diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/worker')
7 files changed, 1330 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/__init__.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/__init__.py new file mode 100644 index 00000000..450f0cac --- /dev/null +++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/__init__.py @@ -0,0 +1 @@ +from .worker import Worker, WorkerStartOptions, WorkerStatus diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/action_listener_process.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/action_listener_process.py new file mode 100644 index 00000000..08508607 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/action_listener_process.py @@ -0,0 +1,278 @@ +import asyncio +import logging +import signal +import time +from dataclasses import dataclass, field +from multiprocessing import Queue +from typing import Any, List, Mapping, Optional + +import grpc + +from hatchet_sdk.clients.dispatcher.action_listener import Action +from hatchet_sdk.clients.dispatcher.dispatcher import ( + ActionListener, + GetActionListenerRequest, + new_dispatcher, +) +from hatchet_sdk.contracts.dispatcher_pb2 import ( + GROUP_KEY_EVENT_TYPE_STARTED, + STEP_EVENT_TYPE_STARTED, + ActionType, +) +from hatchet_sdk.loader import ClientConfig +from hatchet_sdk.logger import logger +from hatchet_sdk.utils.backoff import exp_backoff_sleep + +ACTION_EVENT_RETRY_COUNT = 5 + + +@dataclass +class ActionEvent: + action: Action + type: Any # TODO type + payload: Optional[str] = None + + +STOP_LOOP = "STOP_LOOP" # Sentinel object to stop the loop + +# TODO link to a block post +BLOCKED_THREAD_WARNING = ( + "THE TIME TO START THE STEP RUN IS TOO LONG, THE MAIN THREAD MAY BE BLOCKED" +) + + +def noop_handler(): + pass + + +@dataclass +class WorkerActionListenerProcess: + name: str + actions: List[str] + max_runs: int + config: ClientConfig + action_queue: Queue + event_queue: Queue + handle_kill: bool = True + debug: bool = False + labels: dict = field(default_factory=dict) + + listener: ActionListener = field(init=False, default=None) + + killing: bool = field(init=False, default=False) + + action_loop_task: asyncio.Task = field(init=False, default=None) + event_send_loop_task: asyncio.Task = field(init=False, default=None) + + running_step_runs: Mapping[str, float] = field(init=False, default_factory=dict) + + def __post_init__(self): + if self.debug: + logger.setLevel(logging.DEBUG) + + loop = asyncio.get_event_loop() + loop.add_signal_handler(signal.SIGINT, noop_handler) + loop.add_signal_handler(signal.SIGTERM, noop_handler) + loop.add_signal_handler( + signal.SIGQUIT, lambda: asyncio.create_task(self.exit_gracefully()) + ) + + async def start(self, retry_attempt=0): + if retry_attempt > 5: + logger.error("could not start action listener") + return + + logger.debug(f"starting action listener: {self.name}") + + try: + self.dispatcher_client = new_dispatcher(self.config) + + self.listener = await self.dispatcher_client.get_action_listener( + GetActionListenerRequest( + worker_name=self.name, + services=["default"], + actions=self.actions, + max_runs=self.max_runs, + _labels=self.labels, + ) + ) + + logger.debug(f"acquired action listener: {self.listener.worker_id}") + except grpc.RpcError as rpc_error: + logger.error(f"could not start action listener: {rpc_error}") + return + + # Start both loops as background tasks + self.action_loop_task = asyncio.create_task(self.start_action_loop()) + self.event_send_loop_task = asyncio.create_task(self.start_event_send_loop()) + self.blocked_main_loop = asyncio.create_task(self.start_blocked_main_loop()) + + # TODO move event methods to separate class + async def _get_event(self): + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, self.event_queue.get) + + async def start_event_send_loop(self): + while True: + event: ActionEvent = await self._get_event() + if event == STOP_LOOP: + logger.debug("stopping event send loop...") + break + + logger.debug(f"tx: event: {event.action.action_id}/{event.type}") + asyncio.create_task(self.send_event(event)) + + async def start_blocked_main_loop(self): + threshold = 1 + while not self.killing: + count = 0 + for step_run_id, start_time in self.running_step_runs.items(): + diff = self.now() - start_time + if diff > threshold: + count += 1 + + if count > 0: + logger.warning(f"{BLOCKED_THREAD_WARNING}: Waiting Steps {count}") + await asyncio.sleep(1) + + async def send_event(self, event: ActionEvent, retry_attempt: int = 1): + try: + match event.action.action_type: + # FIXME: all events sent from an execution of a function are of type ActionType.START_STEP_RUN since + # the action is re-used. We should change this. + case ActionType.START_STEP_RUN: + # TODO right now we're sending two start_step_run events + # one on the action loop and one on the event loop + # ideally we change the first to an ack to set the time + if event.type == STEP_EVENT_TYPE_STARTED: + if event.action.step_run_id in self.running_step_runs: + diff = ( + self.now() + - self.running_step_runs[event.action.step_run_id] + ) + if diff > 0.1: + logger.warning( + f"{BLOCKED_THREAD_WARNING}: time to start: {diff}s" + ) + else: + logger.debug(f"start time: {diff}") + del self.running_step_runs[event.action.step_run_id] + else: + self.running_step_runs[event.action.step_run_id] = ( + self.now() + ) + + asyncio.create_task( + self.dispatcher_client.send_step_action_event( + event.action, event.type, event.payload + ) + ) + case ActionType.CANCEL_STEP_RUN: + logger.debug("unimplemented event send") + case ActionType.START_GET_GROUP_KEY: + asyncio.create_task( + self.dispatcher_client.send_group_key_action_event( + event.action, event.type, event.payload + ) + ) + case _: + logger.error("unknown action type for event send") + except Exception as e: + logger.error( + f"could not send action event ({retry_attempt}/{ACTION_EVENT_RETRY_COUNT}): {e}" + ) + if retry_attempt <= ACTION_EVENT_RETRY_COUNT: + await exp_backoff_sleep(retry_attempt, 1) + await self.send_event(event, retry_attempt + 1) + + def now(self): + return time.time() + + async def start_action_loop(self): + try: + async for action in self.listener: + if action is None: + break + + # Process the action here + match action.action_type: + case ActionType.START_STEP_RUN: + self.event_queue.put( + ActionEvent( + action=action, + type=STEP_EVENT_TYPE_STARTED, # TODO ack type + ) + ) + logger.info( + f"rx: start step run: {action.step_run_id}/{action.action_id}" + ) + + # TODO handle this case better... + if action.step_run_id in self.running_step_runs: + logger.warning( + f"step run already running: {action.step_run_id}" + ) + + case ActionType.CANCEL_STEP_RUN: + logger.info(f"rx: cancel step run: {action.step_run_id}") + case ActionType.START_GET_GROUP_KEY: + self.event_queue.put( + ActionEvent( + action=action, + type=GROUP_KEY_EVENT_TYPE_STARTED, # TODO ack type + ) + ) + logger.info( + f"rx: start group key: {action.get_group_key_run_id}" + ) + case _: + logger.error( + f"rx: unknown action type ({action.action_type}): {action.action_type}" + ) + try: + self.action_queue.put(action) + except Exception as e: + logger.error(f"error putting action: {e}") + + except Exception as e: + logger.error(f"error in action loop: {e}") + finally: + logger.info("action loop closed") + if not self.killing: + await self.exit_gracefully(skip_unregister=True) + + async def cleanup(self): + self.killing = True + + if self.listener is not None: + self.listener.cleanup() + + self.event_queue.put(STOP_LOOP) + + async def exit_gracefully(self, skip_unregister=False): + if self.killing: + return + + logger.debug("closing action listener...") + + await self.cleanup() + + while not self.event_queue.empty(): + pass + + logger.info("action listener closed") + + def exit_forcefully(self): + asyncio.run(self.cleanup()) + logger.debug("forcefully closing listener...") + + +def worker_action_listener_process(*args, **kwargs): + async def run(): + process = WorkerActionListenerProcess(*args, **kwargs) + await process.start() + # Keep the process running + while not process.killing: + await asyncio.sleep(0.1) + + asyncio.run(run()) diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/run_loop_manager.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/run_loop_manager.py new file mode 100644 index 00000000..27ed788c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/run_loop_manager.py @@ -0,0 +1,112 @@ +import asyncio +import logging +from dataclasses import dataclass, field +from multiprocessing import Queue +from typing import Callable, TypeVar + +from hatchet_sdk import Context +from hatchet_sdk.client import Client, new_client_raw +from hatchet_sdk.clients.dispatcher.action_listener import Action +from hatchet_sdk.loader import ClientConfig +from hatchet_sdk.logger import logger +from hatchet_sdk.utils.types import WorkflowValidator +from hatchet_sdk.worker.runner.runner import Runner +from hatchet_sdk.worker.runner.utils.capture_logs import capture_logs + +STOP_LOOP = "STOP_LOOP" + +T = TypeVar("T") + + +@dataclass +class WorkerActionRunLoopManager: + name: str + action_registry: dict[str, Callable[[Context], T]] + validator_registry: dict[str, WorkflowValidator] + max_runs: int | None + config: ClientConfig + action_queue: Queue + event_queue: Queue + loop: asyncio.AbstractEventLoop + handle_kill: bool = True + debug: bool = False + labels: dict[str, str | int] = field(default_factory=dict) + + client: Client = field(init=False, default=None) + + killing: bool = field(init=False, default=False) + runner: Runner = field(init=False, default=None) + + def __post_init__(self): + if self.debug: + logger.setLevel(logging.DEBUG) + self.client = new_client_raw(self.config, self.debug) + self.start() + + def start(self, retry_count=1): + k = self.loop.create_task(self.async_start(retry_count)) + + async def async_start(self, retry_count=1): + await capture_logs( + self.client.logInterceptor, + self.client.event, + self._async_start, + )(retry_count=retry_count) + + async def _async_start(self, retry_count: int = 1) -> None: + logger.info("starting runner...") + self.loop = asyncio.get_running_loop() + # needed for graceful termination + k = self.loop.create_task(self._start_action_loop()) + await k + + def cleanup(self) -> None: + self.killing = True + + self.action_queue.put(STOP_LOOP) + + async def wait_for_tasks(self) -> None: + if self.runner: + await self.runner.wait_for_tasks() + + async def _start_action_loop(self) -> None: + self.runner = Runner( + self.name, + self.event_queue, + self.max_runs, + self.handle_kill, + self.action_registry, + self.validator_registry, + self.config, + self.labels, + ) + + logger.debug(f"'{self.name}' waiting for {list(self.action_registry.keys())}") + while not self.killing: + action: Action = await self._get_action() + if action == STOP_LOOP: + logger.debug("stopping action runner loop...") + break + + self.runner.run(action) + logger.debug("action runner loop stopped") + + async def _get_action(self): + return await self.loop.run_in_executor(None, self.action_queue.get) + + async def exit_gracefully(self) -> None: + if self.killing: + return + + logger.info("gracefully exiting runner...") + + self.cleanup() + + # Wait for 1 second to allow last calls to flush. These are calls which have been + # added to the event loop as callbacks to tasks, so we're not aware of them in the + # task list. + await asyncio.sleep(1) + + def exit_forcefully(self) -> None: + logger.info("forcefully exiting runner...") + self.cleanup() diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/runner.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/runner.py new file mode 100644 index 00000000..01e61bcf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/runner.py @@ -0,0 +1,460 @@ +import asyncio +import contextvars +import ctypes +import functools +import json +import time +import traceback +from concurrent.futures import ThreadPoolExecutor +from enum import Enum +from multiprocessing import Queue +from threading import Thread, current_thread +from typing import Any, Callable, Dict, cast + +from pydantic import BaseModel + +from hatchet_sdk.client import new_client_raw +from hatchet_sdk.clients.admin import new_admin +from hatchet_sdk.clients.dispatcher.action_listener import Action +from hatchet_sdk.clients.dispatcher.dispatcher import new_dispatcher +from hatchet_sdk.clients.run_event_listener import new_listener +from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener +from hatchet_sdk.context import Context # type: ignore[attr-defined] +from hatchet_sdk.context.worker_context import WorkerContext +from hatchet_sdk.contracts.dispatcher_pb2 import ( + GROUP_KEY_EVENT_TYPE_COMPLETED, + GROUP_KEY_EVENT_TYPE_FAILED, + GROUP_KEY_EVENT_TYPE_STARTED, + STEP_EVENT_TYPE_COMPLETED, + STEP_EVENT_TYPE_FAILED, + STEP_EVENT_TYPE_STARTED, + ActionType, +) +from hatchet_sdk.loader import ClientConfig +from hatchet_sdk.logger import logger +from hatchet_sdk.utils.types import WorkflowValidator +from hatchet_sdk.v2.callable import DurableContext +from hatchet_sdk.worker.action_listener_process import ActionEvent +from hatchet_sdk.worker.runner.utils.capture_logs import copy_context_vars, sr, wr + + +class WorkerStatus(Enum): + INITIALIZED = 1 + STARTING = 2 + HEALTHY = 3 + UNHEALTHY = 4 + + +class Runner: + def __init__( + self, + name: str, + event_queue: "Queue[Any]", + max_runs: int | None = None, + handle_kill: bool = True, + action_registry: dict[str, Callable[..., Any]] = {}, + validator_registry: dict[str, WorkflowValidator] = {}, + config: ClientConfig = ClientConfig(), + labels: dict[str, str | int] = {}, + ): + # We store the config so we can dynamically create clients for the dispatcher client. + self.config = config + self.client = new_client_raw(config) + self.name = self.client.config.namespace + name + self.max_runs = max_runs + self.tasks: dict[str, asyncio.Task[Any]] = {} # Store run ids and futures + self.contexts: dict[str, Context] = {} # Store run ids and contexts + self.action_registry: dict[str, Callable[..., Any]] = action_registry + self.validator_registry = validator_registry + + self.event_queue = event_queue + + # The thread pool is used for synchronous functions which need to run concurrently + self.thread_pool = ThreadPoolExecutor(max_workers=max_runs) + self.threads: Dict[str, Thread] = {} # Store run ids and threads + + self.killing = False + self.handle_kill = handle_kill + + # We need to initialize a new admin and dispatcher client *after* we've started the event loop, + # otherwise the grpc.aio methods will use a different event loop and we'll get a bunch of errors. + self.dispatcher_client = new_dispatcher(self.config) + self.admin_client = new_admin(self.config) + self.workflow_run_event_listener = new_listener(self.config) + self.client.workflow_listener = PooledWorkflowRunListener(self.config) + + self.worker_context = WorkerContext( + labels=labels, client=new_client_raw(config).dispatcher + ) + + def create_workflow_run_url(self, action: Action) -> str: + return f"{self.config.server_url}/workflow-runs/{action.workflow_run_id}?tenant={action.tenant_id}" + + def run(self, action: Action) -> None: + if self.worker_context.id() is None: + self.worker_context._worker_id = action.worker_id + + match action.action_type: + case ActionType.START_STEP_RUN: + log = f"run: start step: {action.action_id}/{action.step_run_id}" + logger.info(log) + asyncio.create_task(self.handle_start_step_run(action)) + case ActionType.CANCEL_STEP_RUN: + log = f"cancel: step run: {action.action_id}/{action.step_run_id}" + logger.info(log) + asyncio.create_task(self.handle_cancel_action(action.step_run_id)) + case ActionType.START_GET_GROUP_KEY: + log = f"run: get group key: {action.action_id}/{action.get_group_key_run_id}" + logger.info(log) + asyncio.create_task(self.handle_start_group_key_run(action)) + case _: + log = f"unknown action type: {action.action_type}" + logger.error(log) + + def step_run_callback(self, action: Action) -> Callable[[asyncio.Task[Any]], None]: + def inner_callback(task: asyncio.Task[Any]) -> None: + self.cleanup_run_id(action.step_run_id) + + errored = False + cancelled = task.cancelled() + + # Get the output from the future + try: + if not cancelled: + output = task.result() + except Exception as e: + errored = True + + # This except is coming from the application itself, so we want to send that to the Hatchet instance + self.event_queue.put( + ActionEvent( + action=action, + type=STEP_EVENT_TYPE_FAILED, + payload=str(errorWithTraceback(f"{e}", e)), + ) + ) + + logger.error( + f"failed step run: {action.action_id}/{action.step_run_id}" + ) + + if not errored and not cancelled: + self.event_queue.put( + ActionEvent( + action=action, + type=STEP_EVENT_TYPE_COMPLETED, + payload=self.serialize_output(output), + ) + ) + + logger.info( + f"finished step run: {action.action_id}/{action.step_run_id}" + ) + + return inner_callback + + def group_key_run_callback( + self, action: Action + ) -> Callable[[asyncio.Task[Any]], None]: + def inner_callback(task: asyncio.Task[Any]) -> None: + self.cleanup_run_id(action.get_group_key_run_id) + + errored = False + cancelled = task.cancelled() + + # Get the output from the future + try: + if not cancelled: + output = task.result() + except Exception as e: + errored = True + self.event_queue.put( + ActionEvent( + action=action, + type=GROUP_KEY_EVENT_TYPE_FAILED, + payload=str(errorWithTraceback(f"{e}", e)), + ) + ) + + logger.error( + f"failed step run: {action.action_id}/{action.step_run_id}" + ) + + if not errored and not cancelled: + self.event_queue.put( + ActionEvent( + action=action, + type=GROUP_KEY_EVENT_TYPE_COMPLETED, + payload=self.serialize_output(output), + ) + ) + + logger.info( + f"finished step run: {action.action_id}/{action.step_run_id}" + ) + + return inner_callback + + ## TODO: Stricter type hinting here + def thread_action_func( + self, context: Context, action_func: Callable[..., Any], action: Action + ) -> Any: + if action.step_run_id is not None and action.step_run_id != "": + self.threads[action.step_run_id] = current_thread() + elif ( + action.get_group_key_run_id is not None + and action.get_group_key_run_id != "" + ): + self.threads[action.get_group_key_run_id] = current_thread() + + return action_func(context) + + ## TODO: Stricter type hinting here + # We wrap all actions in an async func + async def async_wrapped_action_func( + self, + context: Context, + action_func: Callable[..., Any], + action: Action, + run_id: str, + ) -> Any: + wr.set(context.workflow_run_id()) + sr.set(context.step_run_id) + + try: + if ( + hasattr(action_func, "is_coroutine") and action_func.is_coroutine + ) or asyncio.iscoroutinefunction(action_func): + return await action_func(context) + else: + pfunc = functools.partial( + # we must copy the context vars to the new thread, as only asyncio natively supports + # contextvars + copy_context_vars, + contextvars.copy_context().items(), + self.thread_action_func, + context, + action_func, + action, + ) + + loop = asyncio.get_event_loop() + return await loop.run_in_executor(self.thread_pool, pfunc) + except Exception as e: + logger.error( + errorWithTraceback( + f"exception raised in action ({action.action_id}, retry={action.retry_count}):\n{e}", + e, + ) + ) + raise e + finally: + self.cleanup_run_id(run_id) + + def cleanup_run_id(self, run_id: str | None) -> None: + if run_id in self.tasks: + del self.tasks[run_id] + + if run_id in self.threads: + del self.threads[run_id] + + if run_id in self.contexts: + del self.contexts[run_id] + + def create_context( + self, action: Action, action_func: Callable[..., Any] | None + ) -> Context | DurableContext: + if hasattr(action_func, "durable") and getattr(action_func, "durable"): + return DurableContext( + action, + self.dispatcher_client, + self.admin_client, + self.client.event, + self.client.rest, + self.client.workflow_listener, + self.workflow_run_event_listener, + self.worker_context, + self.client.config.namespace, + validator_registry=self.validator_registry, + ) + + return Context( + action, + self.dispatcher_client, + self.admin_client, + self.client.event, + self.client.rest, + self.client.workflow_listener, + self.workflow_run_event_listener, + self.worker_context, + self.client.config.namespace, + validator_registry=self.validator_registry, + ) + + ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor + async def handle_start_step_run(self, action: Action) -> None | Exception: + action_name = action.action_id + + # Find the corresponding action function from the registry + action_func = self.action_registry.get(action_name) + + context = self.create_context(action, action_func) + + self.contexts[action.step_run_id] = context + + if action_func: + self.event_queue.put( + ActionEvent( + action=action, + type=STEP_EVENT_TYPE_STARTED, + ) + ) + + loop = asyncio.get_event_loop() + task = loop.create_task( + self.async_wrapped_action_func( + context, action_func, action, action.step_run_id + ) + ) + + task.add_done_callback(self.step_run_callback(action)) + self.tasks[action.step_run_id] = task + + try: + await task + except Exception as e: + return e + + return None + + ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor + async def handle_start_group_key_run(self, action: Action) -> Exception | None: + action_name = action.action_id + context = Context( + action, + self.dispatcher_client, + self.admin_client, + self.client.event, + self.client.rest, + self.client.workflow_listener, + self.workflow_run_event_listener, + self.worker_context, + self.client.config.namespace, + ) + + self.contexts[action.get_group_key_run_id] = context + + # Find the corresponding action function from the registry + action_func = self.action_registry.get(action_name) + + if action_func: + # send an event that the group key run has started + self.event_queue.put( + ActionEvent( + action=action, + type=GROUP_KEY_EVENT_TYPE_STARTED, + ) + ) + + loop = asyncio.get_event_loop() + task = loop.create_task( + self.async_wrapped_action_func( + context, action_func, action, action.get_group_key_run_id + ) + ) + + task.add_done_callback(self.group_key_run_callback(action)) + self.tasks[action.get_group_key_run_id] = task + + try: + await task + except Exception as e: + return e + + return None + + def force_kill_thread(self, thread: Thread) -> None: + """Terminate a python threading.Thread.""" + try: + if not thread.is_alive(): + return + + ident = cast(int, thread.ident) + + logger.info(f"Forcefully terminating thread {ident}") + + exc = ctypes.py_object(SystemExit) + res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(ident), exc) + if res == 0: + raise ValueError("Invalid thread ID") + elif res != 1: + logger.error("PyThreadState_SetAsyncExc failed") + + # Call with exception set to 0 is needed to cleanup properly. + ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0) + raise SystemError("PyThreadState_SetAsyncExc failed") + + logger.info(f"Successfully terminated thread {ident}") + + # Immediately add a new thread to the thread pool, because we've actually killed a worker + # in the ThreadPoolExecutor + self.thread_pool.submit(lambda: None) + except Exception as e: + logger.exception(f"Failed to terminate thread: {e}") + + ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor + async def handle_cancel_action(self, run_id: str) -> None: + try: + # call cancel to signal the context to stop + if run_id in self.contexts: + context = self.contexts.get(run_id) + + if context: + context.cancel() + + await asyncio.sleep(1) + + if run_id in self.tasks: + future = self.tasks.get(run_id) + + if future: + future.cancel() + + # check if thread is still running, if so, print a warning + if run_id in self.threads: + thread = self.threads.get(run_id) + if thread and self.client.config.enable_force_kill_sync_threads: + self.force_kill_thread(thread) + await asyncio.sleep(1) + + logger.warning( + f"Thread {self.threads[run_id].ident} with run id {run_id} is still running after cancellation. This could cause the thread pool to get blocked and prevent new tasks from running." + ) + finally: + self.cleanup_run_id(run_id) + + def serialize_output(self, output: Any) -> str: + + if isinstance(output, BaseModel): + return output.model_dump_json() + + if output is not None: + try: + return json.dumps(output) + except Exception as e: + logger.error(f"Could not serialize output: {e}") + return str(output) + + return "" + + async def wait_for_tasks(self) -> None: + running = len(self.tasks.keys()) + while running > 0: + logger.info(f"waiting for {running} tasks to finish...") + await asyncio.sleep(1) + running = len(self.tasks.keys()) + + +def errorWithTraceback(message: str, e: Exception) -> str: + trace = "".join(traceback.format_exception(type(e), e, e.__traceback__)) + return f"{message}\n{trace}" diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py new file mode 100644 index 00000000..08c57de8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py @@ -0,0 +1,81 @@ +import contextvars +import functools +import logging +from concurrent.futures import ThreadPoolExecutor +from io import StringIO +from typing import Any, Coroutine + +from hatchet_sdk import logger +from hatchet_sdk.clients.events import EventClient + +wr: contextvars.ContextVar[str | None] = contextvars.ContextVar( + "workflow_run_id", default=None +) +sr: contextvars.ContextVar[str | None] = contextvars.ContextVar( + "step_run_id", default=None +) + + +def copy_context_vars(ctx_vars, func, *args, **kwargs): + for var, value in ctx_vars: + var.set(value) + return func(*args, **kwargs) + + +class InjectingFilter(logging.Filter): + # For some reason, only the InjectingFilter has access to the contextvars method sr.get(), + # otherwise we would use emit within the CustomLogHandler + def filter(self, record): + record.workflow_run_id = wr.get() + record.step_run_id = sr.get() + return True + + +class CustomLogHandler(logging.StreamHandler): + def __init__(self, event_client: EventClient, stream=None): + super().__init__(stream) + self.logger_thread_pool = ThreadPoolExecutor(max_workers=1) + self.event_client = event_client + + def _log(self, line: str, step_run_id: str | None): + try: + if not step_run_id: + return + + self.event_client.log(message=line, step_run_id=step_run_id) + except Exception as e: + logger.error(f"Error logging: {e}") + + def emit(self, record): + super().emit(record) + + log_entry = self.format(record) + self.logger_thread_pool.submit(self._log, log_entry, record.step_run_id) + + +def capture_logs( + logger: logging.Logger, + event_client: EventClient, + func: Coroutine[Any, Any, Any], +): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + if not logger: + raise Exception("No logger configured on client") + + log_stream = StringIO() + custom_handler = CustomLogHandler(event_client, log_stream) + custom_handler.setLevel(logging.INFO) + custom_handler.addFilter(InjectingFilter()) + logger.addHandler(custom_handler) + + try: + result = await func(*args, **kwargs) + finally: + custom_handler.flush() + logger.removeHandler(custom_handler) + log_stream.close() + + return result + + return wrapper diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/error_with_traceback.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/error_with_traceback.py new file mode 100644 index 00000000..9c09602f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/error_with_traceback.py @@ -0,0 +1,6 @@ +import traceback + + +def errorWithTraceback(message: str, e: Exception): + trace = "".join(traceback.format_exception(type(e), e, e.__traceback__)) + return f"{message}\n{trace}" diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/worker.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/worker.py new file mode 100644 index 00000000..b6ec1531 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/worker.py @@ -0,0 +1,392 @@ +import asyncio +import multiprocessing +import multiprocessing.context +import os +import signal +import sys +from concurrent.futures import Future +from dataclasses import dataclass, field +from enum import Enum +from multiprocessing import Queue +from multiprocessing.process import BaseProcess +from types import FrameType +from typing import Any, Callable, TypeVar, get_type_hints + +from aiohttp import web +from aiohttp.web_request import Request +from aiohttp.web_response import Response +from prometheus_client import CONTENT_TYPE_LATEST, Gauge, generate_latest + +from hatchet_sdk import Context +from hatchet_sdk.client import Client, new_client_raw +from hatchet_sdk.contracts.workflows_pb2 import CreateWorkflowVersionOpts +from hatchet_sdk.loader import ClientConfig +from hatchet_sdk.logger import logger +from hatchet_sdk.utils.types import WorkflowValidator +from hatchet_sdk.utils.typing import is_basemodel_subclass +from hatchet_sdk.v2.callable import HatchetCallable +from hatchet_sdk.v2.concurrency import ConcurrencyFunction +from hatchet_sdk.worker.action_listener_process import worker_action_listener_process +from hatchet_sdk.worker.runner.run_loop_manager import WorkerActionRunLoopManager +from hatchet_sdk.workflow import WorkflowInterface + +T = TypeVar("T") + + +class WorkerStatus(Enum): + INITIALIZED = 1 + STARTING = 2 + HEALTHY = 3 + UNHEALTHY = 4 + + +@dataclass +class WorkerStartOptions: + loop: asyncio.AbstractEventLoop | None = field(default=None) + + +TWorkflow = TypeVar("TWorkflow", bound=object) + + +class Worker: + def __init__( + self, + name: str, + config: ClientConfig = ClientConfig(), + max_runs: int | None = None, + labels: dict[str, str | int] = {}, + debug: bool = False, + owned_loop: bool = True, + handle_kill: bool = True, + ) -> None: + self.name = name + self.config = config + self.max_runs = max_runs + self.debug = debug + self.labels = labels + self.handle_kill = handle_kill + self.owned_loop = owned_loop + + self.client: Client + + self.action_registry: dict[str, Callable[[Context], Any]] = {} + self.validator_registry: dict[str, WorkflowValidator] = {} + + self.killing: bool = False + self._status: WorkerStatus + + self.action_listener_process: BaseProcess + self.action_listener_health_check: asyncio.Task[Any] + self.action_runner: WorkerActionRunLoopManager + + self.ctx = multiprocessing.get_context("spawn") + + self.action_queue: "Queue[Any]" = self.ctx.Queue() + self.event_queue: "Queue[Any]" = self.ctx.Queue() + + self.loop: asyncio.AbstractEventLoop + + self.client = new_client_raw(self.config, self.debug) + self.name = self.client.config.namespace + self.name + + self._setup_signal_handlers() + + self.worker_status_gauge = Gauge( + "hatchet_worker_status", "Current status of the Hatchet worker" + ) + + def register_function(self, action: str, func: Callable[[Context], Any]) -> None: + self.action_registry[action] = func + + def register_workflow_from_opts( + self, name: str, opts: CreateWorkflowVersionOpts + ) -> None: + try: + self.client.admin.put_workflow(opts.name, opts) + except Exception as e: + logger.error(f"failed to register workflow: {opts.name}") + logger.error(e) + sys.exit(1) + + def register_workflow(self, workflow: TWorkflow) -> None: + ## Hack for typing + assert isinstance(workflow, WorkflowInterface) + + namespace = self.client.config.namespace + + try: + self.client.admin.put_workflow( + workflow.get_name(namespace), workflow.get_create_opts(namespace) + ) + except Exception as e: + logger.error(f"failed to register workflow: {workflow.get_name(namespace)}") + logger.error(e) + sys.exit(1) + + def create_action_function( + action_func: Callable[..., T] + ) -> Callable[[Context], T]: + def action_function(context: Context) -> T: + return action_func(workflow, context) + + if asyncio.iscoroutinefunction(action_func): + setattr(action_function, "is_coroutine", True) + else: + setattr(action_function, "is_coroutine", False) + + return action_function + + for action_name, action_func in workflow.get_actions(namespace): + self.action_registry[action_name] = create_action_function(action_func) + return_type = get_type_hints(action_func).get("return") + + self.validator_registry[action_name] = WorkflowValidator( + workflow_input=workflow.input_validator, + step_output=return_type if is_basemodel_subclass(return_type) else None, + ) + + def status(self) -> WorkerStatus: + return self._status + + def setup_loop(self, loop: asyncio.AbstractEventLoop | None = None) -> bool: + try: + loop = loop or asyncio.get_running_loop() + self.loop = loop + created_loop = False + logger.debug("using existing event loop") + return created_loop + except RuntimeError: + self.loop = asyncio.new_event_loop() + logger.debug("creating new event loop") + asyncio.set_event_loop(self.loop) + created_loop = True + return created_loop + + async def health_check_handler(self, request: Request) -> Response: + status = self.status() + + return web.json_response({"status": status.name}) + + async def metrics_handler(self, request: Request) -> Response: + self.worker_status_gauge.set(1 if self.status() == WorkerStatus.HEALTHY else 0) + + return web.Response(body=generate_latest(), content_type="text/plain") + + async def start_health_server(self) -> None: + port = self.config.worker_healthcheck_port or 8001 + + app = web.Application() + app.add_routes( + [ + web.get("/health", self.health_check_handler), + web.get("/metrics", self.metrics_handler), + ] + ) + + runner = web.AppRunner(app) + + try: + await runner.setup() + await web.TCPSite(runner, "0.0.0.0", port).start() + except Exception as e: + logger.error("failed to start healthcheck server") + logger.error(str(e)) + return + + logger.info(f"healthcheck server running on port {port}") + + def start( + self, options: WorkerStartOptions = WorkerStartOptions() + ) -> Future[asyncio.Task[Any] | None]: + self.owned_loop = self.setup_loop(options.loop) + + f = asyncio.run_coroutine_threadsafe( + self.async_start(options, _from_start=True), self.loop + ) + + # start the loop and wait until its closed + if self.owned_loop: + self.loop.run_forever() + + if self.handle_kill: + sys.exit(0) + + return f + + ## Start methods + async def async_start( + self, + options: WorkerStartOptions = WorkerStartOptions(), + _from_start: bool = False, + ) -> Any | None: + main_pid = os.getpid() + logger.info("------------------------------------------") + logger.info("STARTING HATCHET...") + logger.debug(f"worker runtime starting on PID: {main_pid}") + + self._status = WorkerStatus.STARTING + + if len(self.action_registry.keys()) == 0: + logger.error( + "no actions registered, register workflows or actions before starting worker" + ) + return None + + # non blocking setup + if not _from_start: + self.setup_loop(options.loop) + + if self.config.worker_healthcheck_enabled: + await self.start_health_server() + + self.action_listener_process = self._start_listener() + + self.action_runner = self._run_action_runner() + + self.action_listener_health_check = self.loop.create_task( + self._check_listener_health() + ) + + return await self.action_listener_health_check + + def _run_action_runner(self) -> WorkerActionRunLoopManager: + # Retrieve the shared queue + return WorkerActionRunLoopManager( + self.name, + self.action_registry, + self.validator_registry, + self.max_runs, + self.config, + self.action_queue, + self.event_queue, + self.loop, + self.handle_kill, + self.client.debug, + self.labels, + ) + + def _start_listener(self) -> multiprocessing.context.SpawnProcess: + action_list = [str(key) for key in self.action_registry.keys()] + + try: + process = self.ctx.Process( + target=worker_action_listener_process, + args=( + self.name, + action_list, + self.max_runs, + self.config, + self.action_queue, + self.event_queue, + self.handle_kill, + self.client.debug, + self.labels, + ), + ) + process.start() + logger.debug(f"action listener starting on PID: {process.pid}") + + return process + except Exception as e: + logger.error(f"failed to start action listener: {e}") + sys.exit(1) + + async def _check_listener_health(self) -> None: + logger.debug("starting action listener health check...") + try: + while not self.killing: + if ( + self.action_listener_process is None + or not self.action_listener_process.is_alive() + ): + logger.debug("child action listener process killed...") + self._status = WorkerStatus.UNHEALTHY + if not self.killing: + self.loop.create_task(self.exit_gracefully()) + break + else: + self._status = WorkerStatus.HEALTHY + await asyncio.sleep(1) + except Exception as e: + logger.error(f"error checking listener health: {e}") + + ## Cleanup methods + def _setup_signal_handlers(self) -> None: + signal.signal(signal.SIGTERM, self._handle_exit_signal) + signal.signal(signal.SIGINT, self._handle_exit_signal) + signal.signal(signal.SIGQUIT, self._handle_force_quit_signal) + + def _handle_exit_signal(self, signum: int, frame: FrameType | None) -> None: + sig_name = "SIGTERM" if signum == signal.SIGTERM else "SIGINT" + logger.info(f"received signal {sig_name}...") + self.loop.create_task(self.exit_gracefully()) + + def _handle_force_quit_signal(self, signum: int, frame: FrameType | None) -> None: + logger.info("received SIGQUIT...") + self.exit_forcefully() + + async def close(self) -> None: + logger.info(f"closing worker '{self.name}'...") + self.killing = True + # self.action_queue.close() + # self.event_queue.close() + + if self.action_runner is not None: + self.action_runner.cleanup() + + await self.action_listener_health_check + + async def exit_gracefully(self) -> None: + logger.debug(f"gracefully stopping worker: {self.name}") + + if self.killing: + return self.exit_forcefully() + + self.killing = True + + await self.action_runner.wait_for_tasks() + + await self.action_runner.exit_gracefully() + + if self.action_listener_process and self.action_listener_process.is_alive(): + self.action_listener_process.kill() + + await self.close() + if self.loop and self.owned_loop: + self.loop.stop() + + logger.info("👋") + + def exit_forcefully(self) -> None: + self.killing = True + + logger.debug(f"forcefully stopping worker: {self.name}") + + self.close() + + if self.action_listener_process: + self.action_listener_process.kill() # Forcefully kill the process + + logger.info("👋") + sys.exit( + 1 + ) # Exit immediately TODO - should we exit with 1 here, there may be other workers to cleanup + + +def register_on_worker(callable: HatchetCallable[T], worker: Worker) -> None: + worker.register_function(callable.get_action_name(), callable) + + if callable.function_on_failure is not None: + worker.register_function( + callable.function_on_failure.get_action_name(), callable.function_on_failure + ) + + if callable.function_concurrency is not None: + worker.register_function( + callable.function_concurrency.get_action_name(), + callable.function_concurrency, + ) + + opts = callable.to_workflow_opts() + + worker.register_workflow_from_opts(opts.name, opts) |