aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py')
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py542
1 files changed, 542 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py
new file mode 100644
index 00000000..18664cef
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/admin.py
@@ -0,0 +1,542 @@
+import json
+from datetime import datetime
+from typing import Any, Callable, Dict, List, Optional, TypedDict, TypeVar, Union
+
+import grpc
+from google.protobuf import timestamp_pb2
+
+from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun
+from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
+from hatchet_sdk.clients.run_event_listener import new_listener
+from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
+from hatchet_sdk.connection import new_conn
+from hatchet_sdk.contracts.workflows_pb2 import (
+ BulkTriggerWorkflowRequest,
+ BulkTriggerWorkflowResponse,
+ CreateWorkflowVersionOpts,
+ PutRateLimitRequest,
+ PutWorkflowRequest,
+ RateLimitDuration,
+ ScheduleWorkflowRequest,
+ TriggerWorkflowRequest,
+ TriggerWorkflowResponse,
+ WorkflowVersion,
+)
+from hatchet_sdk.contracts.workflows_pb2_grpc import WorkflowServiceStub
+from hatchet_sdk.utils.serialization import flatten
+from hatchet_sdk.workflow_run import RunRef, WorkflowRunRef
+
+from ..loader import ClientConfig
+from ..metadata import get_metadata
+from ..workflow import WorkflowMeta
+
+
+def new_admin(config: ClientConfig):
+ return AdminClient(config)
+
+
+class ScheduleTriggerWorkflowOptions(TypedDict, total=False):
+ parent_id: Optional[str]
+ parent_step_run_id: Optional[str]
+ child_index: Optional[int]
+ child_key: Optional[str]
+ namespace: Optional[str]
+
+
+class ChildTriggerWorkflowOptions(TypedDict, total=False):
+ additional_metadata: Dict[str, str] | None = None
+ sticky: bool | None = None
+
+
+class ChildWorkflowRunDict(TypedDict, total=False):
+ workflow_name: str
+ input: Any
+ options: ChildTriggerWorkflowOptions
+ key: str | None = None
+
+
+class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions, total=False):
+ additional_metadata: Dict[str, str] | None = None
+ desired_worker_id: str | None = None
+ namespace: str | None = None
+
+
+class WorkflowRunDict(TypedDict, total=False):
+ workflow_name: str
+ input: Any
+ options: TriggerWorkflowOptions | None
+
+
+class DedupeViolationErr(Exception):
+ """Raised by the Hatchet library to indicate that a workflow has already been run with this deduplication value."""
+
+ pass
+
+
+class AdminClientBase:
+ pooled_workflow_listener: PooledWorkflowRunListener | None = None
+
+ def _prepare_workflow_request(
+ self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None
+ ):
+ try:
+ payload_data = json.dumps(input)
+
+ try:
+ meta = (
+ None
+ if options is None or "additional_metadata" not in options
+ else options["additional_metadata"]
+ )
+ if meta is not None:
+ options = {
+ **options,
+ "additional_metadata": json.dumps(meta).encode("utf-8"),
+ }
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Error encoding payload: {e}")
+
+ return TriggerWorkflowRequest(
+ name=workflow_name, input=payload_data, **(options or {})
+ )
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Error encoding payload: {e}")
+
+ def _prepare_put_workflow_request(
+ self,
+ name: str,
+ workflow: CreateWorkflowVersionOpts | WorkflowMeta,
+ overrides: CreateWorkflowVersionOpts | None = None,
+ ):
+ try:
+ opts: CreateWorkflowVersionOpts
+
+ if isinstance(workflow, CreateWorkflowVersionOpts):
+ opts = workflow
+ else:
+ opts = workflow.get_create_opts(self.client.config.namespace)
+
+ if overrides is not None:
+ opts.MergeFrom(overrides)
+
+ opts.name = name
+
+ return PutWorkflowRequest(
+ opts=opts,
+ )
+ except grpc.RpcError as e:
+ raise ValueError(f"Could not put workflow: {e}")
+
+ def _prepare_schedule_workflow_request(
+ self,
+ name: str,
+ schedules: List[Union[datetime, timestamp_pb2.Timestamp]],
+ input={},
+ options: ScheduleTriggerWorkflowOptions = None,
+ ):
+ timestamp_schedules = []
+ for schedule in schedules:
+ if isinstance(schedule, datetime):
+ t = schedule.timestamp()
+ seconds = int(t)
+ nanos = int(t % 1 * 1e9)
+ timestamp = timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos)
+ timestamp_schedules.append(timestamp)
+ elif isinstance(schedule, timestamp_pb2.Timestamp):
+ timestamp_schedules.append(schedule)
+ else:
+ raise ValueError(
+ "Invalid schedule type. Must be datetime or timestamp_pb2.Timestamp."
+ )
+
+ return ScheduleWorkflowRequest(
+ name=name,
+ schedules=timestamp_schedules,
+ input=json.dumps(input),
+ **(options or {}),
+ )
+
+
+T = TypeVar("T")
+
+
+class AdminClientAioImpl(AdminClientBase):
+ def __init__(self, config: ClientConfig):
+ aio_conn = new_conn(config, True)
+ self.config = config
+ self.aio_client = WorkflowServiceStub(aio_conn)
+ self.token = config.token
+ self.listener_client = new_listener(config)
+ self.namespace = config.namespace
+
+ async def run(
+ self,
+ function: Union[str, Callable[[Any], T]],
+ input: any,
+ options: TriggerWorkflowOptions = None,
+ ) -> "RunRef[T]":
+ workflow_name = function
+
+ if not isinstance(function, str):
+ workflow_name = function.function_name
+
+ wrr = await self.run_workflow(workflow_name, input, options)
+
+ return RunRef[T](
+ wrr.workflow_run_id, wrr.workflow_listener, wrr.workflow_run_event_listener
+ )
+
+ ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
+ @tenacity_retry
+ async def run_workflow(
+ self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None
+ ) -> WorkflowRunRef:
+ try:
+ if not self.pooled_workflow_listener:
+ self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
+
+ namespace = self.namespace
+
+ if (
+ options is not None
+ and "namespace" in options
+ and options["namespace"] is not None
+ ):
+ namespace = options.pop("namespace")
+
+ if namespace != "" and not workflow_name.startswith(self.namespace):
+ workflow_name = f"{namespace}{workflow_name}"
+
+ request = self._prepare_workflow_request(workflow_name, input, options)
+
+ resp: TriggerWorkflowResponse = await self.aio_client.TriggerWorkflow(
+ request,
+ metadata=get_metadata(self.token),
+ )
+
+ return WorkflowRunRef(
+ workflow_run_id=resp.workflow_run_id,
+ workflow_listener=self.pooled_workflow_listener,
+ workflow_run_event_listener=self.listener_client,
+ )
+ except (grpc.RpcError, grpc.aio.AioRpcError) as e:
+ if e.code() == grpc.StatusCode.ALREADY_EXISTS:
+ raise DedupeViolationErr(e.details())
+
+ raise e
+
+ ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
+ @tenacity_retry
+ async def run_workflows(
+ self,
+ workflows: list[WorkflowRunDict],
+ options: TriggerWorkflowOptions | None = None,
+ ) -> List[WorkflowRunRef]:
+ if len(workflows) == 0:
+ raise ValueError("No workflows to run")
+
+ if not self.pooled_workflow_listener:
+ self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
+
+ namespace = self.namespace
+
+ if (
+ options is not None
+ and "namespace" in options
+ and options["namespace"] is not None
+ ):
+ namespace = options["namespace"]
+ del options["namespace"]
+
+ workflow_run_requests: TriggerWorkflowRequest = []
+
+ for workflow in workflows:
+ workflow_name = workflow["workflow_name"]
+ input_data = workflow["input"]
+ options = workflow["options"]
+
+ if namespace != "" and not workflow_name.startswith(self.namespace):
+ workflow_name = f"{namespace}{workflow_name}"
+
+ # Prepare and trigger workflow for each workflow name and input
+ request = self._prepare_workflow_request(workflow_name, input_data, options)
+ workflow_run_requests.append(request)
+
+ request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests)
+
+ resp: BulkTriggerWorkflowResponse = await self.aio_client.BulkTriggerWorkflow(
+ request,
+ metadata=get_metadata(self.token),
+ )
+
+ return [
+ WorkflowRunRef(
+ workflow_run_id=workflow_run_id,
+ workflow_listener=self.pooled_workflow_listener,
+ workflow_run_event_listener=self.listener_client,
+ )
+ for workflow_run_id in resp.workflow_run_ids
+ ]
+
+ @tenacity_retry
+ async def put_workflow(
+ self,
+ name: str,
+ workflow: CreateWorkflowVersionOpts | WorkflowMeta,
+ overrides: CreateWorkflowVersionOpts | None = None,
+ ) -> WorkflowVersion:
+ opts = self._prepare_put_workflow_request(name, workflow, overrides)
+
+ return await self.aio_client.PutWorkflow(
+ opts,
+ metadata=get_metadata(self.token),
+ )
+
+ @tenacity_retry
+ async def put_rate_limit(
+ self,
+ key: str,
+ limit: int,
+ duration: RateLimitDuration = RateLimitDuration.SECOND,
+ ):
+ await self.aio_client.PutRateLimit(
+ PutRateLimitRequest(
+ key=key,
+ limit=limit,
+ duration=duration,
+ ),
+ metadata=get_metadata(self.token),
+ )
+
+ @tenacity_retry
+ async def schedule_workflow(
+ self,
+ name: str,
+ schedules: List[Union[datetime, timestamp_pb2.Timestamp]],
+ input={},
+ options: ScheduleTriggerWorkflowOptions = None,
+ ) -> WorkflowVersion:
+ try:
+ namespace = self.namespace
+
+ if (
+ options is not None
+ and "namespace" in options
+ and options["namespace"] is not None
+ ):
+ namespace = options["namespace"]
+ del options["namespace"]
+
+ if namespace != "" and not name.startswith(self.namespace):
+ name = f"{namespace}{name}"
+
+ request = self._prepare_schedule_workflow_request(
+ name, schedules, input, options
+ )
+
+ return await self.aio_client.ScheduleWorkflow(
+ request,
+ metadata=get_metadata(self.token),
+ )
+ except (grpc.aio.AioRpcError, grpc.RpcError) as e:
+ if e.code() == grpc.StatusCode.ALREADY_EXISTS:
+ raise DedupeViolationErr(e.details())
+
+ raise e
+
+
+class AdminClient(AdminClientBase):
+ def __init__(self, config: ClientConfig):
+ conn = new_conn(config)
+ self.config = config
+ self.client = WorkflowServiceStub(conn)
+ self.aio = AdminClientAioImpl(config)
+ self.token = config.token
+ self.listener_client = new_listener(config)
+ self.namespace = config.namespace
+
+ @tenacity_retry
+ def put_workflow(
+ self,
+ name: str,
+ workflow: CreateWorkflowVersionOpts | WorkflowMeta,
+ overrides: CreateWorkflowVersionOpts | None = None,
+ ) -> WorkflowVersion:
+ opts = self._prepare_put_workflow_request(name, workflow, overrides)
+
+ resp: WorkflowVersion = self.client.PutWorkflow(
+ opts,
+ metadata=get_metadata(self.token),
+ )
+
+ return resp
+
+ @tenacity_retry
+ def put_rate_limit(
+ self,
+ key: str,
+ limit: int,
+ duration: Union[RateLimitDuration.Value, str] = RateLimitDuration.SECOND,
+ ):
+ self.client.PutRateLimit(
+ PutRateLimitRequest(
+ key=key,
+ limit=limit,
+ duration=duration,
+ ),
+ metadata=get_metadata(self.token),
+ )
+
+ @tenacity_retry
+ def schedule_workflow(
+ self,
+ name: str,
+ schedules: List[Union[datetime, timestamp_pb2.Timestamp]],
+ input={},
+ options: ScheduleTriggerWorkflowOptions = None,
+ ) -> WorkflowVersion:
+ try:
+ namespace = self.namespace
+
+ if (
+ options is not None
+ and "namespace" in options
+ and options["namespace"] is not None
+ ):
+ namespace = options["namespace"]
+ del options["namespace"]
+
+ if namespace != "" and not name.startswith(self.namespace):
+ name = f"{namespace}{name}"
+
+ request = self._prepare_schedule_workflow_request(
+ name, schedules, input, options
+ )
+
+ return self.client.ScheduleWorkflow(
+ request,
+ metadata=get_metadata(self.token),
+ )
+ except (grpc.RpcError, grpc.aio.AioRpcError) as e:
+ if e.code() == grpc.StatusCode.ALREADY_EXISTS:
+ raise DedupeViolationErr(e.details())
+
+ raise e
+
+ ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
+ @tenacity_retry
+ def run_workflow(
+ self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None
+ ) -> WorkflowRunRef:
+ try:
+ if not self.pooled_workflow_listener:
+ self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
+
+ namespace = self.namespace
+
+ ## TODO: Factor this out - it's repeated a lot of places
+ if (
+ options is not None
+ and "namespace" in options
+ and options["namespace"] is not None
+ ):
+ namespace = options.pop("namespace")
+
+ if namespace != "" and not workflow_name.startswith(self.namespace):
+ workflow_name = f"{namespace}{workflow_name}"
+
+ request = self._prepare_workflow_request(workflow_name, input, options)
+
+ resp: TriggerWorkflowResponse = self.client.TriggerWorkflow(
+ request,
+ metadata=get_metadata(self.token),
+ )
+
+ return WorkflowRunRef(
+ workflow_run_id=resp.workflow_run_id,
+ workflow_listener=self.pooled_workflow_listener,
+ workflow_run_event_listener=self.listener_client,
+ )
+ except (grpc.RpcError, grpc.aio.AioRpcError) as e:
+ if e.code() == grpc.StatusCode.ALREADY_EXISTS:
+ raise DedupeViolationErr(e.details())
+
+ raise e
+
+ ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
+ @tenacity_retry
+ def run_workflows(
+ self, workflows: List[WorkflowRunDict], options: TriggerWorkflowOptions = None
+ ) -> list[WorkflowRunRef]:
+ workflow_run_requests: TriggerWorkflowRequest = []
+ if not self.pooled_workflow_listener:
+ self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
+
+ for workflow in workflows:
+ workflow_name = workflow["workflow_name"]
+ input_data = workflow["input"]
+ options = workflow["options"]
+
+ namespace = self.namespace
+
+ if (
+ options is not None
+ and "namespace" in options
+ and options["namespace"] is not None
+ ):
+ namespace = options["namespace"]
+ del options["namespace"]
+
+ if namespace != "" and not workflow_name.startswith(self.namespace):
+ workflow_name = f"{namespace}{workflow_name}"
+
+ # Prepare and trigger workflow for each workflow name and input
+ request = self._prepare_workflow_request(workflow_name, input_data, options)
+
+ workflow_run_requests.append(request)
+
+ request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests)
+
+ resp: BulkTriggerWorkflowResponse = self.client.BulkTriggerWorkflow(
+ request,
+ metadata=get_metadata(self.token),
+ )
+
+ return [
+ WorkflowRunRef(
+ workflow_run_id=workflow_run_id,
+ workflow_listener=self.pooled_workflow_listener,
+ workflow_run_event_listener=self.listener_client,
+ )
+ for workflow_run_id in resp.workflow_run_ids
+ ]
+
+ def run(
+ self,
+ function: Union[str, Callable[[Any], T]],
+ input: any,
+ options: TriggerWorkflowOptions = None,
+ ) -> "RunRef[T]":
+ workflow_name = function
+
+ if not isinstance(function, str):
+ workflow_name = function.function_name
+
+ wrr = self.run_workflow(workflow_name, input, options)
+
+ return RunRef[T](
+ wrr.workflow_run_id, wrr.workflow_listener, wrr.workflow_run_event_listener
+ )
+
+ def get_workflow_run(self, workflow_run_id: str) -> WorkflowRunRef:
+ try:
+ if not self.pooled_workflow_listener:
+ self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
+
+ return WorkflowRunRef(
+ workflow_run_id=workflow_run_id,
+ workflow_listener=self.pooled_workflow_listener,
+ workflow_run_event_listener=self.listener_client,
+ )
+ except grpc.RpcError as e:
+ raise ValueError(f"Could not get workflow run: {e}")