aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher')
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/action_listener.py423
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/dispatcher.py204
2 files changed, 627 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/action_listener.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/action_listener.py
new file mode 100644
index 00000000..cf231a76
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/action_listener.py
@@ -0,0 +1,423 @@
+import asyncio
+import json
+import time
+from dataclasses import dataclass, field
+from typing import Any, AsyncGenerator, List, Optional
+
+import grpc
+from grpc._cython import cygrpc
+
+from hatchet_sdk.clients.event_ts import Event_ts, read_with_interrupt
+from hatchet_sdk.clients.run_event_listener import (
+ DEFAULT_ACTION_LISTENER_RETRY_INTERVAL,
+)
+from hatchet_sdk.connection import new_conn
+from hatchet_sdk.contracts.dispatcher_pb2 import (
+ ActionType,
+ AssignedAction,
+ HeartbeatRequest,
+ WorkerLabels,
+ WorkerListenRequest,
+ WorkerUnsubscribeRequest,
+)
+from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub
+from hatchet_sdk.logger import logger
+from hatchet_sdk.utils.backoff import exp_backoff_sleep
+from hatchet_sdk.utils.serialization import flatten
+
+from ...loader import ClientConfig
+from ...metadata import get_metadata
+from ..events import proto_timestamp_now
+
+DEFAULT_ACTION_TIMEOUT = 600 # seconds
+
+
+DEFAULT_ACTION_LISTENER_RETRY_INTERVAL = 5 # seconds
+DEFAULT_ACTION_LISTENER_RETRY_COUNT = 15
+
+
+@dataclass
+class GetActionListenerRequest:
+ worker_name: str
+ services: List[str]
+ actions: List[str]
+ max_runs: Optional[int] = None
+ _labels: dict[str, str | int] = field(default_factory=dict)
+
+ labels: dict[str, WorkerLabels] = field(init=False)
+
+ def __post_init__(self):
+ self.labels = {}
+
+ for key, value in self._labels.items():
+ if isinstance(value, int):
+ self.labels[key] = WorkerLabels(intValue=value)
+ else:
+ self.labels[key] = WorkerLabels(strValue=str(value))
+
+
+@dataclass
+class Action:
+ worker_id: str
+ tenant_id: str
+ workflow_run_id: str
+ get_group_key_run_id: str
+ job_id: str
+ job_name: str
+ job_run_id: str
+ step_id: str
+ step_run_id: str
+ action_id: str
+ action_payload: str
+ action_type: ActionType
+ retry_count: int
+ additional_metadata: dict[str, str] | None = None
+
+ child_workflow_index: int | None = None
+ child_workflow_key: str | None = None
+ parent_workflow_run_id: str | None = None
+
+ def __post_init__(self):
+ if isinstance(self.additional_metadata, str) and self.additional_metadata != "":
+ try:
+ self.additional_metadata = json.loads(self.additional_metadata)
+ except json.JSONDecodeError:
+ # If JSON decoding fails, keep the original string
+ pass
+
+ # Ensure additional_metadata is always a dictionary
+ if not isinstance(self.additional_metadata, dict):
+ self.additional_metadata = {}
+
+ @property
+ def otel_attributes(self) -> dict[str, str | int]:
+ try:
+ payload_str = json.dumps(self.action_payload, default=str)
+ except Exception:
+ payload_str = str(self.action_payload)
+
+ attrs: dict[str, str | int | None] = {
+ "hatchet.tenant_id": self.tenant_id,
+ "hatchet.worker_id": self.worker_id,
+ "hatchet.workflow_run_id": self.workflow_run_id,
+ "hatchet.step_id": self.step_id,
+ "hatchet.step_run_id": self.step_run_id,
+ "hatchet.retry_count": self.retry_count,
+ "hatchet.parent_workflow_run_id": self.parent_workflow_run_id,
+ "hatchet.child_workflow_index": self.child_workflow_index,
+ "hatchet.child_workflow_key": self.child_workflow_key,
+ "hatchet.action_payload": payload_str,
+ "hatchet.workflow_name": self.job_name,
+ "hatchet.action_name": self.action_id,
+ "hatchet.get_group_key_run_id": self.get_group_key_run_id,
+ }
+
+ return {k: v for k, v in attrs.items() if v}
+
+
+START_STEP_RUN = 0
+CANCEL_STEP_RUN = 1
+START_GET_GROUP_KEY = 2
+
+
+@dataclass
+class ActionListener:
+ config: ClientConfig
+ worker_id: str
+
+ client: DispatcherStub = field(init=False)
+ aio_client: DispatcherStub = field(init=False)
+ token: str = field(init=False)
+ retries: int = field(default=0, init=False)
+ last_connection_attempt: float = field(default=0, init=False)
+ last_heartbeat_succeeded: bool = field(default=True, init=False)
+ time_last_hb_succeeded: float = field(default=9999999999999, init=False)
+ heartbeat_task: Optional[asyncio.Task] = field(default=None, init=False)
+ run_heartbeat: bool = field(default=True, init=False)
+ listen_strategy: str = field(default="v2", init=False)
+ stop_signal: bool = field(default=False, init=False)
+
+ missed_heartbeats: int = field(default=0, init=False)
+
+ def __post_init__(self):
+ self.client = DispatcherStub(new_conn(self.config))
+ self.aio_client = DispatcherStub(new_conn(self.config, True))
+ self.token = self.config.token
+
+ def is_healthy(self):
+ return self.last_heartbeat_succeeded
+
+ async def heartbeat(self):
+ # send a heartbeat every 4 seconds
+ heartbeat_delay = 4
+
+ while True:
+ if not self.run_heartbeat:
+ break
+
+ try:
+ logger.debug("sending heartbeat")
+ await self.aio_client.Heartbeat(
+ HeartbeatRequest(
+ workerId=self.worker_id,
+ heartbeatAt=proto_timestamp_now(),
+ ),
+ timeout=5,
+ metadata=get_metadata(self.token),
+ )
+
+ if self.last_heartbeat_succeeded is False:
+ logger.info("listener established")
+
+ now = time.time()
+ diff = now - self.time_last_hb_succeeded
+ if diff > heartbeat_delay + 1:
+ logger.warn(
+ f"time since last successful heartbeat: {diff:.2f}s, expects {heartbeat_delay}s"
+ )
+
+ self.last_heartbeat_succeeded = True
+ self.time_last_hb_succeeded = now
+ self.missed_heartbeats = 0
+ except grpc.RpcError as e:
+ self.missed_heartbeats = self.missed_heartbeats + 1
+ self.last_heartbeat_succeeded = False
+
+ if (
+ e.code() == grpc.StatusCode.UNAVAILABLE
+ or e.code() == grpc.StatusCode.FAILED_PRECONDITION
+ ):
+ # todo case on "recvmsg:Connection reset by peer" for updates?
+ if self.missed_heartbeats >= 3:
+ # we don't reraise the error here, as we don't want to stop the heartbeat thread
+ logger.error(
+ f"⛔️ failed heartbeat ({self.missed_heartbeats}): {e.details()}"
+ )
+ elif self.missed_heartbeats > 1:
+ logger.warning(
+ f"failed to send heartbeat ({self.missed_heartbeats}): {e.details()}"
+ )
+ else:
+ logger.error(f"failed to send heartbeat: {e}")
+
+ if self.interrupt is not None:
+ self.interrupt.set()
+
+ if e.code() == grpc.StatusCode.UNIMPLEMENTED:
+ break
+ await asyncio.sleep(heartbeat_delay)
+
+ async def start_heartbeater(self):
+ if self.heartbeat_task is not None:
+ return
+
+ try:
+ loop = asyncio.get_event_loop()
+ except RuntimeError as e:
+ if str(e).startswith("There is no current event loop in thread"):
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ else:
+ raise e
+ self.heartbeat_task = loop.create_task(self.heartbeat())
+
+ def __aiter__(self):
+ return self._generator()
+
+ async def _generator(self) -> AsyncGenerator[Action, None]:
+ listener = None
+
+ while not self.stop_signal:
+ if listener is not None:
+ listener.cancel()
+
+ try:
+ listener = await self.get_listen_client()
+ except Exception:
+ logger.info("closing action listener loop")
+ yield None
+
+ try:
+ while not self.stop_signal:
+ self.interrupt = Event_ts()
+ t = asyncio.create_task(
+ read_with_interrupt(listener, self.interrupt)
+ )
+ await self.interrupt.wait()
+
+ if not t.done():
+ # print a warning
+ logger.warning(
+ "Interrupted read_with_interrupt task of action listener"
+ )
+
+ t.cancel()
+ listener.cancel()
+ break
+
+ assigned_action = t.result()
+
+ if assigned_action is cygrpc.EOF:
+ self.retries = self.retries + 1
+ break
+
+ self.retries = 0
+ assigned_action: AssignedAction
+
+ # Process the received action
+ action_type = self.map_action_type(assigned_action.actionType)
+
+ if (
+ assigned_action.actionPayload is None
+ or assigned_action.actionPayload == ""
+ ):
+ action_payload = None
+ else:
+ action_payload = self.parse_action_payload(
+ assigned_action.actionPayload
+ )
+
+ action = Action(
+ tenant_id=assigned_action.tenantId,
+ worker_id=self.worker_id,
+ workflow_run_id=assigned_action.workflowRunId,
+ get_group_key_run_id=assigned_action.getGroupKeyRunId,
+ job_id=assigned_action.jobId,
+ job_name=assigned_action.jobName,
+ job_run_id=assigned_action.jobRunId,
+ step_id=assigned_action.stepId,
+ step_run_id=assigned_action.stepRunId,
+ action_id=assigned_action.actionId,
+ action_payload=action_payload,
+ action_type=action_type,
+ retry_count=assigned_action.retryCount,
+ additional_metadata=assigned_action.additional_metadata,
+ child_workflow_index=assigned_action.child_workflow_index,
+ child_workflow_key=assigned_action.child_workflow_key,
+ parent_workflow_run_id=assigned_action.parent_workflow_run_id,
+ )
+
+ yield action
+ except grpc.RpcError as e:
+ self.last_heartbeat_succeeded = False
+
+ # Handle different types of errors
+ if e.code() == grpc.StatusCode.CANCELLED:
+ # Context cancelled, unsubscribe and close
+ logger.debug("Context cancelled, closing listener")
+ elif e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
+ logger.info("Deadline exceeded, retrying subscription")
+ elif (
+ self.listen_strategy == "v2"
+ and e.code() == grpc.StatusCode.UNIMPLEMENTED
+ ):
+ # ListenV2 is not available, fallback to Listen
+ self.listen_strategy = "v1"
+ self.run_heartbeat = False
+ logger.info("ListenV2 not available, falling back to Listen")
+ else:
+ # TODO retry
+ if e.code() == grpc.StatusCode.UNAVAILABLE:
+ logger.error(f"action listener error: {e.details()}")
+ else:
+ # Unknown error, report and break
+ logger.error(f"action listener error: {e}")
+
+ self.retries = self.retries + 1
+
+ def parse_action_payload(self, payload: str):
+ try:
+ payload_data = json.loads(payload)
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Error decoding payload: {e}")
+ return payload_data
+
+ def map_action_type(self, action_type):
+ if action_type == ActionType.START_STEP_RUN:
+ return START_STEP_RUN
+ elif action_type == ActionType.CANCEL_STEP_RUN:
+ return CANCEL_STEP_RUN
+ elif action_type == ActionType.START_GET_GROUP_KEY:
+ return START_GET_GROUP_KEY
+ else:
+ # logger.error(f"Unknown action type: {action_type}")
+ return None
+
+ async def get_listen_client(self):
+ current_time = int(time.time())
+
+ if (
+ current_time - self.last_connection_attempt
+ > DEFAULT_ACTION_LISTENER_RETRY_INTERVAL
+ ):
+ # reset retries if last connection was long lived
+ self.retries = 0
+
+ if self.retries > DEFAULT_ACTION_LISTENER_RETRY_COUNT:
+ # TODO this is the problem case...
+ logger.error(
+ f"could not establish action listener connection after {DEFAULT_ACTION_LISTENER_RETRY_COUNT} retries"
+ )
+ self.run_heartbeat = False
+ raise Exception("retry_exhausted")
+ elif self.retries >= 1:
+ # logger.info
+ # if we are retrying, we wait for a bit. this should eventually be replaced with exp backoff + jitter
+ await exp_backoff_sleep(
+ self.retries, DEFAULT_ACTION_LISTENER_RETRY_INTERVAL
+ )
+
+ logger.info(
+ f"action listener connection interrupted, retrying... ({self.retries}/{DEFAULT_ACTION_LISTENER_RETRY_COUNT})"
+ )
+
+ self.aio_client = DispatcherStub(new_conn(self.config, True))
+
+ if self.listen_strategy == "v2":
+ # we should await for the listener to be established before
+ # starting the heartbeater
+ listener = self.aio_client.ListenV2(
+ WorkerListenRequest(workerId=self.worker_id),
+ timeout=self.config.listener_v2_timeout,
+ metadata=get_metadata(self.token),
+ )
+ await self.start_heartbeater()
+ else:
+ # if ListenV2 is not available, fallback to Listen
+ listener = self.aio_client.Listen(
+ WorkerListenRequest(workerId=self.worker_id),
+ timeout=DEFAULT_ACTION_TIMEOUT,
+ metadata=get_metadata(self.token),
+ )
+
+ self.last_connection_attempt = current_time
+
+ return listener
+
+ def cleanup(self):
+ self.run_heartbeat = False
+ self.heartbeat_task.cancel()
+
+ try:
+ self.unregister()
+ except Exception as e:
+ logger.error(f"failed to unregister: {e}")
+
+ if self.interrupt:
+ self.interrupt.set()
+
+ def unregister(self):
+ self.run_heartbeat = False
+ self.heartbeat_task.cancel()
+
+ try:
+ req = self.aio_client.Unsubscribe(
+ WorkerUnsubscribeRequest(workerId=self.worker_id),
+ timeout=5,
+ metadata=get_metadata(self.token),
+ )
+ if self.interrupt is not None:
+ self.interrupt.set()
+ return req
+ except grpc.RpcError as e:
+ raise Exception(f"Failed to unsubscribe: {e}")
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/dispatcher.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/dispatcher.py
new file mode 100644
index 00000000..407a80cc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/dispatcher.py
@@ -0,0 +1,204 @@
+from typing import Any, cast
+
+from google.protobuf.timestamp_pb2 import Timestamp
+
+from hatchet_sdk.clients.dispatcher.action_listener import (
+ Action,
+ ActionListener,
+ GetActionListenerRequest,
+)
+from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
+from hatchet_sdk.connection import new_conn
+from hatchet_sdk.contracts.dispatcher_pb2 import (
+ STEP_EVENT_TYPE_COMPLETED,
+ STEP_EVENT_TYPE_FAILED,
+ ActionEventResponse,
+ GroupKeyActionEvent,
+ GroupKeyActionEventType,
+ OverridesData,
+ RefreshTimeoutRequest,
+ ReleaseSlotRequest,
+ StepActionEvent,
+ StepActionEventType,
+ UpsertWorkerLabelsRequest,
+ WorkerLabels,
+ WorkerRegisterRequest,
+ WorkerRegisterResponse,
+)
+from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub
+
+from ...loader import ClientConfig
+from ...metadata import get_metadata
+
+DEFAULT_REGISTER_TIMEOUT = 30
+
+
+def new_dispatcher(config: ClientConfig) -> "DispatcherClient":
+ return DispatcherClient(config=config)
+
+
+class DispatcherClient:
+ config: ClientConfig
+
+ def __init__(self, config: ClientConfig):
+ conn = new_conn(config)
+ self.client = DispatcherStub(conn) # type: ignore[no-untyped-call]
+
+ aio_conn = new_conn(config, True)
+ self.aio_client = DispatcherStub(aio_conn) # type: ignore[no-untyped-call]
+ self.token = config.token
+ self.config = config
+
+ async def get_action_listener(
+ self, req: GetActionListenerRequest
+ ) -> ActionListener:
+
+ # Override labels with the preset labels
+ preset_labels = self.config.worker_preset_labels
+
+ for key, value in preset_labels.items():
+ req.labels[key] = WorkerLabels(strValue=str(value))
+
+ # Register the worker
+ response: WorkerRegisterResponse = await self.aio_client.Register(
+ WorkerRegisterRequest(
+ workerName=req.worker_name,
+ actions=req.actions,
+ services=req.services,
+ maxRuns=req.max_runs,
+ labels=req.labels,
+ ),
+ timeout=DEFAULT_REGISTER_TIMEOUT,
+ metadata=get_metadata(self.token),
+ )
+
+ return ActionListener(self.config, response.workerId)
+
+ async def send_step_action_event(
+ self, action: Action, event_type: StepActionEventType, payload: str
+ ) -> Any:
+ try:
+ return await self._try_send_step_action_event(action, event_type, payload)
+ except Exception as e:
+ # for step action events, send a failure event when we cannot send the completed event
+ if (
+ event_type == STEP_EVENT_TYPE_COMPLETED
+ or event_type == STEP_EVENT_TYPE_FAILED
+ ):
+ await self._try_send_step_action_event(
+ action,
+ STEP_EVENT_TYPE_FAILED,
+ "Failed to send finished event: " + str(e),
+ )
+
+ return
+
+ @tenacity_retry
+ async def _try_send_step_action_event(
+ self, action: Action, event_type: StepActionEventType, payload: str
+ ) -> Any:
+ eventTimestamp = Timestamp()
+ eventTimestamp.GetCurrentTime()
+
+ event = StepActionEvent(
+ workerId=action.worker_id,
+ jobId=action.job_id,
+ jobRunId=action.job_run_id,
+ stepId=action.step_id,
+ stepRunId=action.step_run_id,
+ actionId=action.action_id,
+ eventTimestamp=eventTimestamp,
+ eventType=event_type,
+ eventPayload=payload,
+ retryCount=action.retry_count,
+ )
+
+ ## TODO: What does this return?
+ return await self.aio_client.SendStepActionEvent(
+ event,
+ metadata=get_metadata(self.token),
+ )
+
+ async def send_group_key_action_event(
+ self, action: Action, event_type: GroupKeyActionEventType, payload: str
+ ) -> Any:
+ eventTimestamp = Timestamp()
+ eventTimestamp.GetCurrentTime()
+
+ event = GroupKeyActionEvent(
+ workerId=action.worker_id,
+ workflowRunId=action.workflow_run_id,
+ getGroupKeyRunId=action.get_group_key_run_id,
+ actionId=action.action_id,
+ eventTimestamp=eventTimestamp,
+ eventType=event_type,
+ eventPayload=payload,
+ )
+
+ ## TODO: What does this return?
+ return await self.aio_client.SendGroupKeyActionEvent(
+ event,
+ metadata=get_metadata(self.token),
+ )
+
+ def put_overrides_data(self, data: OverridesData) -> ActionEventResponse:
+ return cast(
+ ActionEventResponse,
+ self.client.PutOverridesData(
+ data,
+ metadata=get_metadata(self.token),
+ ),
+ )
+
+ def release_slot(self, step_run_id: str) -> None:
+ self.client.ReleaseSlot(
+ ReleaseSlotRequest(stepRunId=step_run_id),
+ timeout=DEFAULT_REGISTER_TIMEOUT,
+ metadata=get_metadata(self.token),
+ )
+
+ def refresh_timeout(self, step_run_id: str, increment_by: str) -> None:
+ self.client.RefreshTimeout(
+ RefreshTimeoutRequest(
+ stepRunId=step_run_id,
+ incrementTimeoutBy=increment_by,
+ ),
+ timeout=DEFAULT_REGISTER_TIMEOUT,
+ metadata=get_metadata(self.token),
+ )
+
+ def upsert_worker_labels(
+ self, worker_id: str | None, labels: dict[str, str | int]
+ ) -> None:
+ worker_labels = {}
+
+ for key, value in labels.items():
+ if isinstance(value, int):
+ worker_labels[key] = WorkerLabels(intValue=value)
+ else:
+ worker_labels[key] = WorkerLabels(strValue=str(value))
+
+ self.client.UpsertWorkerLabels(
+ UpsertWorkerLabelsRequest(workerId=worker_id, labels=worker_labels),
+ timeout=DEFAULT_REGISTER_TIMEOUT,
+ metadata=get_metadata(self.token),
+ )
+
+ async def async_upsert_worker_labels(
+ self,
+ worker_id: str | None,
+ labels: dict[str, str | int],
+ ) -> None:
+ worker_labels = {}
+
+ for key, value in labels.items():
+ if isinstance(value, int):
+ worker_labels[key] = WorkerLabels(intValue=value)
+ else:
+ worker_labels[key] = WorkerLabels(strValue=str(value))
+
+ await self.aio_client.UpsertWorkerLabels(
+ UpsertWorkerLabelsRequest(workerId=worker_id, labels=worker_labels),
+ timeout=DEFAULT_REGISTER_TIMEOUT,
+ metadata=get_metadata(self.token),
+ )