about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py')
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py542
1 files changed, 542 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py
new file mode 100644
index 00000000..18664cef
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py
@@ -0,0 +1,542 @@
+import json
+from datetime import datetime
+from typing import Any, Callable, Dict, List, Optional, TypedDict, TypeVar, Union
+
+import grpc
+from google.protobuf import timestamp_pb2
+
+from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun
+from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
+from hatchet_sdk.clients.run_event_listener import new_listener
+from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
+from hatchet_sdk.connection import new_conn
+from hatchet_sdk.contracts.workflows_pb2 import (
+    BulkTriggerWorkflowRequest,
+    BulkTriggerWorkflowResponse,
+    CreateWorkflowVersionOpts,
+    PutRateLimitRequest,
+    PutWorkflowRequest,
+    RateLimitDuration,
+    ScheduleWorkflowRequest,
+    TriggerWorkflowRequest,
+    TriggerWorkflowResponse,
+    WorkflowVersion,
+)
+from hatchet_sdk.contracts.workflows_pb2_grpc import WorkflowServiceStub
+from hatchet_sdk.utils.serialization import flatten
+from hatchet_sdk.workflow_run import RunRef, WorkflowRunRef
+
+from ..loader import ClientConfig
+from ..metadata import get_metadata
+from ..workflow import WorkflowMeta
+
+
+def new_admin(config: ClientConfig):
+    return AdminClient(config)
+
+
+class ScheduleTriggerWorkflowOptions(TypedDict, total=False):
+    parent_id: Optional[str]
+    parent_step_run_id: Optional[str]
+    child_index: Optional[int]
+    child_key: Optional[str]
+    namespace: Optional[str]
+
+
+class ChildTriggerWorkflowOptions(TypedDict, total=False):
+    additional_metadata: Dict[str, str] | None = None
+    sticky: bool | None = None
+
+
+class ChildWorkflowRunDict(TypedDict, total=False):
+    workflow_name: str
+    input: Any
+    options: ChildTriggerWorkflowOptions
+    key: str | None = None
+
+
+class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions, total=False):
+    additional_metadata: Dict[str, str] | None = None
+    desired_worker_id: str | None = None
+    namespace: str | None = None
+
+
+class WorkflowRunDict(TypedDict, total=False):
+    workflow_name: str
+    input: Any
+    options: TriggerWorkflowOptions | None
+
+
+class DedupeViolationErr(Exception):
+    """Raised by the Hatchet library to indicate that a workflow has already been run with this deduplication value."""
+
+    pass
+
+
+class AdminClientBase:
+    pooled_workflow_listener: PooledWorkflowRunListener | None = None
+
+    def _prepare_workflow_request(
+        self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None
+    ):
+        try:
+            payload_data = json.dumps(input)
+
+            try:
+                meta = (
+                    None
+                    if options is None or "additional_metadata" not in options
+                    else options["additional_metadata"]
+                )
+                if meta is not None:
+                    options = {
+                        **options,
+                        "additional_metadata": json.dumps(meta).encode("utf-8"),
+                    }
+            except json.JSONDecodeError as e:
+                raise ValueError(f"Error encoding payload: {e}")
+
+            return TriggerWorkflowRequest(
+                name=workflow_name, input=payload_data, **(options or {})
+            )
+        except json.JSONDecodeError as e:
+            raise ValueError(f"Error encoding payload: {e}")
+
+    def _prepare_put_workflow_request(
+        self,
+        name: str,
+        workflow: CreateWorkflowVersionOpts | WorkflowMeta,
+        overrides: CreateWorkflowVersionOpts | None = None,
+    ):
+        try:
+            opts: CreateWorkflowVersionOpts
+
+            if isinstance(workflow, CreateWorkflowVersionOpts):
+                opts = workflow
+            else:
+                opts = workflow.get_create_opts(self.client.config.namespace)
+
+            if overrides is not None:
+                opts.MergeFrom(overrides)
+
+            opts.name = name
+
+            return PutWorkflowRequest(
+                opts=opts,
+            )
+        except grpc.RpcError as e:
+            raise ValueError(f"Could not put workflow: {e}")
+
+    def _prepare_schedule_workflow_request(
+        self,
+        name: str,
+        schedules: List[Union[datetime, timestamp_pb2.Timestamp]],
+        input={},
+        options: ScheduleTriggerWorkflowOptions = None,
+    ):
+        timestamp_schedules = []
+        for schedule in schedules:
+            if isinstance(schedule, datetime):
+                t = schedule.timestamp()
+                seconds = int(t)
+                nanos = int(t % 1 * 1e9)
+                timestamp = timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos)
+                timestamp_schedules.append(timestamp)
+            elif isinstance(schedule, timestamp_pb2.Timestamp):
+                timestamp_schedules.append(schedule)
+            else:
+                raise ValueError(
+                    "Invalid schedule type. Must be datetime or timestamp_pb2.Timestamp."
+                )
+
+        return ScheduleWorkflowRequest(
+            name=name,
+            schedules=timestamp_schedules,
+            input=json.dumps(input),
+            **(options or {}),
+        )
+
+
+T = TypeVar("T")
+
+
+class AdminClientAioImpl(AdminClientBase):
+    def __init__(self, config: ClientConfig):
+        aio_conn = new_conn(config, True)
+        self.config = config
+        self.aio_client = WorkflowServiceStub(aio_conn)
+        self.token = config.token
+        self.listener_client = new_listener(config)
+        self.namespace = config.namespace
+
+    async def run(
+        self,
+        function: Union[str, Callable[[Any], T]],
+        input: any,
+        options: TriggerWorkflowOptions = None,
+    ) -> "RunRef[T]":
+        workflow_name = function
+
+        if not isinstance(function, str):
+            workflow_name = function.function_name
+
+        wrr = await self.run_workflow(workflow_name, input, options)
+
+        return RunRef[T](
+            wrr.workflow_run_id, wrr.workflow_listener, wrr.workflow_run_event_listener
+        )
+
+    ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
+    @tenacity_retry
+    async def run_workflow(
+        self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None
+    ) -> WorkflowRunRef:
+        try:
+            if not self.pooled_workflow_listener:
+                self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
+
+            namespace = self.namespace
+
+            if (
+                options is not None
+                and "namespace" in options
+                and options["namespace"] is not None
+            ):
+                namespace = options.pop("namespace")
+
+            if namespace != "" and not workflow_name.startswith(self.namespace):
+                workflow_name = f"{namespace}{workflow_name}"
+
+            request = self._prepare_workflow_request(workflow_name, input, options)
+
+            resp: TriggerWorkflowResponse = await self.aio_client.TriggerWorkflow(
+                request,
+                metadata=get_metadata(self.token),
+            )
+
+            return WorkflowRunRef(
+                workflow_run_id=resp.workflow_run_id,
+                workflow_listener=self.pooled_workflow_listener,
+                workflow_run_event_listener=self.listener_client,
+            )
+        except (grpc.RpcError, grpc.aio.AioRpcError) as e:
+            if e.code() == grpc.StatusCode.ALREADY_EXISTS:
+                raise DedupeViolationErr(e.details())
+
+            raise e
+
+    ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
+    @tenacity_retry
+    async def run_workflows(
+        self,
+        workflows: list[WorkflowRunDict],
+        options: TriggerWorkflowOptions | None = None,
+    ) -> List[WorkflowRunRef]:
+        if len(workflows) == 0:
+            raise ValueError("No workflows to run")
+
+        if not self.pooled_workflow_listener:
+            self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
+
+        namespace = self.namespace
+
+        if (
+            options is not None
+            and "namespace" in options
+            and options["namespace"] is not None
+        ):
+            namespace = options["namespace"]
+            del options["namespace"]
+
+        workflow_run_requests: TriggerWorkflowRequest = []
+
+        for workflow in workflows:
+            workflow_name = workflow["workflow_name"]
+            input_data = workflow["input"]
+            options = workflow["options"]
+
+            if namespace != "" and not workflow_name.startswith(self.namespace):
+                workflow_name = f"{namespace}{workflow_name}"
+
+            # Prepare and trigger workflow for each workflow name and input
+            request = self._prepare_workflow_request(workflow_name, input_data, options)
+            workflow_run_requests.append(request)
+
+        request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests)
+
+        resp: BulkTriggerWorkflowResponse = await self.aio_client.BulkTriggerWorkflow(
+            request,
+            metadata=get_metadata(self.token),
+        )
+
+        return [
+            WorkflowRunRef(
+                workflow_run_id=workflow_run_id,
+                workflow_listener=self.pooled_workflow_listener,
+                workflow_run_event_listener=self.listener_client,
+            )
+            for workflow_run_id in resp.workflow_run_ids
+        ]
+
+    @tenacity_retry
+    async def put_workflow(
+        self,
+        name: str,
+        workflow: CreateWorkflowVersionOpts | WorkflowMeta,
+        overrides: CreateWorkflowVersionOpts | None = None,
+    ) -> WorkflowVersion:
+        opts = self._prepare_put_workflow_request(name, workflow, overrides)
+
+        return await self.aio_client.PutWorkflow(
+            opts,
+            metadata=get_metadata(self.token),
+        )
+
+    @tenacity_retry
+    async def put_rate_limit(
+        self,
+        key: str,
+        limit: int,
+        duration: RateLimitDuration = RateLimitDuration.SECOND,
+    ):
+        await self.aio_client.PutRateLimit(
+            PutRateLimitRequest(
+                key=key,
+                limit=limit,
+                duration=duration,
+            ),
+            metadata=get_metadata(self.token),
+        )
+
+    @tenacity_retry
+    async def schedule_workflow(
+        self,
+        name: str,
+        schedules: List[Union[datetime, timestamp_pb2.Timestamp]],
+        input={},
+        options: ScheduleTriggerWorkflowOptions = None,
+    ) -> WorkflowVersion:
+        try:
+            namespace = self.namespace
+
+            if (
+                options is not None
+                and "namespace" in options
+                and options["namespace"] is not None
+            ):
+                namespace = options["namespace"]
+                del options["namespace"]
+
+            if namespace != "" and not name.startswith(self.namespace):
+                name = f"{namespace}{name}"
+
+            request = self._prepare_schedule_workflow_request(
+                name, schedules, input, options
+            )
+
+            return await self.aio_client.ScheduleWorkflow(
+                request,
+                metadata=get_metadata(self.token),
+            )
+        except (grpc.aio.AioRpcError, grpc.RpcError) as e:
+            if e.code() == grpc.StatusCode.ALREADY_EXISTS:
+                raise DedupeViolationErr(e.details())
+
+            raise e
+
+
+class AdminClient(AdminClientBase):
+    def __init__(self, config: ClientConfig):
+        conn = new_conn(config)
+        self.config = config
+        self.client = WorkflowServiceStub(conn)
+        self.aio = AdminClientAioImpl(config)
+        self.token = config.token
+        self.listener_client = new_listener(config)
+        self.namespace = config.namespace
+
+    @tenacity_retry
+    def put_workflow(
+        self,
+        name: str,
+        workflow: CreateWorkflowVersionOpts | WorkflowMeta,
+        overrides: CreateWorkflowVersionOpts | None = None,
+    ) -> WorkflowVersion:
+        opts = self._prepare_put_workflow_request(name, workflow, overrides)
+
+        resp: WorkflowVersion = self.client.PutWorkflow(
+            opts,
+            metadata=get_metadata(self.token),
+        )
+
+        return resp
+
+    @tenacity_retry
+    def put_rate_limit(
+        self,
+        key: str,
+        limit: int,
+        duration: Union[RateLimitDuration.Value, str] = RateLimitDuration.SECOND,
+    ):
+        self.client.PutRateLimit(
+            PutRateLimitRequest(
+                key=key,
+                limit=limit,
+                duration=duration,
+            ),
+            metadata=get_metadata(self.token),
+        )
+
+    @tenacity_retry
+    def schedule_workflow(
+        self,
+        name: str,
+        schedules: List[Union[datetime, timestamp_pb2.Timestamp]],
+        input={},
+        options: ScheduleTriggerWorkflowOptions = None,
+    ) -> WorkflowVersion:
+        try:
+            namespace = self.namespace
+
+            if (
+                options is not None
+                and "namespace" in options
+                and options["namespace"] is not None
+            ):
+                namespace = options["namespace"]
+                del options["namespace"]
+
+            if namespace != "" and not name.startswith(self.namespace):
+                name = f"{namespace}{name}"
+
+            request = self._prepare_schedule_workflow_request(
+                name, schedules, input, options
+            )
+
+            return self.client.ScheduleWorkflow(
+                request,
+                metadata=get_metadata(self.token),
+            )
+        except (grpc.RpcError, grpc.aio.AioRpcError) as e:
+            if e.code() == grpc.StatusCode.ALREADY_EXISTS:
+                raise DedupeViolationErr(e.details())
+
+            raise e
+
+    ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
+    @tenacity_retry
+    def run_workflow(
+        self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None
+    ) -> WorkflowRunRef:
+        try:
+            if not self.pooled_workflow_listener:
+                self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
+
+            namespace = self.namespace
+
+            ## TODO: Factor this out - it's repeated a lot of places
+            if (
+                options is not None
+                and "namespace" in options
+                and options["namespace"] is not None
+            ):
+                namespace = options.pop("namespace")
+
+            if namespace != "" and not workflow_name.startswith(self.namespace):
+                workflow_name = f"{namespace}{workflow_name}"
+
+            request = self._prepare_workflow_request(workflow_name, input, options)
+
+            resp: TriggerWorkflowResponse = self.client.TriggerWorkflow(
+                request,
+                metadata=get_metadata(self.token),
+            )
+
+            return WorkflowRunRef(
+                workflow_run_id=resp.workflow_run_id,
+                workflow_listener=self.pooled_workflow_listener,
+                workflow_run_event_listener=self.listener_client,
+            )
+        except (grpc.RpcError, grpc.aio.AioRpcError) as e:
+            if e.code() == grpc.StatusCode.ALREADY_EXISTS:
+                raise DedupeViolationErr(e.details())
+
+            raise e
+
+    ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
+    @tenacity_retry
+    def run_workflows(
+        self, workflows: List[WorkflowRunDict], options: TriggerWorkflowOptions = None
+    ) -> list[WorkflowRunRef]:
+        workflow_run_requests: TriggerWorkflowRequest = []
+        if not self.pooled_workflow_listener:
+            self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
+
+        for workflow in workflows:
+            workflow_name = workflow["workflow_name"]
+            input_data = workflow["input"]
+            options = workflow["options"]
+
+            namespace = self.namespace
+
+            if (
+                options is not None
+                and "namespace" in options
+                and options["namespace"] is not None
+            ):
+                namespace = options["namespace"]
+                del options["namespace"]
+
+            if namespace != "" and not workflow_name.startswith(self.namespace):
+                workflow_name = f"{namespace}{workflow_name}"
+
+            # Prepare and trigger workflow for each workflow name and input
+            request = self._prepare_workflow_request(workflow_name, input_data, options)
+
+            workflow_run_requests.append(request)
+
+            request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests)
+
+        resp: BulkTriggerWorkflowResponse = self.client.BulkTriggerWorkflow(
+            request,
+            metadata=get_metadata(self.token),
+        )
+
+        return [
+            WorkflowRunRef(
+                workflow_run_id=workflow_run_id,
+                workflow_listener=self.pooled_workflow_listener,
+                workflow_run_event_listener=self.listener_client,
+            )
+            for workflow_run_id in resp.workflow_run_ids
+        ]
+
+    def run(
+        self,
+        function: Union[str, Callable[[Any], T]],
+        input: any,
+        options: TriggerWorkflowOptions = None,
+    ) -> "RunRef[T]":
+        workflow_name = function
+
+        if not isinstance(function, str):
+            workflow_name = function.function_name
+
+        wrr = self.run_workflow(workflow_name, input, options)
+
+        return RunRef[T](
+            wrr.workflow_run_id, wrr.workflow_listener, wrr.workflow_run_event_listener
+        )
+
+    def get_workflow_run(self, workflow_run_id: str) -> WorkflowRunRef:
+        try:
+            if not self.pooled_workflow_listener:
+                self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
+
+            return WorkflowRunRef(
+                workflow_run_id=workflow_run_id,
+                workflow_listener=self.pooled_workflow_listener,
+                workflow_run_event_listener=self.listener_client,
+            )
+        except grpc.RpcError as e:
+            raise ValueError(f"Could not get workflow run: {e}")