import json
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, TypedDict, TypeVar, Union
import grpc
from google.protobuf import timestamp_pb2
from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun
from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
from hatchet_sdk.clients.run_event_listener import new_listener
from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
from hatchet_sdk.connection import new_conn
from hatchet_sdk.contracts.workflows_pb2 import (
BulkTriggerWorkflowRequest,
BulkTriggerWorkflowResponse,
CreateWorkflowVersionOpts,
PutRateLimitRequest,
PutWorkflowRequest,
RateLimitDuration,
ScheduleWorkflowRequest,
TriggerWorkflowRequest,
TriggerWorkflowResponse,
WorkflowVersion,
)
from hatchet_sdk.contracts.workflows_pb2_grpc import WorkflowServiceStub
from hatchet_sdk.utils.serialization import flatten
from hatchet_sdk.workflow_run import RunRef, WorkflowRunRef
from ..loader import ClientConfig
from ..metadata import get_metadata
from ..workflow import WorkflowMeta
def new_admin(config: ClientConfig):
return AdminClient(config)
class ScheduleTriggerWorkflowOptions(TypedDict, total=False):
parent_id: Optional[str]
parent_step_run_id: Optional[str]
child_index: Optional[int]
child_key: Optional[str]
namespace: Optional[str]
class ChildTriggerWorkflowOptions(TypedDict, total=False):
additional_metadata: Dict[str, str] | None = None
sticky: bool | None = None
class ChildWorkflowRunDict(TypedDict, total=False):
workflow_name: str
input: Any
options: ChildTriggerWorkflowOptions
key: str | None = None
class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions, total=False):
additional_metadata: Dict[str, str] | None = None
desired_worker_id: str | None = None
namespace: str | None = None
class WorkflowRunDict(TypedDict, total=False):
workflow_name: str
input: Any
options: TriggerWorkflowOptions | None
class DedupeViolationErr(Exception):
"""Raised by the Hatchet library to indicate that a workflow has already been run with this deduplication value."""
pass
class AdminClientBase:
pooled_workflow_listener: PooledWorkflowRunListener | None = None
def _prepare_workflow_request(
self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None
):
try:
payload_data = json.dumps(input)
try:
meta = (
None
if options is None or "additional_metadata" not in options
else options["additional_metadata"]
)
if meta is not None:
options = {
**options,
"additional_metadata": json.dumps(meta).encode("utf-8"),
}
except json.JSONDecodeError as e:
raise ValueError(f"Error encoding payload: {e}")
return TriggerWorkflowRequest(
name=workflow_name, input=payload_data, **(options or {})
)
except json.JSONDecodeError as e:
raise ValueError(f"Error encoding payload: {e}")
def _prepare_put_workflow_request(
self,
name: str,
workflow: CreateWorkflowVersionOpts | WorkflowMeta,
overrides: CreateWorkflowVersionOpts | None = None,
):
try:
opts: CreateWorkflowVersionOpts
if isinstance(workflow, CreateWorkflowVersionOpts):
opts = workflow
else:
opts = workflow.get_create_opts(self.client.config.namespace)
if overrides is not None:
opts.MergeFrom(overrides)
opts.name = name
return PutWorkflowRequest(
opts=opts,
)
except grpc.RpcError as e:
raise ValueError(f"Could not put workflow: {e}")
def _prepare_schedule_workflow_request(
self,
name: str,
schedules: List[Union[datetime, timestamp_pb2.Timestamp]],
input={},
options: ScheduleTriggerWorkflowOptions = None,
):
timestamp_schedules = []
for schedule in schedules:
if isinstance(schedule, datetime):
t = schedule.timestamp()
seconds = int(t)
nanos = int(t % 1 * 1e9)
timestamp = timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos)
timestamp_schedules.append(timestamp)
elif isinstance(schedule, timestamp_pb2.Timestamp):
timestamp_schedules.append(schedule)
else:
raise ValueError(
"Invalid schedule type. Must be datetime or timestamp_pb2.Timestamp."
)
return ScheduleWorkflowRequest(
name=name,
schedules=timestamp_schedules,
input=json.dumps(input),
**(options or {}),
)
T = TypeVar("T")
class AdminClientAioImpl(AdminClientBase):
def __init__(self, config: ClientConfig):
aio_conn = new_conn(config, True)
self.config = config
self.aio_client = WorkflowServiceStub(aio_conn)
self.token = config.token
self.listener_client = new_listener(config)
self.namespace = config.namespace
async def run(
self,
function: Union[str, Callable[[Any], T]],
input: any,
options: TriggerWorkflowOptions = None,
) -> "RunRef[T]":
workflow_name = function
if not isinstance(function, str):
workflow_name = function.function_name
wrr = await self.run_workflow(workflow_name, input, options)
return RunRef[T](
wrr.workflow_run_id, wrr.workflow_listener, wrr.workflow_run_event_listener
)
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
@tenacity_retry
async def run_workflow(
self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None
) -> WorkflowRunRef:
try:
if not self.pooled_workflow_listener:
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
namespace = self.namespace
if (
options is not None
and "namespace" in options
and options["namespace"] is not None
):
namespace = options.pop("namespace")
if namespace != "" and not workflow_name.startswith(self.namespace):
workflow_name = f"{namespace}{workflow_name}"
request = self._prepare_workflow_request(workflow_name, input, options)
resp: TriggerWorkflowResponse = await self.aio_client.TriggerWorkflow(
request,
metadata=get_metadata(self.token),
)
return WorkflowRunRef(
workflow_run_id=resp.workflow_run_id,
workflow_listener=self.pooled_workflow_listener,
workflow_run_event_listener=self.listener_client,
)
except (grpc.RpcError, grpc.aio.AioRpcError) as e:
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
raise DedupeViolationErr(e.details())
raise e
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
@tenacity_retry
async def run_workflows(
self,
workflows: list[WorkflowRunDict],
options: TriggerWorkflowOptions | None = None,
) -> List[WorkflowRunRef]:
if len(workflows) == 0:
raise ValueError("No workflows to run")
if not self.pooled_workflow_listener:
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
namespace = self.namespace
if (
options is not None
and "namespace" in options
and options["namespace"] is not None
):
namespace = options["namespace"]
del options["namespace"]
workflow_run_requests: TriggerWorkflowRequest = []
for workflow in workflows:
workflow_name = workflow["workflow_name"]
input_data = workflow["input"]
options = workflow["options"]
if namespace != "" and not workflow_name.startswith(self.namespace):
workflow_name = f"{namespace}{workflow_name}"
# Prepare and trigger workflow for each workflow name and input
request = self._prepare_workflow_request(workflow_name, input_data, options)
workflow_run_requests.append(request)
request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests)
resp: BulkTriggerWorkflowResponse = await self.aio_client.BulkTriggerWorkflow(
request,
metadata=get_metadata(self.token),
)
return [
WorkflowRunRef(
workflow_run_id=workflow_run_id,
workflow_listener=self.pooled_workflow_listener,
workflow_run_event_listener=self.listener_client,
)
for workflow_run_id in resp.workflow_run_ids
]
@tenacity_retry
async def put_workflow(
self,
name: str,
workflow: CreateWorkflowVersionOpts | WorkflowMeta,
overrides: CreateWorkflowVersionOpts | None = None,
) -> WorkflowVersion:
opts = self._prepare_put_workflow_request(name, workflow, overrides)
return await self.aio_client.PutWorkflow(
opts,
metadata=get_metadata(self.token),
)
@tenacity_retry
async def put_rate_limit(
self,
key: str,
limit: int,
duration: RateLimitDuration = RateLimitDuration.SECOND,
):
await self.aio_client.PutRateLimit(
PutRateLimitRequest(
key=key,
limit=limit,
duration=duration,
),
metadata=get_metadata(self.token),
)
@tenacity_retry
async def schedule_workflow(
self,
name: str,
schedules: List[Union[datetime, timestamp_pb2.Timestamp]],
input={},
options: ScheduleTriggerWorkflowOptions = None,
) -> WorkflowVersion:
try:
namespace = self.namespace
if (
options is not None
and "namespace" in options
and options["namespace"] is not None
):
namespace = options["namespace"]
del options["namespace"]
if namespace != "" and not name.startswith(self.namespace):
name = f"{namespace}{name}"
request = self._prepare_schedule_workflow_request(
name, schedules, input, options
)
return await self.aio_client.ScheduleWorkflow(
request,
metadata=get_metadata(self.token),
)
except (grpc.aio.AioRpcError, grpc.RpcError) as e:
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
raise DedupeViolationErr(e.details())
raise e
class AdminClient(AdminClientBase):
def __init__(self, config: ClientConfig):
conn = new_conn(config)
self.config = config
self.client = WorkflowServiceStub(conn)
self.aio = AdminClientAioImpl(config)
self.token = config.token
self.listener_client = new_listener(config)
self.namespace = config.namespace
@tenacity_retry
def put_workflow(
self,
name: str,
workflow: CreateWorkflowVersionOpts | WorkflowMeta,
overrides: CreateWorkflowVersionOpts | None = None,
) -> WorkflowVersion:
opts = self._prepare_put_workflow_request(name, workflow, overrides)
resp: WorkflowVersion = self.client.PutWorkflow(
opts,
metadata=get_metadata(self.token),
)
return resp
@tenacity_retry
def put_rate_limit(
self,
key: str,
limit: int,
duration: Union[RateLimitDuration.Value, str] = RateLimitDuration.SECOND,
):
self.client.PutRateLimit(
PutRateLimitRequest(
key=key,
limit=limit,
duration=duration,
),
metadata=get_metadata(self.token),
)
@tenacity_retry
def schedule_workflow(
self,
name: str,
schedules: List[Union[datetime, timestamp_pb2.Timestamp]],
input={},
options: ScheduleTriggerWorkflowOptions = None,
) -> WorkflowVersion:
try:
namespace = self.namespace
if (
options is not None
and "namespace" in options
and options["namespace"] is not None
):
namespace = options["namespace"]
del options["namespace"]
if namespace != "" and not name.startswith(self.namespace):
name = f"{namespace}{name}"
request = self._prepare_schedule_workflow_request(
name, schedules, input, options
)
return self.client.ScheduleWorkflow(
request,
metadata=get_metadata(self.token),
)
except (grpc.RpcError, grpc.aio.AioRpcError) as e:
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
raise DedupeViolationErr(e.details())
raise e
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
@tenacity_retry
def run_workflow(
self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None
) -> WorkflowRunRef:
try:
if not self.pooled_workflow_listener:
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
namespace = self.namespace
## TODO: Factor this out - it's repeated a lot of places
if (
options is not None
and "namespace" in options
and options["namespace"] is not None
):
namespace = options.pop("namespace")
if namespace != "" and not workflow_name.startswith(self.namespace):
workflow_name = f"{namespace}{workflow_name}"
request = self._prepare_workflow_request(workflow_name, input, options)
resp: TriggerWorkflowResponse = self.client.TriggerWorkflow(
request,
metadata=get_metadata(self.token),
)
return WorkflowRunRef(
workflow_run_id=resp.workflow_run_id,
workflow_listener=self.pooled_workflow_listener,
workflow_run_event_listener=self.listener_client,
)
except (grpc.RpcError, grpc.aio.AioRpcError) as e:
if e.code() == grpc.StatusCode.ALREADY_EXISTS:
raise DedupeViolationErr(e.details())
raise e
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
@tenacity_retry
def run_workflows(
self, workflows: List[WorkflowRunDict], options: TriggerWorkflowOptions = None
) -> list[WorkflowRunRef]:
workflow_run_requests: TriggerWorkflowRequest = []
if not self.pooled_workflow_listener:
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
for workflow in workflows:
workflow_name = workflow["workflow_name"]
input_data = workflow["input"]
options = workflow["options"]
namespace = self.namespace
if (
options is not None
and "namespace" in options
and options["namespace"] is not None
):
namespace = options["namespace"]
del options["namespace"]
if namespace != "" and not workflow_name.startswith(self.namespace):
workflow_name = f"{namespace}{workflow_name}"
# Prepare and trigger workflow for each workflow name and input
request = self._prepare_workflow_request(workflow_name, input_data, options)
workflow_run_requests.append(request)
request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests)
resp: BulkTriggerWorkflowResponse = self.client.BulkTriggerWorkflow(
request,
metadata=get_metadata(self.token),
)
return [
WorkflowRunRef(
workflow_run_id=workflow_run_id,
workflow_listener=self.pooled_workflow_listener,
workflow_run_event_listener=self.listener_client,
)
for workflow_run_id in resp.workflow_run_ids
]
def run(
self,
function: Union[str, Callable[[Any], T]],
input: any,
options: TriggerWorkflowOptions = None,
) -> "RunRef[T]":
workflow_name = function
if not isinstance(function, str):
workflow_name = function.function_name
wrr = self.run_workflow(workflow_name, input, options)
return RunRef[T](
wrr.workflow_run_id, wrr.workflow_listener, wrr.workflow_run_event_listener
)
def get_workflow_run(self, workflow_run_id: str) -> WorkflowRunRef:
try:
if not self.pooled_workflow_listener:
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
return WorkflowRunRef(
workflow_run_id=workflow_run_id,
workflow_listener=self.pooled_workflow_listener,
workflow_run_event_listener=self.listener_client,
)
except grpc.RpcError as e:
raise ValueError(f"Could not get workflow run: {e}")