aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/dispatcher.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/dispatcher.py')
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/dispatcher.py204
1 files changed, 204 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/dispatcher.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/dispatcher.py
new file mode 100644
index 00000000..407a80cc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/dispatcher.py
@@ -0,0 +1,204 @@
+from typing import Any, cast
+
+from google.protobuf.timestamp_pb2 import Timestamp
+
+from hatchet_sdk.clients.dispatcher.action_listener import (
+ Action,
+ ActionListener,
+ GetActionListenerRequest,
+)
+from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
+from hatchet_sdk.connection import new_conn
+from hatchet_sdk.contracts.dispatcher_pb2 import (
+ STEP_EVENT_TYPE_COMPLETED,
+ STEP_EVENT_TYPE_FAILED,
+ ActionEventResponse,
+ GroupKeyActionEvent,
+ GroupKeyActionEventType,
+ OverridesData,
+ RefreshTimeoutRequest,
+ ReleaseSlotRequest,
+ StepActionEvent,
+ StepActionEventType,
+ UpsertWorkerLabelsRequest,
+ WorkerLabels,
+ WorkerRegisterRequest,
+ WorkerRegisterResponse,
+)
+from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub
+
+from ...loader import ClientConfig
+from ...metadata import get_metadata
+
+DEFAULT_REGISTER_TIMEOUT = 30
+
+
+def new_dispatcher(config: ClientConfig) -> "DispatcherClient":
+ return DispatcherClient(config=config)
+
+
+class DispatcherClient:
+ config: ClientConfig
+
+ def __init__(self, config: ClientConfig):
+ conn = new_conn(config)
+ self.client = DispatcherStub(conn) # type: ignore[no-untyped-call]
+
+ aio_conn = new_conn(config, True)
+ self.aio_client = DispatcherStub(aio_conn) # type: ignore[no-untyped-call]
+ self.token = config.token
+ self.config = config
+
+ async def get_action_listener(
+ self, req: GetActionListenerRequest
+ ) -> ActionListener:
+
+ # Override labels with the preset labels
+ preset_labels = self.config.worker_preset_labels
+
+ for key, value in preset_labels.items():
+ req.labels[key] = WorkerLabels(strValue=str(value))
+
+ # Register the worker
+ response: WorkerRegisterResponse = await self.aio_client.Register(
+ WorkerRegisterRequest(
+ workerName=req.worker_name,
+ actions=req.actions,
+ services=req.services,
+ maxRuns=req.max_runs,
+ labels=req.labels,
+ ),
+ timeout=DEFAULT_REGISTER_TIMEOUT,
+ metadata=get_metadata(self.token),
+ )
+
+ return ActionListener(self.config, response.workerId)
+
+ async def send_step_action_event(
+ self, action: Action, event_type: StepActionEventType, payload: str
+ ) -> Any:
+ try:
+ return await self._try_send_step_action_event(action, event_type, payload)
+ except Exception as e:
+ # for step action events, send a failure event when we cannot send the completed event
+ if (
+ event_type == STEP_EVENT_TYPE_COMPLETED
+ or event_type == STEP_EVENT_TYPE_FAILED
+ ):
+ await self._try_send_step_action_event(
+ action,
+ STEP_EVENT_TYPE_FAILED,
+ "Failed to send finished event: " + str(e),
+ )
+
+ return
+
+ @tenacity_retry
+ async def _try_send_step_action_event(
+ self, action: Action, event_type: StepActionEventType, payload: str
+ ) -> Any:
+ eventTimestamp = Timestamp()
+ eventTimestamp.GetCurrentTime()
+
+ event = StepActionEvent(
+ workerId=action.worker_id,
+ jobId=action.job_id,
+ jobRunId=action.job_run_id,
+ stepId=action.step_id,
+ stepRunId=action.step_run_id,
+ actionId=action.action_id,
+ eventTimestamp=eventTimestamp,
+ eventType=event_type,
+ eventPayload=payload,
+ retryCount=action.retry_count,
+ )
+
+ ## TODO: What does this return?
+ return await self.aio_client.SendStepActionEvent(
+ event,
+ metadata=get_metadata(self.token),
+ )
+
+ async def send_group_key_action_event(
+ self, action: Action, event_type: GroupKeyActionEventType, payload: str
+ ) -> Any:
+ eventTimestamp = Timestamp()
+ eventTimestamp.GetCurrentTime()
+
+ event = GroupKeyActionEvent(
+ workerId=action.worker_id,
+ workflowRunId=action.workflow_run_id,
+ getGroupKeyRunId=action.get_group_key_run_id,
+ actionId=action.action_id,
+ eventTimestamp=eventTimestamp,
+ eventType=event_type,
+ eventPayload=payload,
+ )
+
+ ## TODO: What does this return?
+ return await self.aio_client.SendGroupKeyActionEvent(
+ event,
+ metadata=get_metadata(self.token),
+ )
+
+ def put_overrides_data(self, data: OverridesData) -> ActionEventResponse:
+ return cast(
+ ActionEventResponse,
+ self.client.PutOverridesData(
+ data,
+ metadata=get_metadata(self.token),
+ ),
+ )
+
+ def release_slot(self, step_run_id: str) -> None:
+ self.client.ReleaseSlot(
+ ReleaseSlotRequest(stepRunId=step_run_id),
+ timeout=DEFAULT_REGISTER_TIMEOUT,
+ metadata=get_metadata(self.token),
+ )
+
+ def refresh_timeout(self, step_run_id: str, increment_by: str) -> None:
+ self.client.RefreshTimeout(
+ RefreshTimeoutRequest(
+ stepRunId=step_run_id,
+ incrementTimeoutBy=increment_by,
+ ),
+ timeout=DEFAULT_REGISTER_TIMEOUT,
+ metadata=get_metadata(self.token),
+ )
+
+ def upsert_worker_labels(
+ self, worker_id: str | None, labels: dict[str, str | int]
+ ) -> None:
+ worker_labels = {}
+
+ for key, value in labels.items():
+ if isinstance(value, int):
+ worker_labels[key] = WorkerLabels(intValue=value)
+ else:
+ worker_labels[key] = WorkerLabels(strValue=str(value))
+
+ self.client.UpsertWorkerLabels(
+ UpsertWorkerLabelsRequest(workerId=worker_id, labels=worker_labels),
+ timeout=DEFAULT_REGISTER_TIMEOUT,
+ metadata=get_metadata(self.token),
+ )
+
+ async def async_upsert_worker_labels(
+ self,
+ worker_id: str | None,
+ labels: dict[str, str | int],
+ ) -> None:
+ worker_labels = {}
+
+ for key, value in labels.items():
+ if isinstance(value, int):
+ worker_labels[key] = WorkerLabels(intValue=value)
+ else:
+ worker_labels[key] = WorkerLabels(strValue=str(value))
+
+ await self.aio_client.UpsertWorkerLabels(
+ UpsertWorkerLabelsRequest(workerId=worker_id, labels=worker_labels),
+ timeout=DEFAULT_REGISTER_TIMEOUT,
+ metadata=get_metadata(self.token),
+ )