diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher')
-rw-r--r-- | .venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/action_listener.py | 423 | ||||
-rw-r--r-- | .venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/dispatcher.py | 204 |
2 files changed, 627 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/action_listener.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/action_listener.py new file mode 100644 index 00000000..cf231a76 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/action_listener.py @@ -0,0 +1,423 @@ +import asyncio +import json +import time +from dataclasses import dataclass, field +from typing import Any, AsyncGenerator, List, Optional + +import grpc +from grpc._cython import cygrpc + +from hatchet_sdk.clients.event_ts import Event_ts, read_with_interrupt +from hatchet_sdk.clients.run_event_listener import ( + DEFAULT_ACTION_LISTENER_RETRY_INTERVAL, +) +from hatchet_sdk.connection import new_conn +from hatchet_sdk.contracts.dispatcher_pb2 import ( + ActionType, + AssignedAction, + HeartbeatRequest, + WorkerLabels, + WorkerListenRequest, + WorkerUnsubscribeRequest, +) +from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub +from hatchet_sdk.logger import logger +from hatchet_sdk.utils.backoff import exp_backoff_sleep +from hatchet_sdk.utils.serialization import flatten + +from ...loader import ClientConfig +from ...metadata import get_metadata +from ..events import proto_timestamp_now + +DEFAULT_ACTION_TIMEOUT = 600 # seconds + + +DEFAULT_ACTION_LISTENER_RETRY_INTERVAL = 5 # seconds +DEFAULT_ACTION_LISTENER_RETRY_COUNT = 15 + + +@dataclass +class GetActionListenerRequest: + worker_name: str + services: List[str] + actions: List[str] + max_runs: Optional[int] = None + _labels: dict[str, str | int] = field(default_factory=dict) + + labels: dict[str, WorkerLabels] = field(init=False) + + def __post_init__(self): + self.labels = {} + + for key, value in self._labels.items(): + if isinstance(value, int): + self.labels[key] = WorkerLabels(intValue=value) + else: + self.labels[key] = WorkerLabels(strValue=str(value)) + + +@dataclass +class Action: + worker_id: str + tenant_id: str + workflow_run_id: str + get_group_key_run_id: str + job_id: str + job_name: str + job_run_id: str + step_id: str + step_run_id: str + action_id: str + action_payload: str + action_type: ActionType + retry_count: int + additional_metadata: dict[str, str] | None = None + + child_workflow_index: int | None = None + child_workflow_key: str | None = None + parent_workflow_run_id: str | None = None + + def __post_init__(self): + if isinstance(self.additional_metadata, str) and self.additional_metadata != "": + try: + self.additional_metadata = json.loads(self.additional_metadata) + except json.JSONDecodeError: + # If JSON decoding fails, keep the original string + pass + + # Ensure additional_metadata is always a dictionary + if not isinstance(self.additional_metadata, dict): + self.additional_metadata = {} + + @property + def otel_attributes(self) -> dict[str, str | int]: + try: + payload_str = json.dumps(self.action_payload, default=str) + except Exception: + payload_str = str(self.action_payload) + + attrs: dict[str, str | int | None] = { + "hatchet.tenant_id": self.tenant_id, + "hatchet.worker_id": self.worker_id, + "hatchet.workflow_run_id": self.workflow_run_id, + "hatchet.step_id": self.step_id, + "hatchet.step_run_id": self.step_run_id, + "hatchet.retry_count": self.retry_count, + "hatchet.parent_workflow_run_id": self.parent_workflow_run_id, + "hatchet.child_workflow_index": self.child_workflow_index, + "hatchet.child_workflow_key": self.child_workflow_key, + "hatchet.action_payload": payload_str, + "hatchet.workflow_name": self.job_name, + "hatchet.action_name": self.action_id, + "hatchet.get_group_key_run_id": self.get_group_key_run_id, + } + + return {k: v for k, v in attrs.items() if v} + + +START_STEP_RUN = 0 +CANCEL_STEP_RUN = 1 +START_GET_GROUP_KEY = 2 + + +@dataclass +class ActionListener: + config: ClientConfig + worker_id: str + + client: DispatcherStub = field(init=False) + aio_client: DispatcherStub = field(init=False) + token: str = field(init=False) + retries: int = field(default=0, init=False) + last_connection_attempt: float = field(default=0, init=False) + last_heartbeat_succeeded: bool = field(default=True, init=False) + time_last_hb_succeeded: float = field(default=9999999999999, init=False) + heartbeat_task: Optional[asyncio.Task] = field(default=None, init=False) + run_heartbeat: bool = field(default=True, init=False) + listen_strategy: str = field(default="v2", init=False) + stop_signal: bool = field(default=False, init=False) + + missed_heartbeats: int = field(default=0, init=False) + + def __post_init__(self): + self.client = DispatcherStub(new_conn(self.config)) + self.aio_client = DispatcherStub(new_conn(self.config, True)) + self.token = self.config.token + + def is_healthy(self): + return self.last_heartbeat_succeeded + + async def heartbeat(self): + # send a heartbeat every 4 seconds + heartbeat_delay = 4 + + while True: + if not self.run_heartbeat: + break + + try: + logger.debug("sending heartbeat") + await self.aio_client.Heartbeat( + HeartbeatRequest( + workerId=self.worker_id, + heartbeatAt=proto_timestamp_now(), + ), + timeout=5, + metadata=get_metadata(self.token), + ) + + if self.last_heartbeat_succeeded is False: + logger.info("listener established") + + now = time.time() + diff = now - self.time_last_hb_succeeded + if diff > heartbeat_delay + 1: + logger.warn( + f"time since last successful heartbeat: {diff:.2f}s, expects {heartbeat_delay}s" + ) + + self.last_heartbeat_succeeded = True + self.time_last_hb_succeeded = now + self.missed_heartbeats = 0 + except grpc.RpcError as e: + self.missed_heartbeats = self.missed_heartbeats + 1 + self.last_heartbeat_succeeded = False + + if ( + e.code() == grpc.StatusCode.UNAVAILABLE + or e.code() == grpc.StatusCode.FAILED_PRECONDITION + ): + # todo case on "recvmsg:Connection reset by peer" for updates? + if self.missed_heartbeats >= 3: + # we don't reraise the error here, as we don't want to stop the heartbeat thread + logger.error( + f"⛔️ failed heartbeat ({self.missed_heartbeats}): {e.details()}" + ) + elif self.missed_heartbeats > 1: + logger.warning( + f"failed to send heartbeat ({self.missed_heartbeats}): {e.details()}" + ) + else: + logger.error(f"failed to send heartbeat: {e}") + + if self.interrupt is not None: + self.interrupt.set() + + if e.code() == grpc.StatusCode.UNIMPLEMENTED: + break + await asyncio.sleep(heartbeat_delay) + + async def start_heartbeater(self): + if self.heartbeat_task is not None: + return + + try: + loop = asyncio.get_event_loop() + except RuntimeError as e: + if str(e).startswith("There is no current event loop in thread"): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + else: + raise e + self.heartbeat_task = loop.create_task(self.heartbeat()) + + def __aiter__(self): + return self._generator() + + async def _generator(self) -> AsyncGenerator[Action, None]: + listener = None + + while not self.stop_signal: + if listener is not None: + listener.cancel() + + try: + listener = await self.get_listen_client() + except Exception: + logger.info("closing action listener loop") + yield None + + try: + while not self.stop_signal: + self.interrupt = Event_ts() + t = asyncio.create_task( + read_with_interrupt(listener, self.interrupt) + ) + await self.interrupt.wait() + + if not t.done(): + # print a warning + logger.warning( + "Interrupted read_with_interrupt task of action listener" + ) + + t.cancel() + listener.cancel() + break + + assigned_action = t.result() + + if assigned_action is cygrpc.EOF: + self.retries = self.retries + 1 + break + + self.retries = 0 + assigned_action: AssignedAction + + # Process the received action + action_type = self.map_action_type(assigned_action.actionType) + + if ( + assigned_action.actionPayload is None + or assigned_action.actionPayload == "" + ): + action_payload = None + else: + action_payload = self.parse_action_payload( + assigned_action.actionPayload + ) + + action = Action( + tenant_id=assigned_action.tenantId, + worker_id=self.worker_id, + workflow_run_id=assigned_action.workflowRunId, + get_group_key_run_id=assigned_action.getGroupKeyRunId, + job_id=assigned_action.jobId, + job_name=assigned_action.jobName, + job_run_id=assigned_action.jobRunId, + step_id=assigned_action.stepId, + step_run_id=assigned_action.stepRunId, + action_id=assigned_action.actionId, + action_payload=action_payload, + action_type=action_type, + retry_count=assigned_action.retryCount, + additional_metadata=assigned_action.additional_metadata, + child_workflow_index=assigned_action.child_workflow_index, + child_workflow_key=assigned_action.child_workflow_key, + parent_workflow_run_id=assigned_action.parent_workflow_run_id, + ) + + yield action + except grpc.RpcError as e: + self.last_heartbeat_succeeded = False + + # Handle different types of errors + if e.code() == grpc.StatusCode.CANCELLED: + # Context cancelled, unsubscribe and close + logger.debug("Context cancelled, closing listener") + elif e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + logger.info("Deadline exceeded, retrying subscription") + elif ( + self.listen_strategy == "v2" + and e.code() == grpc.StatusCode.UNIMPLEMENTED + ): + # ListenV2 is not available, fallback to Listen + self.listen_strategy = "v1" + self.run_heartbeat = False + logger.info("ListenV2 not available, falling back to Listen") + else: + # TODO retry + if e.code() == grpc.StatusCode.UNAVAILABLE: + logger.error(f"action listener error: {e.details()}") + else: + # Unknown error, report and break + logger.error(f"action listener error: {e}") + + self.retries = self.retries + 1 + + def parse_action_payload(self, payload: str): + try: + payload_data = json.loads(payload) + except json.JSONDecodeError as e: + raise ValueError(f"Error decoding payload: {e}") + return payload_data + + def map_action_type(self, action_type): + if action_type == ActionType.START_STEP_RUN: + return START_STEP_RUN + elif action_type == ActionType.CANCEL_STEP_RUN: + return CANCEL_STEP_RUN + elif action_type == ActionType.START_GET_GROUP_KEY: + return START_GET_GROUP_KEY + else: + # logger.error(f"Unknown action type: {action_type}") + return None + + async def get_listen_client(self): + current_time = int(time.time()) + + if ( + current_time - self.last_connection_attempt + > DEFAULT_ACTION_LISTENER_RETRY_INTERVAL + ): + # reset retries if last connection was long lived + self.retries = 0 + + if self.retries > DEFAULT_ACTION_LISTENER_RETRY_COUNT: + # TODO this is the problem case... + logger.error( + f"could not establish action listener connection after {DEFAULT_ACTION_LISTENER_RETRY_COUNT} retries" + ) + self.run_heartbeat = False + raise Exception("retry_exhausted") + elif self.retries >= 1: + # logger.info + # if we are retrying, we wait for a bit. this should eventually be replaced with exp backoff + jitter + await exp_backoff_sleep( + self.retries, DEFAULT_ACTION_LISTENER_RETRY_INTERVAL + ) + + logger.info( + f"action listener connection interrupted, retrying... ({self.retries}/{DEFAULT_ACTION_LISTENER_RETRY_COUNT})" + ) + + self.aio_client = DispatcherStub(new_conn(self.config, True)) + + if self.listen_strategy == "v2": + # we should await for the listener to be established before + # starting the heartbeater + listener = self.aio_client.ListenV2( + WorkerListenRequest(workerId=self.worker_id), + timeout=self.config.listener_v2_timeout, + metadata=get_metadata(self.token), + ) + await self.start_heartbeater() + else: + # if ListenV2 is not available, fallback to Listen + listener = self.aio_client.Listen( + WorkerListenRequest(workerId=self.worker_id), + timeout=DEFAULT_ACTION_TIMEOUT, + metadata=get_metadata(self.token), + ) + + self.last_connection_attempt = current_time + + return listener + + def cleanup(self): + self.run_heartbeat = False + self.heartbeat_task.cancel() + + try: + self.unregister() + except Exception as e: + logger.error(f"failed to unregister: {e}") + + if self.interrupt: + self.interrupt.set() + + def unregister(self): + self.run_heartbeat = False + self.heartbeat_task.cancel() + + try: + req = self.aio_client.Unsubscribe( + WorkerUnsubscribeRequest(workerId=self.worker_id), + timeout=5, + metadata=get_metadata(self.token), + ) + if self.interrupt is not None: + self.interrupt.set() + return req + except grpc.RpcError as e: + raise Exception(f"Failed to unsubscribe: {e}") diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/dispatcher.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/dispatcher.py new file mode 100644 index 00000000..407a80cc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/dispatcher/dispatcher.py @@ -0,0 +1,204 @@ +from typing import Any, cast + +from google.protobuf.timestamp_pb2 import Timestamp + +from hatchet_sdk.clients.dispatcher.action_listener import ( + Action, + ActionListener, + GetActionListenerRequest, +) +from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry +from hatchet_sdk.connection import new_conn +from hatchet_sdk.contracts.dispatcher_pb2 import ( + STEP_EVENT_TYPE_COMPLETED, + STEP_EVENT_TYPE_FAILED, + ActionEventResponse, + GroupKeyActionEvent, + GroupKeyActionEventType, + OverridesData, + RefreshTimeoutRequest, + ReleaseSlotRequest, + StepActionEvent, + StepActionEventType, + UpsertWorkerLabelsRequest, + WorkerLabels, + WorkerRegisterRequest, + WorkerRegisterResponse, +) +from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub + +from ...loader import ClientConfig +from ...metadata import get_metadata + +DEFAULT_REGISTER_TIMEOUT = 30 + + +def new_dispatcher(config: ClientConfig) -> "DispatcherClient": + return DispatcherClient(config=config) + + +class DispatcherClient: + config: ClientConfig + + def __init__(self, config: ClientConfig): + conn = new_conn(config) + self.client = DispatcherStub(conn) # type: ignore[no-untyped-call] + + aio_conn = new_conn(config, True) + self.aio_client = DispatcherStub(aio_conn) # type: ignore[no-untyped-call] + self.token = config.token + self.config = config + + async def get_action_listener( + self, req: GetActionListenerRequest + ) -> ActionListener: + + # Override labels with the preset labels + preset_labels = self.config.worker_preset_labels + + for key, value in preset_labels.items(): + req.labels[key] = WorkerLabels(strValue=str(value)) + + # Register the worker + response: WorkerRegisterResponse = await self.aio_client.Register( + WorkerRegisterRequest( + workerName=req.worker_name, + actions=req.actions, + services=req.services, + maxRuns=req.max_runs, + labels=req.labels, + ), + timeout=DEFAULT_REGISTER_TIMEOUT, + metadata=get_metadata(self.token), + ) + + return ActionListener(self.config, response.workerId) + + async def send_step_action_event( + self, action: Action, event_type: StepActionEventType, payload: str + ) -> Any: + try: + return await self._try_send_step_action_event(action, event_type, payload) + except Exception as e: + # for step action events, send a failure event when we cannot send the completed event + if ( + event_type == STEP_EVENT_TYPE_COMPLETED + or event_type == STEP_EVENT_TYPE_FAILED + ): + await self._try_send_step_action_event( + action, + STEP_EVENT_TYPE_FAILED, + "Failed to send finished event: " + str(e), + ) + + return + + @tenacity_retry + async def _try_send_step_action_event( + self, action: Action, event_type: StepActionEventType, payload: str + ) -> Any: + eventTimestamp = Timestamp() + eventTimestamp.GetCurrentTime() + + event = StepActionEvent( + workerId=action.worker_id, + jobId=action.job_id, + jobRunId=action.job_run_id, + stepId=action.step_id, + stepRunId=action.step_run_id, + actionId=action.action_id, + eventTimestamp=eventTimestamp, + eventType=event_type, + eventPayload=payload, + retryCount=action.retry_count, + ) + + ## TODO: What does this return? + return await self.aio_client.SendStepActionEvent( + event, + metadata=get_metadata(self.token), + ) + + async def send_group_key_action_event( + self, action: Action, event_type: GroupKeyActionEventType, payload: str + ) -> Any: + eventTimestamp = Timestamp() + eventTimestamp.GetCurrentTime() + + event = GroupKeyActionEvent( + workerId=action.worker_id, + workflowRunId=action.workflow_run_id, + getGroupKeyRunId=action.get_group_key_run_id, + actionId=action.action_id, + eventTimestamp=eventTimestamp, + eventType=event_type, + eventPayload=payload, + ) + + ## TODO: What does this return? + return await self.aio_client.SendGroupKeyActionEvent( + event, + metadata=get_metadata(self.token), + ) + + def put_overrides_data(self, data: OverridesData) -> ActionEventResponse: + return cast( + ActionEventResponse, + self.client.PutOverridesData( + data, + metadata=get_metadata(self.token), + ), + ) + + def release_slot(self, step_run_id: str) -> None: + self.client.ReleaseSlot( + ReleaseSlotRequest(stepRunId=step_run_id), + timeout=DEFAULT_REGISTER_TIMEOUT, + metadata=get_metadata(self.token), + ) + + def refresh_timeout(self, step_run_id: str, increment_by: str) -> None: + self.client.RefreshTimeout( + RefreshTimeoutRequest( + stepRunId=step_run_id, + incrementTimeoutBy=increment_by, + ), + timeout=DEFAULT_REGISTER_TIMEOUT, + metadata=get_metadata(self.token), + ) + + def upsert_worker_labels( + self, worker_id: str | None, labels: dict[str, str | int] + ) -> None: + worker_labels = {} + + for key, value in labels.items(): + if isinstance(value, int): + worker_labels[key] = WorkerLabels(intValue=value) + else: + worker_labels[key] = WorkerLabels(strValue=str(value)) + + self.client.UpsertWorkerLabels( + UpsertWorkerLabelsRequest(workerId=worker_id, labels=worker_labels), + timeout=DEFAULT_REGISTER_TIMEOUT, + metadata=get_metadata(self.token), + ) + + async def async_upsert_worker_labels( + self, + worker_id: str | None, + labels: dict[str, str | int], + ) -> None: + worker_labels = {} + + for key, value in labels.items(): + if isinstance(value, int): + worker_labels[key] = WorkerLabels(intValue=value) + else: + worker_labels[key] = WorkerLabels(strValue=str(value)) + + await self.aio_client.UpsertWorkerLabels( + UpsertWorkerLabelsRequest(workerId=worker_id, labels=worker_labels), + timeout=DEFAULT_REGISTER_TIMEOUT, + metadata=get_metadata(self.token), + ) |