diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py | 542 |
1 files changed, 542 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py new file mode 100644 index 00000000..18664cef --- /dev/null +++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py @@ -0,0 +1,542 @@ +import json +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional, TypedDict, TypeVar, Union + +import grpc +from google.protobuf import timestamp_pb2 + +from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun +from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry +from hatchet_sdk.clients.run_event_listener import new_listener +from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener +from hatchet_sdk.connection import new_conn +from hatchet_sdk.contracts.workflows_pb2 import ( + BulkTriggerWorkflowRequest, + BulkTriggerWorkflowResponse, + CreateWorkflowVersionOpts, + PutRateLimitRequest, + PutWorkflowRequest, + RateLimitDuration, + ScheduleWorkflowRequest, + TriggerWorkflowRequest, + TriggerWorkflowResponse, + WorkflowVersion, +) +from hatchet_sdk.contracts.workflows_pb2_grpc import WorkflowServiceStub +from hatchet_sdk.utils.serialization import flatten +from hatchet_sdk.workflow_run import RunRef, WorkflowRunRef + +from ..loader import ClientConfig +from ..metadata import get_metadata +from ..workflow import WorkflowMeta + + +def new_admin(config: ClientConfig): + return AdminClient(config) + + +class ScheduleTriggerWorkflowOptions(TypedDict, total=False): + parent_id: Optional[str] + parent_step_run_id: Optional[str] + child_index: Optional[int] + child_key: Optional[str] + namespace: Optional[str] + + +class ChildTriggerWorkflowOptions(TypedDict, total=False): + additional_metadata: Dict[str, str] | None = None + sticky: bool | None = None + + +class ChildWorkflowRunDict(TypedDict, total=False): + workflow_name: str + input: Any + options: ChildTriggerWorkflowOptions + key: str | None = None + + +class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions, total=False): + additional_metadata: Dict[str, str] | None = None + desired_worker_id: str | None = None + namespace: str | None = None + + +class WorkflowRunDict(TypedDict, total=False): + workflow_name: str + input: Any + options: TriggerWorkflowOptions | None + + +class DedupeViolationErr(Exception): + """Raised by the Hatchet library to indicate that a workflow has already been run with this deduplication value.""" + + pass + + +class AdminClientBase: + pooled_workflow_listener: PooledWorkflowRunListener | None = None + + def _prepare_workflow_request( + self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None + ): + try: + payload_data = json.dumps(input) + + try: + meta = ( + None + if options is None or "additional_metadata" not in options + else options["additional_metadata"] + ) + if meta is not None: + options = { + **options, + "additional_metadata": json.dumps(meta).encode("utf-8"), + } + except json.JSONDecodeError as e: + raise ValueError(f"Error encoding payload: {e}") + + return TriggerWorkflowRequest( + name=workflow_name, input=payload_data, **(options or {}) + ) + except json.JSONDecodeError as e: + raise ValueError(f"Error encoding payload: {e}") + + def _prepare_put_workflow_request( + self, + name: str, + workflow: CreateWorkflowVersionOpts | WorkflowMeta, + overrides: CreateWorkflowVersionOpts | None = None, + ): + try: + opts: CreateWorkflowVersionOpts + + if isinstance(workflow, CreateWorkflowVersionOpts): + opts = workflow + else: + opts = workflow.get_create_opts(self.client.config.namespace) + + if overrides is not None: + opts.MergeFrom(overrides) + + opts.name = name + + return PutWorkflowRequest( + opts=opts, + ) + except grpc.RpcError as e: + raise ValueError(f"Could not put workflow: {e}") + + def _prepare_schedule_workflow_request( + self, + name: str, + schedules: List[Union[datetime, timestamp_pb2.Timestamp]], + input={}, + options: ScheduleTriggerWorkflowOptions = None, + ): + timestamp_schedules = [] + for schedule in schedules: + if isinstance(schedule, datetime): + t = schedule.timestamp() + seconds = int(t) + nanos = int(t % 1 * 1e9) + timestamp = timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos) + timestamp_schedules.append(timestamp) + elif isinstance(schedule, timestamp_pb2.Timestamp): + timestamp_schedules.append(schedule) + else: + raise ValueError( + "Invalid schedule type. Must be datetime or timestamp_pb2.Timestamp." + ) + + return ScheduleWorkflowRequest( + name=name, + schedules=timestamp_schedules, + input=json.dumps(input), + **(options or {}), + ) + + +T = TypeVar("T") + + +class AdminClientAioImpl(AdminClientBase): + def __init__(self, config: ClientConfig): + aio_conn = new_conn(config, True) + self.config = config + self.aio_client = WorkflowServiceStub(aio_conn) + self.token = config.token + self.listener_client = new_listener(config) + self.namespace = config.namespace + + async def run( + self, + function: Union[str, Callable[[Any], T]], + input: any, + options: TriggerWorkflowOptions = None, + ) -> "RunRef[T]": + workflow_name = function + + if not isinstance(function, str): + workflow_name = function.function_name + + wrr = await self.run_workflow(workflow_name, input, options) + + return RunRef[T]( + wrr.workflow_run_id, wrr.workflow_listener, wrr.workflow_run_event_listener + ) + + ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor + @tenacity_retry + async def run_workflow( + self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None + ) -> WorkflowRunRef: + try: + if not self.pooled_workflow_listener: + self.pooled_workflow_listener = PooledWorkflowRunListener(self.config) + + namespace = self.namespace + + if ( + options is not None + and "namespace" in options + and options["namespace"] is not None + ): + namespace = options.pop("namespace") + + if namespace != "" and not workflow_name.startswith(self.namespace): + workflow_name = f"{namespace}{workflow_name}" + + request = self._prepare_workflow_request(workflow_name, input, options) + + resp: TriggerWorkflowResponse = await self.aio_client.TriggerWorkflow( + request, + metadata=get_metadata(self.token), + ) + + return WorkflowRunRef( + workflow_run_id=resp.workflow_run_id, + workflow_listener=self.pooled_workflow_listener, + workflow_run_event_listener=self.listener_client, + ) + except (grpc.RpcError, grpc.aio.AioRpcError) as e: + if e.code() == grpc.StatusCode.ALREADY_EXISTS: + raise DedupeViolationErr(e.details()) + + raise e + + ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor + @tenacity_retry + async def run_workflows( + self, + workflows: list[WorkflowRunDict], + options: TriggerWorkflowOptions | None = None, + ) -> List[WorkflowRunRef]: + if len(workflows) == 0: + raise ValueError("No workflows to run") + + if not self.pooled_workflow_listener: + self.pooled_workflow_listener = PooledWorkflowRunListener(self.config) + + namespace = self.namespace + + if ( + options is not None + and "namespace" in options + and options["namespace"] is not None + ): + namespace = options["namespace"] + del options["namespace"] + + workflow_run_requests: TriggerWorkflowRequest = [] + + for workflow in workflows: + workflow_name = workflow["workflow_name"] + input_data = workflow["input"] + options = workflow["options"] + + if namespace != "" and not workflow_name.startswith(self.namespace): + workflow_name = f"{namespace}{workflow_name}" + + # Prepare and trigger workflow for each workflow name and input + request = self._prepare_workflow_request(workflow_name, input_data, options) + workflow_run_requests.append(request) + + request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests) + + resp: BulkTriggerWorkflowResponse = await self.aio_client.BulkTriggerWorkflow( + request, + metadata=get_metadata(self.token), + ) + + return [ + WorkflowRunRef( + workflow_run_id=workflow_run_id, + workflow_listener=self.pooled_workflow_listener, + workflow_run_event_listener=self.listener_client, + ) + for workflow_run_id in resp.workflow_run_ids + ] + + @tenacity_retry + async def put_workflow( + self, + name: str, + workflow: CreateWorkflowVersionOpts | WorkflowMeta, + overrides: CreateWorkflowVersionOpts | None = None, + ) -> WorkflowVersion: + opts = self._prepare_put_workflow_request(name, workflow, overrides) + + return await self.aio_client.PutWorkflow( + opts, + metadata=get_metadata(self.token), + ) + + @tenacity_retry + async def put_rate_limit( + self, + key: str, + limit: int, + duration: RateLimitDuration = RateLimitDuration.SECOND, + ): + await self.aio_client.PutRateLimit( + PutRateLimitRequest( + key=key, + limit=limit, + duration=duration, + ), + metadata=get_metadata(self.token), + ) + + @tenacity_retry + async def schedule_workflow( + self, + name: str, + schedules: List[Union[datetime, timestamp_pb2.Timestamp]], + input={}, + options: ScheduleTriggerWorkflowOptions = None, + ) -> WorkflowVersion: + try: + namespace = self.namespace + + if ( + options is not None + and "namespace" in options + and options["namespace"] is not None + ): + namespace = options["namespace"] + del options["namespace"] + + if namespace != "" and not name.startswith(self.namespace): + name = f"{namespace}{name}" + + request = self._prepare_schedule_workflow_request( + name, schedules, input, options + ) + + return await self.aio_client.ScheduleWorkflow( + request, + metadata=get_metadata(self.token), + ) + except (grpc.aio.AioRpcError, grpc.RpcError) as e: + if e.code() == grpc.StatusCode.ALREADY_EXISTS: + raise DedupeViolationErr(e.details()) + + raise e + + +class AdminClient(AdminClientBase): + def __init__(self, config: ClientConfig): + conn = new_conn(config) + self.config = config + self.client = WorkflowServiceStub(conn) + self.aio = AdminClientAioImpl(config) + self.token = config.token + self.listener_client = new_listener(config) + self.namespace = config.namespace + + @tenacity_retry + def put_workflow( + self, + name: str, + workflow: CreateWorkflowVersionOpts | WorkflowMeta, + overrides: CreateWorkflowVersionOpts | None = None, + ) -> WorkflowVersion: + opts = self._prepare_put_workflow_request(name, workflow, overrides) + + resp: WorkflowVersion = self.client.PutWorkflow( + opts, + metadata=get_metadata(self.token), + ) + + return resp + + @tenacity_retry + def put_rate_limit( + self, + key: str, + limit: int, + duration: Union[RateLimitDuration.Value, str] = RateLimitDuration.SECOND, + ): + self.client.PutRateLimit( + PutRateLimitRequest( + key=key, + limit=limit, + duration=duration, + ), + metadata=get_metadata(self.token), + ) + + @tenacity_retry + def schedule_workflow( + self, + name: str, + schedules: List[Union[datetime, timestamp_pb2.Timestamp]], + input={}, + options: ScheduleTriggerWorkflowOptions = None, + ) -> WorkflowVersion: + try: + namespace = self.namespace + + if ( + options is not None + and "namespace" in options + and options["namespace"] is not None + ): + namespace = options["namespace"] + del options["namespace"] + + if namespace != "" and not name.startswith(self.namespace): + name = f"{namespace}{name}" + + request = self._prepare_schedule_workflow_request( + name, schedules, input, options + ) + + return self.client.ScheduleWorkflow( + request, + metadata=get_metadata(self.token), + ) + except (grpc.RpcError, grpc.aio.AioRpcError) as e: + if e.code() == grpc.StatusCode.ALREADY_EXISTS: + raise DedupeViolationErr(e.details()) + + raise e + + ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor + @tenacity_retry + def run_workflow( + self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None + ) -> WorkflowRunRef: + try: + if not self.pooled_workflow_listener: + self.pooled_workflow_listener = PooledWorkflowRunListener(self.config) + + namespace = self.namespace + + ## TODO: Factor this out - it's repeated a lot of places + if ( + options is not None + and "namespace" in options + and options["namespace"] is not None + ): + namespace = options.pop("namespace") + + if namespace != "" and not workflow_name.startswith(self.namespace): + workflow_name = f"{namespace}{workflow_name}" + + request = self._prepare_workflow_request(workflow_name, input, options) + + resp: TriggerWorkflowResponse = self.client.TriggerWorkflow( + request, + metadata=get_metadata(self.token), + ) + + return WorkflowRunRef( + workflow_run_id=resp.workflow_run_id, + workflow_listener=self.pooled_workflow_listener, + workflow_run_event_listener=self.listener_client, + ) + except (grpc.RpcError, grpc.aio.AioRpcError) as e: + if e.code() == grpc.StatusCode.ALREADY_EXISTS: + raise DedupeViolationErr(e.details()) + + raise e + + ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor + @tenacity_retry + def run_workflows( + self, workflows: List[WorkflowRunDict], options: TriggerWorkflowOptions = None + ) -> list[WorkflowRunRef]: + workflow_run_requests: TriggerWorkflowRequest = [] + if not self.pooled_workflow_listener: + self.pooled_workflow_listener = PooledWorkflowRunListener(self.config) + + for workflow in workflows: + workflow_name = workflow["workflow_name"] + input_data = workflow["input"] + options = workflow["options"] + + namespace = self.namespace + + if ( + options is not None + and "namespace" in options + and options["namespace"] is not None + ): + namespace = options["namespace"] + del options["namespace"] + + if namespace != "" and not workflow_name.startswith(self.namespace): + workflow_name = f"{namespace}{workflow_name}" + + # Prepare and trigger workflow for each workflow name and input + request = self._prepare_workflow_request(workflow_name, input_data, options) + + workflow_run_requests.append(request) + + request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests) + + resp: BulkTriggerWorkflowResponse = self.client.BulkTriggerWorkflow( + request, + metadata=get_metadata(self.token), + ) + + return [ + WorkflowRunRef( + workflow_run_id=workflow_run_id, + workflow_listener=self.pooled_workflow_listener, + workflow_run_event_listener=self.listener_client, + ) + for workflow_run_id in resp.workflow_run_ids + ] + + def run( + self, + function: Union[str, Callable[[Any], T]], + input: any, + options: TriggerWorkflowOptions = None, + ) -> "RunRef[T]": + workflow_name = function + + if not isinstance(function, str): + workflow_name = function.function_name + + wrr = self.run_workflow(workflow_name, input, options) + + return RunRef[T]( + wrr.workflow_run_id, wrr.workflow_listener, wrr.workflow_run_event_listener + ) + + def get_workflow_run(self, workflow_run_id: str) -> WorkflowRunRef: + try: + if not self.pooled_workflow_listener: + self.pooled_workflow_listener = PooledWorkflowRunListener(self.config) + + return WorkflowRunRef( + workflow_run_id=workflow_run_id, + workflow_listener=self.pooled_workflow_listener, + workflow_run_event_listener=self.listener_client, + ) + except grpc.RpcError as e: + raise ValueError(f"Could not get workflow run: {e}") |