aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/hatchet_sdk/worker
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/worker')
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/worker/__init__.py1
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/worker/action_listener_process.py278
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/run_loop_manager.py112
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/runner.py460
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py81
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/error_with_traceback.py6
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/worker/worker.py392
7 files changed, 1330 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/__init__.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/__init__.py
new file mode 100644
index 00000000..450f0cac
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/__init__.py
@@ -0,0 +1 @@
+from .worker import Worker, WorkerStartOptions, WorkerStatus
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/action_listener_process.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/action_listener_process.py
new file mode 100644
index 00000000..08508607
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/action_listener_process.py
@@ -0,0 +1,278 @@
+import asyncio
+import logging
+import signal
+import time
+from dataclasses import dataclass, field
+from multiprocessing import Queue
+from typing import Any, List, Mapping, Optional
+
+import grpc
+
+from hatchet_sdk.clients.dispatcher.action_listener import Action
+from hatchet_sdk.clients.dispatcher.dispatcher import (
+ ActionListener,
+ GetActionListenerRequest,
+ new_dispatcher,
+)
+from hatchet_sdk.contracts.dispatcher_pb2 import (
+ GROUP_KEY_EVENT_TYPE_STARTED,
+ STEP_EVENT_TYPE_STARTED,
+ ActionType,
+)
+from hatchet_sdk.loader import ClientConfig
+from hatchet_sdk.logger import logger
+from hatchet_sdk.utils.backoff import exp_backoff_sleep
+
+ACTION_EVENT_RETRY_COUNT = 5
+
+
+@dataclass
+class ActionEvent:
+ action: Action
+ type: Any # TODO type
+ payload: Optional[str] = None
+
+
+STOP_LOOP = "STOP_LOOP" # Sentinel object to stop the loop
+
+# TODO link to a block post
+BLOCKED_THREAD_WARNING = (
+ "THE TIME TO START THE STEP RUN IS TOO LONG, THE MAIN THREAD MAY BE BLOCKED"
+)
+
+
+def noop_handler():
+ pass
+
+
+@dataclass
+class WorkerActionListenerProcess:
+ name: str
+ actions: List[str]
+ max_runs: int
+ config: ClientConfig
+ action_queue: Queue
+ event_queue: Queue
+ handle_kill: bool = True
+ debug: bool = False
+ labels: dict = field(default_factory=dict)
+
+ listener: ActionListener = field(init=False, default=None)
+
+ killing: bool = field(init=False, default=False)
+
+ action_loop_task: asyncio.Task = field(init=False, default=None)
+ event_send_loop_task: asyncio.Task = field(init=False, default=None)
+
+ running_step_runs: Mapping[str, float] = field(init=False, default_factory=dict)
+
+ def __post_init__(self):
+ if self.debug:
+ logger.setLevel(logging.DEBUG)
+
+ loop = asyncio.get_event_loop()
+ loop.add_signal_handler(signal.SIGINT, noop_handler)
+ loop.add_signal_handler(signal.SIGTERM, noop_handler)
+ loop.add_signal_handler(
+ signal.SIGQUIT, lambda: asyncio.create_task(self.exit_gracefully())
+ )
+
+ async def start(self, retry_attempt=0):
+ if retry_attempt > 5:
+ logger.error("could not start action listener")
+ return
+
+ logger.debug(f"starting action listener: {self.name}")
+
+ try:
+ self.dispatcher_client = new_dispatcher(self.config)
+
+ self.listener = await self.dispatcher_client.get_action_listener(
+ GetActionListenerRequest(
+ worker_name=self.name,
+ services=["default"],
+ actions=self.actions,
+ max_runs=self.max_runs,
+ _labels=self.labels,
+ )
+ )
+
+ logger.debug(f"acquired action listener: {self.listener.worker_id}")
+ except grpc.RpcError as rpc_error:
+ logger.error(f"could not start action listener: {rpc_error}")
+ return
+
+ # Start both loops as background tasks
+ self.action_loop_task = asyncio.create_task(self.start_action_loop())
+ self.event_send_loop_task = asyncio.create_task(self.start_event_send_loop())
+ self.blocked_main_loop = asyncio.create_task(self.start_blocked_main_loop())
+
+ # TODO move event methods to separate class
+ async def _get_event(self):
+ loop = asyncio.get_running_loop()
+ return await loop.run_in_executor(None, self.event_queue.get)
+
+ async def start_event_send_loop(self):
+ while True:
+ event: ActionEvent = await self._get_event()
+ if event == STOP_LOOP:
+ logger.debug("stopping event send loop...")
+ break
+
+ logger.debug(f"tx: event: {event.action.action_id}/{event.type}")
+ asyncio.create_task(self.send_event(event))
+
+ async def start_blocked_main_loop(self):
+ threshold = 1
+ while not self.killing:
+ count = 0
+ for step_run_id, start_time in self.running_step_runs.items():
+ diff = self.now() - start_time
+ if diff > threshold:
+ count += 1
+
+ if count > 0:
+ logger.warning(f"{BLOCKED_THREAD_WARNING}: Waiting Steps {count}")
+ await asyncio.sleep(1)
+
+ async def send_event(self, event: ActionEvent, retry_attempt: int = 1):
+ try:
+ match event.action.action_type:
+ # FIXME: all events sent from an execution of a function are of type ActionType.START_STEP_RUN since
+ # the action is re-used. We should change this.
+ case ActionType.START_STEP_RUN:
+ # TODO right now we're sending two start_step_run events
+ # one on the action loop and one on the event loop
+ # ideally we change the first to an ack to set the time
+ if event.type == STEP_EVENT_TYPE_STARTED:
+ if event.action.step_run_id in self.running_step_runs:
+ diff = (
+ self.now()
+ - self.running_step_runs[event.action.step_run_id]
+ )
+ if diff > 0.1:
+ logger.warning(
+ f"{BLOCKED_THREAD_WARNING}: time to start: {diff}s"
+ )
+ else:
+ logger.debug(f"start time: {diff}")
+ del self.running_step_runs[event.action.step_run_id]
+ else:
+ self.running_step_runs[event.action.step_run_id] = (
+ self.now()
+ )
+
+ asyncio.create_task(
+ self.dispatcher_client.send_step_action_event(
+ event.action, event.type, event.payload
+ )
+ )
+ case ActionType.CANCEL_STEP_RUN:
+ logger.debug("unimplemented event send")
+ case ActionType.START_GET_GROUP_KEY:
+ asyncio.create_task(
+ self.dispatcher_client.send_group_key_action_event(
+ event.action, event.type, event.payload
+ )
+ )
+ case _:
+ logger.error("unknown action type for event send")
+ except Exception as e:
+ logger.error(
+ f"could not send action event ({retry_attempt}/{ACTION_EVENT_RETRY_COUNT}): {e}"
+ )
+ if retry_attempt <= ACTION_EVENT_RETRY_COUNT:
+ await exp_backoff_sleep(retry_attempt, 1)
+ await self.send_event(event, retry_attempt + 1)
+
+ def now(self):
+ return time.time()
+
+ async def start_action_loop(self):
+ try:
+ async for action in self.listener:
+ if action is None:
+ break
+
+ # Process the action here
+ match action.action_type:
+ case ActionType.START_STEP_RUN:
+ self.event_queue.put(
+ ActionEvent(
+ action=action,
+ type=STEP_EVENT_TYPE_STARTED, # TODO ack type
+ )
+ )
+ logger.info(
+ f"rx: start step run: {action.step_run_id}/{action.action_id}"
+ )
+
+ # TODO handle this case better...
+ if action.step_run_id in self.running_step_runs:
+ logger.warning(
+ f"step run already running: {action.step_run_id}"
+ )
+
+ case ActionType.CANCEL_STEP_RUN:
+ logger.info(f"rx: cancel step run: {action.step_run_id}")
+ case ActionType.START_GET_GROUP_KEY:
+ self.event_queue.put(
+ ActionEvent(
+ action=action,
+ type=GROUP_KEY_EVENT_TYPE_STARTED, # TODO ack type
+ )
+ )
+ logger.info(
+ f"rx: start group key: {action.get_group_key_run_id}"
+ )
+ case _:
+ logger.error(
+ f"rx: unknown action type ({action.action_type}): {action.action_type}"
+ )
+ try:
+ self.action_queue.put(action)
+ except Exception as e:
+ logger.error(f"error putting action: {e}")
+
+ except Exception as e:
+ logger.error(f"error in action loop: {e}")
+ finally:
+ logger.info("action loop closed")
+ if not self.killing:
+ await self.exit_gracefully(skip_unregister=True)
+
+ async def cleanup(self):
+ self.killing = True
+
+ if self.listener is not None:
+ self.listener.cleanup()
+
+ self.event_queue.put(STOP_LOOP)
+
+ async def exit_gracefully(self, skip_unregister=False):
+ if self.killing:
+ return
+
+ logger.debug("closing action listener...")
+
+ await self.cleanup()
+
+ while not self.event_queue.empty():
+ pass
+
+ logger.info("action listener closed")
+
+ def exit_forcefully(self):
+ asyncio.run(self.cleanup())
+ logger.debug("forcefully closing listener...")
+
+
+def worker_action_listener_process(*args, **kwargs):
+ async def run():
+ process = WorkerActionListenerProcess(*args, **kwargs)
+ await process.start()
+ # Keep the process running
+ while not process.killing:
+ await asyncio.sleep(0.1)
+
+ asyncio.run(run())
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/run_loop_manager.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/run_loop_manager.py
new file mode 100644
index 00000000..27ed788c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/run_loop_manager.py
@@ -0,0 +1,112 @@
+import asyncio
+import logging
+from dataclasses import dataclass, field
+from multiprocessing import Queue
+from typing import Callable, TypeVar
+
+from hatchet_sdk import Context
+from hatchet_sdk.client import Client, new_client_raw
+from hatchet_sdk.clients.dispatcher.action_listener import Action
+from hatchet_sdk.loader import ClientConfig
+from hatchet_sdk.logger import logger
+from hatchet_sdk.utils.types import WorkflowValidator
+from hatchet_sdk.worker.runner.runner import Runner
+from hatchet_sdk.worker.runner.utils.capture_logs import capture_logs
+
+STOP_LOOP = "STOP_LOOP"
+
+T = TypeVar("T")
+
+
+@dataclass
+class WorkerActionRunLoopManager:
+ name: str
+ action_registry: dict[str, Callable[[Context], T]]
+ validator_registry: dict[str, WorkflowValidator]
+ max_runs: int | None
+ config: ClientConfig
+ action_queue: Queue
+ event_queue: Queue
+ loop: asyncio.AbstractEventLoop
+ handle_kill: bool = True
+ debug: bool = False
+ labels: dict[str, str | int] = field(default_factory=dict)
+
+ client: Client = field(init=False, default=None)
+
+ killing: bool = field(init=False, default=False)
+ runner: Runner = field(init=False, default=None)
+
+ def __post_init__(self):
+ if self.debug:
+ logger.setLevel(logging.DEBUG)
+ self.client = new_client_raw(self.config, self.debug)
+ self.start()
+
+ def start(self, retry_count=1):
+ k = self.loop.create_task(self.async_start(retry_count))
+
+ async def async_start(self, retry_count=1):
+ await capture_logs(
+ self.client.logInterceptor,
+ self.client.event,
+ self._async_start,
+ )(retry_count=retry_count)
+
+ async def _async_start(self, retry_count: int = 1) -> None:
+ logger.info("starting runner...")
+ self.loop = asyncio.get_running_loop()
+ # needed for graceful termination
+ k = self.loop.create_task(self._start_action_loop())
+ await k
+
+ def cleanup(self) -> None:
+ self.killing = True
+
+ self.action_queue.put(STOP_LOOP)
+
+ async def wait_for_tasks(self) -> None:
+ if self.runner:
+ await self.runner.wait_for_tasks()
+
+ async def _start_action_loop(self) -> None:
+ self.runner = Runner(
+ self.name,
+ self.event_queue,
+ self.max_runs,
+ self.handle_kill,
+ self.action_registry,
+ self.validator_registry,
+ self.config,
+ self.labels,
+ )
+
+ logger.debug(f"'{self.name}' waiting for {list(self.action_registry.keys())}")
+ while not self.killing:
+ action: Action = await self._get_action()
+ if action == STOP_LOOP:
+ logger.debug("stopping action runner loop...")
+ break
+
+ self.runner.run(action)
+ logger.debug("action runner loop stopped")
+
+ async def _get_action(self):
+ return await self.loop.run_in_executor(None, self.action_queue.get)
+
+ async def exit_gracefully(self) -> None:
+ if self.killing:
+ return
+
+ logger.info("gracefully exiting runner...")
+
+ self.cleanup()
+
+ # Wait for 1 second to allow last calls to flush. These are calls which have been
+ # added to the event loop as callbacks to tasks, so we're not aware of them in the
+ # task list.
+ await asyncio.sleep(1)
+
+ def exit_forcefully(self) -> None:
+ logger.info("forcefully exiting runner...")
+ self.cleanup()
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/runner.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/runner.py
new file mode 100644
index 00000000..01e61bcf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/runner.py
@@ -0,0 +1,460 @@
+import asyncio
+import contextvars
+import ctypes
+import functools
+import json
+import time
+import traceback
+from concurrent.futures import ThreadPoolExecutor
+from enum import Enum
+from multiprocessing import Queue
+from threading import Thread, current_thread
+from typing import Any, Callable, Dict, cast
+
+from pydantic import BaseModel
+
+from hatchet_sdk.client import new_client_raw
+from hatchet_sdk.clients.admin import new_admin
+from hatchet_sdk.clients.dispatcher.action_listener import Action
+from hatchet_sdk.clients.dispatcher.dispatcher import new_dispatcher
+from hatchet_sdk.clients.run_event_listener import new_listener
+from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
+from hatchet_sdk.context import Context # type: ignore[attr-defined]
+from hatchet_sdk.context.worker_context import WorkerContext
+from hatchet_sdk.contracts.dispatcher_pb2 import (
+ GROUP_KEY_EVENT_TYPE_COMPLETED,
+ GROUP_KEY_EVENT_TYPE_FAILED,
+ GROUP_KEY_EVENT_TYPE_STARTED,
+ STEP_EVENT_TYPE_COMPLETED,
+ STEP_EVENT_TYPE_FAILED,
+ STEP_EVENT_TYPE_STARTED,
+ ActionType,
+)
+from hatchet_sdk.loader import ClientConfig
+from hatchet_sdk.logger import logger
+from hatchet_sdk.utils.types import WorkflowValidator
+from hatchet_sdk.v2.callable import DurableContext
+from hatchet_sdk.worker.action_listener_process import ActionEvent
+from hatchet_sdk.worker.runner.utils.capture_logs import copy_context_vars, sr, wr
+
+
+class WorkerStatus(Enum):
+ INITIALIZED = 1
+ STARTING = 2
+ HEALTHY = 3
+ UNHEALTHY = 4
+
+
+class Runner:
+ def __init__(
+ self,
+ name: str,
+ event_queue: "Queue[Any]",
+ max_runs: int | None = None,
+ handle_kill: bool = True,
+ action_registry: dict[str, Callable[..., Any]] = {},
+ validator_registry: dict[str, WorkflowValidator] = {},
+ config: ClientConfig = ClientConfig(),
+ labels: dict[str, str | int] = {},
+ ):
+ # We store the config so we can dynamically create clients for the dispatcher client.
+ self.config = config
+ self.client = new_client_raw(config)
+ self.name = self.client.config.namespace + name
+ self.max_runs = max_runs
+ self.tasks: dict[str, asyncio.Task[Any]] = {} # Store run ids and futures
+ self.contexts: dict[str, Context] = {} # Store run ids and contexts
+ self.action_registry: dict[str, Callable[..., Any]] = action_registry
+ self.validator_registry = validator_registry
+
+ self.event_queue = event_queue
+
+ # The thread pool is used for synchronous functions which need to run concurrently
+ self.thread_pool = ThreadPoolExecutor(max_workers=max_runs)
+ self.threads: Dict[str, Thread] = {} # Store run ids and threads
+
+ self.killing = False
+ self.handle_kill = handle_kill
+
+ # We need to initialize a new admin and dispatcher client *after* we've started the event loop,
+ # otherwise the grpc.aio methods will use a different event loop and we'll get a bunch of errors.
+ self.dispatcher_client = new_dispatcher(self.config)
+ self.admin_client = new_admin(self.config)
+ self.workflow_run_event_listener = new_listener(self.config)
+ self.client.workflow_listener = PooledWorkflowRunListener(self.config)
+
+ self.worker_context = WorkerContext(
+ labels=labels, client=new_client_raw(config).dispatcher
+ )
+
+ def create_workflow_run_url(self, action: Action) -> str:
+ return f"{self.config.server_url}/workflow-runs/{action.workflow_run_id}?tenant={action.tenant_id}"
+
+ def run(self, action: Action) -> None:
+ if self.worker_context.id() is None:
+ self.worker_context._worker_id = action.worker_id
+
+ match action.action_type:
+ case ActionType.START_STEP_RUN:
+ log = f"run: start step: {action.action_id}/{action.step_run_id}"
+ logger.info(log)
+ asyncio.create_task(self.handle_start_step_run(action))
+ case ActionType.CANCEL_STEP_RUN:
+ log = f"cancel: step run: {action.action_id}/{action.step_run_id}"
+ logger.info(log)
+ asyncio.create_task(self.handle_cancel_action(action.step_run_id))
+ case ActionType.START_GET_GROUP_KEY:
+ log = f"run: get group key: {action.action_id}/{action.get_group_key_run_id}"
+ logger.info(log)
+ asyncio.create_task(self.handle_start_group_key_run(action))
+ case _:
+ log = f"unknown action type: {action.action_type}"
+ logger.error(log)
+
+ def step_run_callback(self, action: Action) -> Callable[[asyncio.Task[Any]], None]:
+ def inner_callback(task: asyncio.Task[Any]) -> None:
+ self.cleanup_run_id(action.step_run_id)
+
+ errored = False
+ cancelled = task.cancelled()
+
+ # Get the output from the future
+ try:
+ if not cancelled:
+ output = task.result()
+ except Exception as e:
+ errored = True
+
+ # This except is coming from the application itself, so we want to send that to the Hatchet instance
+ self.event_queue.put(
+ ActionEvent(
+ action=action,
+ type=STEP_EVENT_TYPE_FAILED,
+ payload=str(errorWithTraceback(f"{e}", e)),
+ )
+ )
+
+ logger.error(
+ f"failed step run: {action.action_id}/{action.step_run_id}"
+ )
+
+ if not errored and not cancelled:
+ self.event_queue.put(
+ ActionEvent(
+ action=action,
+ type=STEP_EVENT_TYPE_COMPLETED,
+ payload=self.serialize_output(output),
+ )
+ )
+
+ logger.info(
+ f"finished step run: {action.action_id}/{action.step_run_id}"
+ )
+
+ return inner_callback
+
+ def group_key_run_callback(
+ self, action: Action
+ ) -> Callable[[asyncio.Task[Any]], None]:
+ def inner_callback(task: asyncio.Task[Any]) -> None:
+ self.cleanup_run_id(action.get_group_key_run_id)
+
+ errored = False
+ cancelled = task.cancelled()
+
+ # Get the output from the future
+ try:
+ if not cancelled:
+ output = task.result()
+ except Exception as e:
+ errored = True
+ self.event_queue.put(
+ ActionEvent(
+ action=action,
+ type=GROUP_KEY_EVENT_TYPE_FAILED,
+ payload=str(errorWithTraceback(f"{e}", e)),
+ )
+ )
+
+ logger.error(
+ f"failed step run: {action.action_id}/{action.step_run_id}"
+ )
+
+ if not errored and not cancelled:
+ self.event_queue.put(
+ ActionEvent(
+ action=action,
+ type=GROUP_KEY_EVENT_TYPE_COMPLETED,
+ payload=self.serialize_output(output),
+ )
+ )
+
+ logger.info(
+ f"finished step run: {action.action_id}/{action.step_run_id}"
+ )
+
+ return inner_callback
+
+ ## TODO: Stricter type hinting here
+ def thread_action_func(
+ self, context: Context, action_func: Callable[..., Any], action: Action
+ ) -> Any:
+ if action.step_run_id is not None and action.step_run_id != "":
+ self.threads[action.step_run_id] = current_thread()
+ elif (
+ action.get_group_key_run_id is not None
+ and action.get_group_key_run_id != ""
+ ):
+ self.threads[action.get_group_key_run_id] = current_thread()
+
+ return action_func(context)
+
+ ## TODO: Stricter type hinting here
+ # We wrap all actions in an async func
+ async def async_wrapped_action_func(
+ self,
+ context: Context,
+ action_func: Callable[..., Any],
+ action: Action,
+ run_id: str,
+ ) -> Any:
+ wr.set(context.workflow_run_id())
+ sr.set(context.step_run_id)
+
+ try:
+ if (
+ hasattr(action_func, "is_coroutine") and action_func.is_coroutine
+ ) or asyncio.iscoroutinefunction(action_func):
+ return await action_func(context)
+ else:
+ pfunc = functools.partial(
+ # we must copy the context vars to the new thread, as only asyncio natively supports
+ # contextvars
+ copy_context_vars,
+ contextvars.copy_context().items(),
+ self.thread_action_func,
+ context,
+ action_func,
+ action,
+ )
+
+ loop = asyncio.get_event_loop()
+ return await loop.run_in_executor(self.thread_pool, pfunc)
+ except Exception as e:
+ logger.error(
+ errorWithTraceback(
+ f"exception raised in action ({action.action_id}, retry={action.retry_count}):\n{e}",
+ e,
+ )
+ )
+ raise e
+ finally:
+ self.cleanup_run_id(run_id)
+
+ def cleanup_run_id(self, run_id: str | None) -> None:
+ if run_id in self.tasks:
+ del self.tasks[run_id]
+
+ if run_id in self.threads:
+ del self.threads[run_id]
+
+ if run_id in self.contexts:
+ del self.contexts[run_id]
+
+ def create_context(
+ self, action: Action, action_func: Callable[..., Any] | None
+ ) -> Context | DurableContext:
+ if hasattr(action_func, "durable") and getattr(action_func, "durable"):
+ return DurableContext(
+ action,
+ self.dispatcher_client,
+ self.admin_client,
+ self.client.event,
+ self.client.rest,
+ self.client.workflow_listener,
+ self.workflow_run_event_listener,
+ self.worker_context,
+ self.client.config.namespace,
+ validator_registry=self.validator_registry,
+ )
+
+ return Context(
+ action,
+ self.dispatcher_client,
+ self.admin_client,
+ self.client.event,
+ self.client.rest,
+ self.client.workflow_listener,
+ self.workflow_run_event_listener,
+ self.worker_context,
+ self.client.config.namespace,
+ validator_registry=self.validator_registry,
+ )
+
+ ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
+ async def handle_start_step_run(self, action: Action) -> None | Exception:
+ action_name = action.action_id
+
+ # Find the corresponding action function from the registry
+ action_func = self.action_registry.get(action_name)
+
+ context = self.create_context(action, action_func)
+
+ self.contexts[action.step_run_id] = context
+
+ if action_func:
+ self.event_queue.put(
+ ActionEvent(
+ action=action,
+ type=STEP_EVENT_TYPE_STARTED,
+ )
+ )
+
+ loop = asyncio.get_event_loop()
+ task = loop.create_task(
+ self.async_wrapped_action_func(
+ context, action_func, action, action.step_run_id
+ )
+ )
+
+ task.add_done_callback(self.step_run_callback(action))
+ self.tasks[action.step_run_id] = task
+
+ try:
+ await task
+ except Exception as e:
+ return e
+
+ return None
+
+ ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
+ async def handle_start_group_key_run(self, action: Action) -> Exception | None:
+ action_name = action.action_id
+ context = Context(
+ action,
+ self.dispatcher_client,
+ self.admin_client,
+ self.client.event,
+ self.client.rest,
+ self.client.workflow_listener,
+ self.workflow_run_event_listener,
+ self.worker_context,
+ self.client.config.namespace,
+ )
+
+ self.contexts[action.get_group_key_run_id] = context
+
+ # Find the corresponding action function from the registry
+ action_func = self.action_registry.get(action_name)
+
+ if action_func:
+ # send an event that the group key run has started
+ self.event_queue.put(
+ ActionEvent(
+ action=action,
+ type=GROUP_KEY_EVENT_TYPE_STARTED,
+ )
+ )
+
+ loop = asyncio.get_event_loop()
+ task = loop.create_task(
+ self.async_wrapped_action_func(
+ context, action_func, action, action.get_group_key_run_id
+ )
+ )
+
+ task.add_done_callback(self.group_key_run_callback(action))
+ self.tasks[action.get_group_key_run_id] = task
+
+ try:
+ await task
+ except Exception as e:
+ return e
+
+ return None
+
+ def force_kill_thread(self, thread: Thread) -> None:
+ """Terminate a python threading.Thread."""
+ try:
+ if not thread.is_alive():
+ return
+
+ ident = cast(int, thread.ident)
+
+ logger.info(f"Forcefully terminating thread {ident}")
+
+ exc = ctypes.py_object(SystemExit)
+ res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(ident), exc)
+ if res == 0:
+ raise ValueError("Invalid thread ID")
+ elif res != 1:
+ logger.error("PyThreadState_SetAsyncExc failed")
+
+ # Call with exception set to 0 is needed to cleanup properly.
+ ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0)
+ raise SystemError("PyThreadState_SetAsyncExc failed")
+
+ logger.info(f"Successfully terminated thread {ident}")
+
+ # Immediately add a new thread to the thread pool, because we've actually killed a worker
+ # in the ThreadPoolExecutor
+ self.thread_pool.submit(lambda: None)
+ except Exception as e:
+ logger.exception(f"Failed to terminate thread: {e}")
+
+ ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
+ async def handle_cancel_action(self, run_id: str) -> None:
+ try:
+ # call cancel to signal the context to stop
+ if run_id in self.contexts:
+ context = self.contexts.get(run_id)
+
+ if context:
+ context.cancel()
+
+ await asyncio.sleep(1)
+
+ if run_id in self.tasks:
+ future = self.tasks.get(run_id)
+
+ if future:
+ future.cancel()
+
+ # check if thread is still running, if so, print a warning
+ if run_id in self.threads:
+ thread = self.threads.get(run_id)
+ if thread and self.client.config.enable_force_kill_sync_threads:
+ self.force_kill_thread(thread)
+ await asyncio.sleep(1)
+
+ logger.warning(
+ f"Thread {self.threads[run_id].ident} with run id {run_id} is still running after cancellation. This could cause the thread pool to get blocked and prevent new tasks from running."
+ )
+ finally:
+ self.cleanup_run_id(run_id)
+
+ def serialize_output(self, output: Any) -> str:
+
+ if isinstance(output, BaseModel):
+ return output.model_dump_json()
+
+ if output is not None:
+ try:
+ return json.dumps(output)
+ except Exception as e:
+ logger.error(f"Could not serialize output: {e}")
+ return str(output)
+
+ return ""
+
+ async def wait_for_tasks(self) -> None:
+ running = len(self.tasks.keys())
+ while running > 0:
+ logger.info(f"waiting for {running} tasks to finish...")
+ await asyncio.sleep(1)
+ running = len(self.tasks.keys())
+
+
+def errorWithTraceback(message: str, e: Exception) -> str:
+ trace = "".join(traceback.format_exception(type(e), e, e.__traceback__))
+ return f"{message}\n{trace}"
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py
new file mode 100644
index 00000000..08c57de8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py
@@ -0,0 +1,81 @@
+import contextvars
+import functools
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from io import StringIO
+from typing import Any, Coroutine
+
+from hatchet_sdk import logger
+from hatchet_sdk.clients.events import EventClient
+
+wr: contextvars.ContextVar[str | None] = contextvars.ContextVar(
+ "workflow_run_id", default=None
+)
+sr: contextvars.ContextVar[str | None] = contextvars.ContextVar(
+ "step_run_id", default=None
+)
+
+
+def copy_context_vars(ctx_vars, func, *args, **kwargs):
+ for var, value in ctx_vars:
+ var.set(value)
+ return func(*args, **kwargs)
+
+
+class InjectingFilter(logging.Filter):
+ # For some reason, only the InjectingFilter has access to the contextvars method sr.get(),
+ # otherwise we would use emit within the CustomLogHandler
+ def filter(self, record):
+ record.workflow_run_id = wr.get()
+ record.step_run_id = sr.get()
+ return True
+
+
+class CustomLogHandler(logging.StreamHandler):
+ def __init__(self, event_client: EventClient, stream=None):
+ super().__init__(stream)
+ self.logger_thread_pool = ThreadPoolExecutor(max_workers=1)
+ self.event_client = event_client
+
+ def _log(self, line: str, step_run_id: str | None):
+ try:
+ if not step_run_id:
+ return
+
+ self.event_client.log(message=line, step_run_id=step_run_id)
+ except Exception as e:
+ logger.error(f"Error logging: {e}")
+
+ def emit(self, record):
+ super().emit(record)
+
+ log_entry = self.format(record)
+ self.logger_thread_pool.submit(self._log, log_entry, record.step_run_id)
+
+
+def capture_logs(
+ logger: logging.Logger,
+ event_client: EventClient,
+ func: Coroutine[Any, Any, Any],
+):
+ @functools.wraps(func)
+ async def wrapper(*args, **kwargs):
+ if not logger:
+ raise Exception("No logger configured on client")
+
+ log_stream = StringIO()
+ custom_handler = CustomLogHandler(event_client, log_stream)
+ custom_handler.setLevel(logging.INFO)
+ custom_handler.addFilter(InjectingFilter())
+ logger.addHandler(custom_handler)
+
+ try:
+ result = await func(*args, **kwargs)
+ finally:
+ custom_handler.flush()
+ logger.removeHandler(custom_handler)
+ log_stream.close()
+
+ return result
+
+ return wrapper
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/error_with_traceback.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/error_with_traceback.py
new file mode 100644
index 00000000..9c09602f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/error_with_traceback.py
@@ -0,0 +1,6 @@
+import traceback
+
+
+def errorWithTraceback(message: str, e: Exception):
+ trace = "".join(traceback.format_exception(type(e), e, e.__traceback__))
+ return f"{message}\n{trace}"
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/worker.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/worker.py
new file mode 100644
index 00000000..b6ec1531
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/worker.py
@@ -0,0 +1,392 @@
+import asyncio
+import multiprocessing
+import multiprocessing.context
+import os
+import signal
+import sys
+from concurrent.futures import Future
+from dataclasses import dataclass, field
+from enum import Enum
+from multiprocessing import Queue
+from multiprocessing.process import BaseProcess
+from types import FrameType
+from typing import Any, Callable, TypeVar, get_type_hints
+
+from aiohttp import web
+from aiohttp.web_request import Request
+from aiohttp.web_response import Response
+from prometheus_client import CONTENT_TYPE_LATEST, Gauge, generate_latest
+
+from hatchet_sdk import Context
+from hatchet_sdk.client import Client, new_client_raw
+from hatchet_sdk.contracts.workflows_pb2 import CreateWorkflowVersionOpts
+from hatchet_sdk.loader import ClientConfig
+from hatchet_sdk.logger import logger
+from hatchet_sdk.utils.types import WorkflowValidator
+from hatchet_sdk.utils.typing import is_basemodel_subclass
+from hatchet_sdk.v2.callable import HatchetCallable
+from hatchet_sdk.v2.concurrency import ConcurrencyFunction
+from hatchet_sdk.worker.action_listener_process import worker_action_listener_process
+from hatchet_sdk.worker.runner.run_loop_manager import WorkerActionRunLoopManager
+from hatchet_sdk.workflow import WorkflowInterface
+
+T = TypeVar("T")
+
+
+class WorkerStatus(Enum):
+ INITIALIZED = 1
+ STARTING = 2
+ HEALTHY = 3
+ UNHEALTHY = 4
+
+
+@dataclass
+class WorkerStartOptions:
+ loop: asyncio.AbstractEventLoop | None = field(default=None)
+
+
+TWorkflow = TypeVar("TWorkflow", bound=object)
+
+
+class Worker:
+ def __init__(
+ self,
+ name: str,
+ config: ClientConfig = ClientConfig(),
+ max_runs: int | None = None,
+ labels: dict[str, str | int] = {},
+ debug: bool = False,
+ owned_loop: bool = True,
+ handle_kill: bool = True,
+ ) -> None:
+ self.name = name
+ self.config = config
+ self.max_runs = max_runs
+ self.debug = debug
+ self.labels = labels
+ self.handle_kill = handle_kill
+ self.owned_loop = owned_loop
+
+ self.client: Client
+
+ self.action_registry: dict[str, Callable[[Context], Any]] = {}
+ self.validator_registry: dict[str, WorkflowValidator] = {}
+
+ self.killing: bool = False
+ self._status: WorkerStatus
+
+ self.action_listener_process: BaseProcess
+ self.action_listener_health_check: asyncio.Task[Any]
+ self.action_runner: WorkerActionRunLoopManager
+
+ self.ctx = multiprocessing.get_context("spawn")
+
+ self.action_queue: "Queue[Any]" = self.ctx.Queue()
+ self.event_queue: "Queue[Any]" = self.ctx.Queue()
+
+ self.loop: asyncio.AbstractEventLoop
+
+ self.client = new_client_raw(self.config, self.debug)
+ self.name = self.client.config.namespace + self.name
+
+ self._setup_signal_handlers()
+
+ self.worker_status_gauge = Gauge(
+ "hatchet_worker_status", "Current status of the Hatchet worker"
+ )
+
+ def register_function(self, action: str, func: Callable[[Context], Any]) -> None:
+ self.action_registry[action] = func
+
+ def register_workflow_from_opts(
+ self, name: str, opts: CreateWorkflowVersionOpts
+ ) -> None:
+ try:
+ self.client.admin.put_workflow(opts.name, opts)
+ except Exception as e:
+ logger.error(f"failed to register workflow: {opts.name}")
+ logger.error(e)
+ sys.exit(1)
+
+ def register_workflow(self, workflow: TWorkflow) -> None:
+ ## Hack for typing
+ assert isinstance(workflow, WorkflowInterface)
+
+ namespace = self.client.config.namespace
+
+ try:
+ self.client.admin.put_workflow(
+ workflow.get_name(namespace), workflow.get_create_opts(namespace)
+ )
+ except Exception as e:
+ logger.error(f"failed to register workflow: {workflow.get_name(namespace)}")
+ logger.error(e)
+ sys.exit(1)
+
+ def create_action_function(
+ action_func: Callable[..., T]
+ ) -> Callable[[Context], T]:
+ def action_function(context: Context) -> T:
+ return action_func(workflow, context)
+
+ if asyncio.iscoroutinefunction(action_func):
+ setattr(action_function, "is_coroutine", True)
+ else:
+ setattr(action_function, "is_coroutine", False)
+
+ return action_function
+
+ for action_name, action_func in workflow.get_actions(namespace):
+ self.action_registry[action_name] = create_action_function(action_func)
+ return_type = get_type_hints(action_func).get("return")
+
+ self.validator_registry[action_name] = WorkflowValidator(
+ workflow_input=workflow.input_validator,
+ step_output=return_type if is_basemodel_subclass(return_type) else None,
+ )
+
+ def status(self) -> WorkerStatus:
+ return self._status
+
+ def setup_loop(self, loop: asyncio.AbstractEventLoop | None = None) -> bool:
+ try:
+ loop = loop or asyncio.get_running_loop()
+ self.loop = loop
+ created_loop = False
+ logger.debug("using existing event loop")
+ return created_loop
+ except RuntimeError:
+ self.loop = asyncio.new_event_loop()
+ logger.debug("creating new event loop")
+ asyncio.set_event_loop(self.loop)
+ created_loop = True
+ return created_loop
+
+ async def health_check_handler(self, request: Request) -> Response:
+ status = self.status()
+
+ return web.json_response({"status": status.name})
+
+ async def metrics_handler(self, request: Request) -> Response:
+ self.worker_status_gauge.set(1 if self.status() == WorkerStatus.HEALTHY else 0)
+
+ return web.Response(body=generate_latest(), content_type="text/plain")
+
+ async def start_health_server(self) -> None:
+ port = self.config.worker_healthcheck_port or 8001
+
+ app = web.Application()
+ app.add_routes(
+ [
+ web.get("/health", self.health_check_handler),
+ web.get("/metrics", self.metrics_handler),
+ ]
+ )
+
+ runner = web.AppRunner(app)
+
+ try:
+ await runner.setup()
+ await web.TCPSite(runner, "0.0.0.0", port).start()
+ except Exception as e:
+ logger.error("failed to start healthcheck server")
+ logger.error(str(e))
+ return
+
+ logger.info(f"healthcheck server running on port {port}")
+
+ def start(
+ self, options: WorkerStartOptions = WorkerStartOptions()
+ ) -> Future[asyncio.Task[Any] | None]:
+ self.owned_loop = self.setup_loop(options.loop)
+
+ f = asyncio.run_coroutine_threadsafe(
+ self.async_start(options, _from_start=True), self.loop
+ )
+
+ # start the loop and wait until its closed
+ if self.owned_loop:
+ self.loop.run_forever()
+
+ if self.handle_kill:
+ sys.exit(0)
+
+ return f
+
+ ## Start methods
+ async def async_start(
+ self,
+ options: WorkerStartOptions = WorkerStartOptions(),
+ _from_start: bool = False,
+ ) -> Any | None:
+ main_pid = os.getpid()
+ logger.info("------------------------------------------")
+ logger.info("STARTING HATCHET...")
+ logger.debug(f"worker runtime starting on PID: {main_pid}")
+
+ self._status = WorkerStatus.STARTING
+
+ if len(self.action_registry.keys()) == 0:
+ logger.error(
+ "no actions registered, register workflows or actions before starting worker"
+ )
+ return None
+
+ # non blocking setup
+ if not _from_start:
+ self.setup_loop(options.loop)
+
+ if self.config.worker_healthcheck_enabled:
+ await self.start_health_server()
+
+ self.action_listener_process = self._start_listener()
+
+ self.action_runner = self._run_action_runner()
+
+ self.action_listener_health_check = self.loop.create_task(
+ self._check_listener_health()
+ )
+
+ return await self.action_listener_health_check
+
+ def _run_action_runner(self) -> WorkerActionRunLoopManager:
+ # Retrieve the shared queue
+ return WorkerActionRunLoopManager(
+ self.name,
+ self.action_registry,
+ self.validator_registry,
+ self.max_runs,
+ self.config,
+ self.action_queue,
+ self.event_queue,
+ self.loop,
+ self.handle_kill,
+ self.client.debug,
+ self.labels,
+ )
+
+ def _start_listener(self) -> multiprocessing.context.SpawnProcess:
+ action_list = [str(key) for key in self.action_registry.keys()]
+
+ try:
+ process = self.ctx.Process(
+ target=worker_action_listener_process,
+ args=(
+ self.name,
+ action_list,
+ self.max_runs,
+ self.config,
+ self.action_queue,
+ self.event_queue,
+ self.handle_kill,
+ self.client.debug,
+ self.labels,
+ ),
+ )
+ process.start()
+ logger.debug(f"action listener starting on PID: {process.pid}")
+
+ return process
+ except Exception as e:
+ logger.error(f"failed to start action listener: {e}")
+ sys.exit(1)
+
+ async def _check_listener_health(self) -> None:
+ logger.debug("starting action listener health check...")
+ try:
+ while not self.killing:
+ if (
+ self.action_listener_process is None
+ or not self.action_listener_process.is_alive()
+ ):
+ logger.debug("child action listener process killed...")
+ self._status = WorkerStatus.UNHEALTHY
+ if not self.killing:
+ self.loop.create_task(self.exit_gracefully())
+ break
+ else:
+ self._status = WorkerStatus.HEALTHY
+ await asyncio.sleep(1)
+ except Exception as e:
+ logger.error(f"error checking listener health: {e}")
+
+ ## Cleanup methods
+ def _setup_signal_handlers(self) -> None:
+ signal.signal(signal.SIGTERM, self._handle_exit_signal)
+ signal.signal(signal.SIGINT, self._handle_exit_signal)
+ signal.signal(signal.SIGQUIT, self._handle_force_quit_signal)
+
+ def _handle_exit_signal(self, signum: int, frame: FrameType | None) -> None:
+ sig_name = "SIGTERM" if signum == signal.SIGTERM else "SIGINT"
+ logger.info(f"received signal {sig_name}...")
+ self.loop.create_task(self.exit_gracefully())
+
+ def _handle_force_quit_signal(self, signum: int, frame: FrameType | None) -> None:
+ logger.info("received SIGQUIT...")
+ self.exit_forcefully()
+
+ async def close(self) -> None:
+ logger.info(f"closing worker '{self.name}'...")
+ self.killing = True
+ # self.action_queue.close()
+ # self.event_queue.close()
+
+ if self.action_runner is not None:
+ self.action_runner.cleanup()
+
+ await self.action_listener_health_check
+
+ async def exit_gracefully(self) -> None:
+ logger.debug(f"gracefully stopping worker: {self.name}")
+
+ if self.killing:
+ return self.exit_forcefully()
+
+ self.killing = True
+
+ await self.action_runner.wait_for_tasks()
+
+ await self.action_runner.exit_gracefully()
+
+ if self.action_listener_process and self.action_listener_process.is_alive():
+ self.action_listener_process.kill()
+
+ await self.close()
+ if self.loop and self.owned_loop:
+ self.loop.stop()
+
+ logger.info("👋")
+
+ def exit_forcefully(self) -> None:
+ self.killing = True
+
+ logger.debug(f"forcefully stopping worker: {self.name}")
+
+ self.close()
+
+ if self.action_listener_process:
+ self.action_listener_process.kill() # Forcefully kill the process
+
+ logger.info("👋")
+ sys.exit(
+ 1
+ ) # Exit immediately TODO - should we exit with 1 here, there may be other workers to cleanup
+
+
+def register_on_worker(callable: HatchetCallable[T], worker: Worker) -> None:
+ worker.register_function(callable.get_action_name(), callable)
+
+ if callable.function_on_failure is not None:
+ worker.register_function(
+ callable.function_on_failure.get_action_name(), callable.function_on_failure
+ )
+
+ if callable.function_concurrency is not None:
+ worker.register_function(
+ callable.function_concurrency.get_action_name(),
+ callable.function_concurrency,
+ )
+
+ opts = callable.to_workflow_opts()
+
+ worker.register_workflow_from_opts(opts.name, opts)