about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/hatchet_sdk/worker
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/worker')
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/worker/__init__.py1
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/worker/action_listener_process.py278
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/run_loop_manager.py112
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/runner.py460
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py81
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/error_with_traceback.py6
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/worker/worker.py392
7 files changed, 1330 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/__init__.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/__init__.py
new file mode 100644
index 00000000..450f0cac
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/__init__.py
@@ -0,0 +1 @@
+from .worker import Worker, WorkerStartOptions, WorkerStatus
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/action_listener_process.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/action_listener_process.py
new file mode 100644
index 00000000..08508607
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/action_listener_process.py
@@ -0,0 +1,278 @@
+import asyncio
+import logging
+import signal
+import time
+from dataclasses import dataclass, field
+from multiprocessing import Queue
+from typing import Any, List, Mapping, Optional
+
+import grpc
+
+from hatchet_sdk.clients.dispatcher.action_listener import Action
+from hatchet_sdk.clients.dispatcher.dispatcher import (
+    ActionListener,
+    GetActionListenerRequest,
+    new_dispatcher,
+)
+from hatchet_sdk.contracts.dispatcher_pb2 import (
+    GROUP_KEY_EVENT_TYPE_STARTED,
+    STEP_EVENT_TYPE_STARTED,
+    ActionType,
+)
+from hatchet_sdk.loader import ClientConfig
+from hatchet_sdk.logger import logger
+from hatchet_sdk.utils.backoff import exp_backoff_sleep
+
+ACTION_EVENT_RETRY_COUNT = 5
+
+
+@dataclass
+class ActionEvent:
+    action: Action
+    type: Any  # TODO type
+    payload: Optional[str] = None
+
+
+STOP_LOOP = "STOP_LOOP"  # Sentinel object to stop the loop
+
+# TODO link to a block post
+BLOCKED_THREAD_WARNING = (
+    "THE TIME TO START THE STEP RUN IS TOO LONG, THE MAIN THREAD MAY BE BLOCKED"
+)
+
+
+def noop_handler():
+    pass
+
+
+@dataclass
+class WorkerActionListenerProcess:
+    name: str
+    actions: List[str]
+    max_runs: int
+    config: ClientConfig
+    action_queue: Queue
+    event_queue: Queue
+    handle_kill: bool = True
+    debug: bool = False
+    labels: dict = field(default_factory=dict)
+
+    listener: ActionListener = field(init=False, default=None)
+
+    killing: bool = field(init=False, default=False)
+
+    action_loop_task: asyncio.Task = field(init=False, default=None)
+    event_send_loop_task: asyncio.Task = field(init=False, default=None)
+
+    running_step_runs: Mapping[str, float] = field(init=False, default_factory=dict)
+
+    def __post_init__(self):
+        if self.debug:
+            logger.setLevel(logging.DEBUG)
+
+        loop = asyncio.get_event_loop()
+        loop.add_signal_handler(signal.SIGINT, noop_handler)
+        loop.add_signal_handler(signal.SIGTERM, noop_handler)
+        loop.add_signal_handler(
+            signal.SIGQUIT, lambda: asyncio.create_task(self.exit_gracefully())
+        )
+
+    async def start(self, retry_attempt=0):
+        if retry_attempt > 5:
+            logger.error("could not start action listener")
+            return
+
+        logger.debug(f"starting action listener: {self.name}")
+
+        try:
+            self.dispatcher_client = new_dispatcher(self.config)
+
+            self.listener = await self.dispatcher_client.get_action_listener(
+                GetActionListenerRequest(
+                    worker_name=self.name,
+                    services=["default"],
+                    actions=self.actions,
+                    max_runs=self.max_runs,
+                    _labels=self.labels,
+                )
+            )
+
+            logger.debug(f"acquired action listener: {self.listener.worker_id}")
+        except grpc.RpcError as rpc_error:
+            logger.error(f"could not start action listener: {rpc_error}")
+            return
+
+        # Start both loops as background tasks
+        self.action_loop_task = asyncio.create_task(self.start_action_loop())
+        self.event_send_loop_task = asyncio.create_task(self.start_event_send_loop())
+        self.blocked_main_loop = asyncio.create_task(self.start_blocked_main_loop())
+
+    # TODO move event methods to separate class
+    async def _get_event(self):
+        loop = asyncio.get_running_loop()
+        return await loop.run_in_executor(None, self.event_queue.get)
+
+    async def start_event_send_loop(self):
+        while True:
+            event: ActionEvent = await self._get_event()
+            if event == STOP_LOOP:
+                logger.debug("stopping event send loop...")
+                break
+
+            logger.debug(f"tx: event: {event.action.action_id}/{event.type}")
+            asyncio.create_task(self.send_event(event))
+
+    async def start_blocked_main_loop(self):
+        threshold = 1
+        while not self.killing:
+            count = 0
+            for step_run_id, start_time in self.running_step_runs.items():
+                diff = self.now() - start_time
+                if diff > threshold:
+                    count += 1
+
+            if count > 0:
+                logger.warning(f"{BLOCKED_THREAD_WARNING}: Waiting Steps {count}")
+            await asyncio.sleep(1)
+
+    async def send_event(self, event: ActionEvent, retry_attempt: int = 1):
+        try:
+            match event.action.action_type:
+                # FIXME: all events sent from an execution of a function are of type ActionType.START_STEP_RUN since
+                # the action is re-used. We should change this.
+                case ActionType.START_STEP_RUN:
+                    # TODO right now we're sending two start_step_run events
+                    # one on the action loop and one on the event loop
+                    # ideally we change the first to an ack to set the time
+                    if event.type == STEP_EVENT_TYPE_STARTED:
+                        if event.action.step_run_id in self.running_step_runs:
+                            diff = (
+                                self.now()
+                                - self.running_step_runs[event.action.step_run_id]
+                            )
+                            if diff > 0.1:
+                                logger.warning(
+                                    f"{BLOCKED_THREAD_WARNING}: time to start: {diff}s"
+                                )
+                            else:
+                                logger.debug(f"start time: {diff}")
+                            del self.running_step_runs[event.action.step_run_id]
+                        else:
+                            self.running_step_runs[event.action.step_run_id] = (
+                                self.now()
+                            )
+
+                    asyncio.create_task(
+                        self.dispatcher_client.send_step_action_event(
+                            event.action, event.type, event.payload
+                        )
+                    )
+                case ActionType.CANCEL_STEP_RUN:
+                    logger.debug("unimplemented event send")
+                case ActionType.START_GET_GROUP_KEY:
+                    asyncio.create_task(
+                        self.dispatcher_client.send_group_key_action_event(
+                            event.action, event.type, event.payload
+                        )
+                    )
+                case _:
+                    logger.error("unknown action type for event send")
+        except Exception as e:
+            logger.error(
+                f"could not send action event ({retry_attempt}/{ACTION_EVENT_RETRY_COUNT}): {e}"
+            )
+            if retry_attempt <= ACTION_EVENT_RETRY_COUNT:
+                await exp_backoff_sleep(retry_attempt, 1)
+                await self.send_event(event, retry_attempt + 1)
+
+    def now(self):
+        return time.time()
+
+    async def start_action_loop(self):
+        try:
+            async for action in self.listener:
+                if action is None:
+                    break
+
+                # Process the action here
+                match action.action_type:
+                    case ActionType.START_STEP_RUN:
+                        self.event_queue.put(
+                            ActionEvent(
+                                action=action,
+                                type=STEP_EVENT_TYPE_STARTED,  # TODO ack type
+                            )
+                        )
+                        logger.info(
+                            f"rx: start step run: {action.step_run_id}/{action.action_id}"
+                        )
+
+                        # TODO handle this case better...
+                        if action.step_run_id in self.running_step_runs:
+                            logger.warning(
+                                f"step run already running: {action.step_run_id}"
+                            )
+
+                    case ActionType.CANCEL_STEP_RUN:
+                        logger.info(f"rx: cancel step run: {action.step_run_id}")
+                    case ActionType.START_GET_GROUP_KEY:
+                        self.event_queue.put(
+                            ActionEvent(
+                                action=action,
+                                type=GROUP_KEY_EVENT_TYPE_STARTED,  # TODO ack type
+                            )
+                        )
+                        logger.info(
+                            f"rx: start group key: {action.get_group_key_run_id}"
+                        )
+                    case _:
+                        logger.error(
+                            f"rx: unknown action type ({action.action_type}): {action.action_type}"
+                        )
+                try:
+                    self.action_queue.put(action)
+                except Exception as e:
+                    logger.error(f"error putting action: {e}")
+
+        except Exception as e:
+            logger.error(f"error in action loop: {e}")
+        finally:
+            logger.info("action loop closed")
+            if not self.killing:
+                await self.exit_gracefully(skip_unregister=True)
+
+    async def cleanup(self):
+        self.killing = True
+
+        if self.listener is not None:
+            self.listener.cleanup()
+
+        self.event_queue.put(STOP_LOOP)
+
+    async def exit_gracefully(self, skip_unregister=False):
+        if self.killing:
+            return
+
+        logger.debug("closing action listener...")
+
+        await self.cleanup()
+
+        while not self.event_queue.empty():
+            pass
+
+        logger.info("action listener closed")
+
+    def exit_forcefully(self):
+        asyncio.run(self.cleanup())
+        logger.debug("forcefully closing listener...")
+
+
+def worker_action_listener_process(*args, **kwargs):
+    async def run():
+        process = WorkerActionListenerProcess(*args, **kwargs)
+        await process.start()
+        # Keep the process running
+        while not process.killing:
+            await asyncio.sleep(0.1)
+
+    asyncio.run(run())
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/run_loop_manager.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/run_loop_manager.py
new file mode 100644
index 00000000..27ed788c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/run_loop_manager.py
@@ -0,0 +1,112 @@
+import asyncio
+import logging
+from dataclasses import dataclass, field
+from multiprocessing import Queue
+from typing import Callable, TypeVar
+
+from hatchet_sdk import Context
+from hatchet_sdk.client import Client, new_client_raw
+from hatchet_sdk.clients.dispatcher.action_listener import Action
+from hatchet_sdk.loader import ClientConfig
+from hatchet_sdk.logger import logger
+from hatchet_sdk.utils.types import WorkflowValidator
+from hatchet_sdk.worker.runner.runner import Runner
+from hatchet_sdk.worker.runner.utils.capture_logs import capture_logs
+
+STOP_LOOP = "STOP_LOOP"
+
+T = TypeVar("T")
+
+
+@dataclass
+class WorkerActionRunLoopManager:
+    name: str
+    action_registry: dict[str, Callable[[Context], T]]
+    validator_registry: dict[str, WorkflowValidator]
+    max_runs: int | None
+    config: ClientConfig
+    action_queue: Queue
+    event_queue: Queue
+    loop: asyncio.AbstractEventLoop
+    handle_kill: bool = True
+    debug: bool = False
+    labels: dict[str, str | int] = field(default_factory=dict)
+
+    client: Client = field(init=False, default=None)
+
+    killing: bool = field(init=False, default=False)
+    runner: Runner = field(init=False, default=None)
+
+    def __post_init__(self):
+        if self.debug:
+            logger.setLevel(logging.DEBUG)
+        self.client = new_client_raw(self.config, self.debug)
+        self.start()
+
+    def start(self, retry_count=1):
+        k = self.loop.create_task(self.async_start(retry_count))
+
+    async def async_start(self, retry_count=1):
+        await capture_logs(
+            self.client.logInterceptor,
+            self.client.event,
+            self._async_start,
+        )(retry_count=retry_count)
+
+    async def _async_start(self, retry_count: int = 1) -> None:
+        logger.info("starting runner...")
+        self.loop = asyncio.get_running_loop()
+        # needed for graceful termination
+        k = self.loop.create_task(self._start_action_loop())
+        await k
+
+    def cleanup(self) -> None:
+        self.killing = True
+
+        self.action_queue.put(STOP_LOOP)
+
+    async def wait_for_tasks(self) -> None:
+        if self.runner:
+            await self.runner.wait_for_tasks()
+
+    async def _start_action_loop(self) -> None:
+        self.runner = Runner(
+            self.name,
+            self.event_queue,
+            self.max_runs,
+            self.handle_kill,
+            self.action_registry,
+            self.validator_registry,
+            self.config,
+            self.labels,
+        )
+
+        logger.debug(f"'{self.name}' waiting for {list(self.action_registry.keys())}")
+        while not self.killing:
+            action: Action = await self._get_action()
+            if action == STOP_LOOP:
+                logger.debug("stopping action runner loop...")
+                break
+
+            self.runner.run(action)
+        logger.debug("action runner loop stopped")
+
+    async def _get_action(self):
+        return await self.loop.run_in_executor(None, self.action_queue.get)
+
+    async def exit_gracefully(self) -> None:
+        if self.killing:
+            return
+
+        logger.info("gracefully exiting runner...")
+
+        self.cleanup()
+
+        # Wait for 1 second to allow last calls to flush. These are calls which have been
+        # added to the event loop as callbacks to tasks, so we're not aware of them in the
+        # task list.
+        await asyncio.sleep(1)
+
+    def exit_forcefully(self) -> None:
+        logger.info("forcefully exiting runner...")
+        self.cleanup()
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/runner.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/runner.py
new file mode 100644
index 00000000..01e61bcf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/runner.py
@@ -0,0 +1,460 @@
+import asyncio
+import contextvars
+import ctypes
+import functools
+import json
+import time
+import traceback
+from concurrent.futures import ThreadPoolExecutor
+from enum import Enum
+from multiprocessing import Queue
+from threading import Thread, current_thread
+from typing import Any, Callable, Dict, cast
+
+from pydantic import BaseModel
+
+from hatchet_sdk.client import new_client_raw
+from hatchet_sdk.clients.admin import new_admin
+from hatchet_sdk.clients.dispatcher.action_listener import Action
+from hatchet_sdk.clients.dispatcher.dispatcher import new_dispatcher
+from hatchet_sdk.clients.run_event_listener import new_listener
+from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
+from hatchet_sdk.context import Context  # type: ignore[attr-defined]
+from hatchet_sdk.context.worker_context import WorkerContext
+from hatchet_sdk.contracts.dispatcher_pb2 import (
+    GROUP_KEY_EVENT_TYPE_COMPLETED,
+    GROUP_KEY_EVENT_TYPE_FAILED,
+    GROUP_KEY_EVENT_TYPE_STARTED,
+    STEP_EVENT_TYPE_COMPLETED,
+    STEP_EVENT_TYPE_FAILED,
+    STEP_EVENT_TYPE_STARTED,
+    ActionType,
+)
+from hatchet_sdk.loader import ClientConfig
+from hatchet_sdk.logger import logger
+from hatchet_sdk.utils.types import WorkflowValidator
+from hatchet_sdk.v2.callable import DurableContext
+from hatchet_sdk.worker.action_listener_process import ActionEvent
+from hatchet_sdk.worker.runner.utils.capture_logs import copy_context_vars, sr, wr
+
+
+class WorkerStatus(Enum):
+    INITIALIZED = 1
+    STARTING = 2
+    HEALTHY = 3
+    UNHEALTHY = 4
+
+
+class Runner:
+    def __init__(
+        self,
+        name: str,
+        event_queue: "Queue[Any]",
+        max_runs: int | None = None,
+        handle_kill: bool = True,
+        action_registry: dict[str, Callable[..., Any]] = {},
+        validator_registry: dict[str, WorkflowValidator] = {},
+        config: ClientConfig = ClientConfig(),
+        labels: dict[str, str | int] = {},
+    ):
+        # We store the config so we can dynamically create clients for the dispatcher client.
+        self.config = config
+        self.client = new_client_raw(config)
+        self.name = self.client.config.namespace + name
+        self.max_runs = max_runs
+        self.tasks: dict[str, asyncio.Task[Any]] = {}  # Store run ids and futures
+        self.contexts: dict[str, Context] = {}  # Store run ids and contexts
+        self.action_registry: dict[str, Callable[..., Any]] = action_registry
+        self.validator_registry = validator_registry
+
+        self.event_queue = event_queue
+
+        # The thread pool is used for synchronous functions which need to run concurrently
+        self.thread_pool = ThreadPoolExecutor(max_workers=max_runs)
+        self.threads: Dict[str, Thread] = {}  # Store run ids and threads
+
+        self.killing = False
+        self.handle_kill = handle_kill
+
+        # We need to initialize a new admin and dispatcher client *after* we've started the event loop,
+        # otherwise the grpc.aio methods will use a different event loop and we'll get a bunch of errors.
+        self.dispatcher_client = new_dispatcher(self.config)
+        self.admin_client = new_admin(self.config)
+        self.workflow_run_event_listener = new_listener(self.config)
+        self.client.workflow_listener = PooledWorkflowRunListener(self.config)
+
+        self.worker_context = WorkerContext(
+            labels=labels, client=new_client_raw(config).dispatcher
+        )
+
+    def create_workflow_run_url(self, action: Action) -> str:
+        return f"{self.config.server_url}/workflow-runs/{action.workflow_run_id}?tenant={action.tenant_id}"
+
+    def run(self, action: Action) -> None:
+        if self.worker_context.id() is None:
+            self.worker_context._worker_id = action.worker_id
+
+        match action.action_type:
+            case ActionType.START_STEP_RUN:
+                log = f"run: start step: {action.action_id}/{action.step_run_id}"
+                logger.info(log)
+                asyncio.create_task(self.handle_start_step_run(action))
+            case ActionType.CANCEL_STEP_RUN:
+                log = f"cancel: step run:  {action.action_id}/{action.step_run_id}"
+                logger.info(log)
+                asyncio.create_task(self.handle_cancel_action(action.step_run_id))
+            case ActionType.START_GET_GROUP_KEY:
+                log = f"run: get group key:  {action.action_id}/{action.get_group_key_run_id}"
+                logger.info(log)
+                asyncio.create_task(self.handle_start_group_key_run(action))
+            case _:
+                log = f"unknown action type: {action.action_type}"
+                logger.error(log)
+
+    def step_run_callback(self, action: Action) -> Callable[[asyncio.Task[Any]], None]:
+        def inner_callback(task: asyncio.Task[Any]) -> None:
+            self.cleanup_run_id(action.step_run_id)
+
+            errored = False
+            cancelled = task.cancelled()
+
+            # Get the output from the future
+            try:
+                if not cancelled:
+                    output = task.result()
+            except Exception as e:
+                errored = True
+
+                # This except is coming from the application itself, so we want to send that to the Hatchet instance
+                self.event_queue.put(
+                    ActionEvent(
+                        action=action,
+                        type=STEP_EVENT_TYPE_FAILED,
+                        payload=str(errorWithTraceback(f"{e}", e)),
+                    )
+                )
+
+                logger.error(
+                    f"failed step run: {action.action_id}/{action.step_run_id}"
+                )
+
+            if not errored and not cancelled:
+                self.event_queue.put(
+                    ActionEvent(
+                        action=action,
+                        type=STEP_EVENT_TYPE_COMPLETED,
+                        payload=self.serialize_output(output),
+                    )
+                )
+
+                logger.info(
+                    f"finished step run: {action.action_id}/{action.step_run_id}"
+                )
+
+        return inner_callback
+
+    def group_key_run_callback(
+        self, action: Action
+    ) -> Callable[[asyncio.Task[Any]], None]:
+        def inner_callback(task: asyncio.Task[Any]) -> None:
+            self.cleanup_run_id(action.get_group_key_run_id)
+
+            errored = False
+            cancelled = task.cancelled()
+
+            # Get the output from the future
+            try:
+                if not cancelled:
+                    output = task.result()
+            except Exception as e:
+                errored = True
+                self.event_queue.put(
+                    ActionEvent(
+                        action=action,
+                        type=GROUP_KEY_EVENT_TYPE_FAILED,
+                        payload=str(errorWithTraceback(f"{e}", e)),
+                    )
+                )
+
+                logger.error(
+                    f"failed step run: {action.action_id}/{action.step_run_id}"
+                )
+
+            if not errored and not cancelled:
+                self.event_queue.put(
+                    ActionEvent(
+                        action=action,
+                        type=GROUP_KEY_EVENT_TYPE_COMPLETED,
+                        payload=self.serialize_output(output),
+                    )
+                )
+
+                logger.info(
+                    f"finished step run: {action.action_id}/{action.step_run_id}"
+                )
+
+        return inner_callback
+
+    ## TODO: Stricter type hinting here
+    def thread_action_func(
+        self, context: Context, action_func: Callable[..., Any], action: Action
+    ) -> Any:
+        if action.step_run_id is not None and action.step_run_id != "":
+            self.threads[action.step_run_id] = current_thread()
+        elif (
+            action.get_group_key_run_id is not None
+            and action.get_group_key_run_id != ""
+        ):
+            self.threads[action.get_group_key_run_id] = current_thread()
+
+        return action_func(context)
+
+    ## TODO: Stricter type hinting here
+    # We wrap all actions in an async func
+    async def async_wrapped_action_func(
+        self,
+        context: Context,
+        action_func: Callable[..., Any],
+        action: Action,
+        run_id: str,
+    ) -> Any:
+        wr.set(context.workflow_run_id())
+        sr.set(context.step_run_id)
+
+        try:
+            if (
+                hasattr(action_func, "is_coroutine") and action_func.is_coroutine
+            ) or asyncio.iscoroutinefunction(action_func):
+                return await action_func(context)
+            else:
+                pfunc = functools.partial(
+                    # we must copy the context vars to the new thread, as only asyncio natively supports
+                    # contextvars
+                    copy_context_vars,
+                    contextvars.copy_context().items(),
+                    self.thread_action_func,
+                    context,
+                    action_func,
+                    action,
+                )
+
+                loop = asyncio.get_event_loop()
+                return await loop.run_in_executor(self.thread_pool, pfunc)
+        except Exception as e:
+            logger.error(
+                errorWithTraceback(
+                    f"exception raised in action ({action.action_id}, retry={action.retry_count}):\n{e}",
+                    e,
+                )
+            )
+            raise e
+        finally:
+            self.cleanup_run_id(run_id)
+
+    def cleanup_run_id(self, run_id: str | None) -> None:
+        if run_id in self.tasks:
+            del self.tasks[run_id]
+
+        if run_id in self.threads:
+            del self.threads[run_id]
+
+        if run_id in self.contexts:
+            del self.contexts[run_id]
+
+    def create_context(
+        self, action: Action, action_func: Callable[..., Any] | None
+    ) -> Context | DurableContext:
+        if hasattr(action_func, "durable") and getattr(action_func, "durable"):
+            return DurableContext(
+                action,
+                self.dispatcher_client,
+                self.admin_client,
+                self.client.event,
+                self.client.rest,
+                self.client.workflow_listener,
+                self.workflow_run_event_listener,
+                self.worker_context,
+                self.client.config.namespace,
+                validator_registry=self.validator_registry,
+            )
+
+        return Context(
+            action,
+            self.dispatcher_client,
+            self.admin_client,
+            self.client.event,
+            self.client.rest,
+            self.client.workflow_listener,
+            self.workflow_run_event_listener,
+            self.worker_context,
+            self.client.config.namespace,
+            validator_registry=self.validator_registry,
+        )
+
+    ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
+    async def handle_start_step_run(self, action: Action) -> None | Exception:
+        action_name = action.action_id
+
+        # Find the corresponding action function from the registry
+        action_func = self.action_registry.get(action_name)
+
+        context = self.create_context(action, action_func)
+
+        self.contexts[action.step_run_id] = context
+
+        if action_func:
+            self.event_queue.put(
+                ActionEvent(
+                    action=action,
+                    type=STEP_EVENT_TYPE_STARTED,
+                )
+            )
+
+            loop = asyncio.get_event_loop()
+            task = loop.create_task(
+                self.async_wrapped_action_func(
+                    context, action_func, action, action.step_run_id
+                )
+            )
+
+            task.add_done_callback(self.step_run_callback(action))
+            self.tasks[action.step_run_id] = task
+
+            try:
+                await task
+            except Exception as e:
+                return e
+
+        return None
+
+    ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
+    async def handle_start_group_key_run(self, action: Action) -> Exception | None:
+        action_name = action.action_id
+        context = Context(
+            action,
+            self.dispatcher_client,
+            self.admin_client,
+            self.client.event,
+            self.client.rest,
+            self.client.workflow_listener,
+            self.workflow_run_event_listener,
+            self.worker_context,
+            self.client.config.namespace,
+        )
+
+        self.contexts[action.get_group_key_run_id] = context
+
+        # Find the corresponding action function from the registry
+        action_func = self.action_registry.get(action_name)
+
+        if action_func:
+            # send an event that the group key run has started
+            self.event_queue.put(
+                ActionEvent(
+                    action=action,
+                    type=GROUP_KEY_EVENT_TYPE_STARTED,
+                )
+            )
+
+            loop = asyncio.get_event_loop()
+            task = loop.create_task(
+                self.async_wrapped_action_func(
+                    context, action_func, action, action.get_group_key_run_id
+                )
+            )
+
+            task.add_done_callback(self.group_key_run_callback(action))
+            self.tasks[action.get_group_key_run_id] = task
+
+            try:
+                await task
+            except Exception as e:
+                return e
+
+        return None
+
+    def force_kill_thread(self, thread: Thread) -> None:
+        """Terminate a python threading.Thread."""
+        try:
+            if not thread.is_alive():
+                return
+
+            ident = cast(int, thread.ident)
+
+            logger.info(f"Forcefully terminating thread {ident}")
+
+            exc = ctypes.py_object(SystemExit)
+            res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(ident), exc)
+            if res == 0:
+                raise ValueError("Invalid thread ID")
+            elif res != 1:
+                logger.error("PyThreadState_SetAsyncExc failed")
+
+                # Call with exception set to 0 is needed to cleanup properly.
+                ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0)
+                raise SystemError("PyThreadState_SetAsyncExc failed")
+
+            logger.info(f"Successfully terminated thread {ident}")
+
+            # Immediately add a new thread to the thread pool, because we've actually killed a worker
+            # in the ThreadPoolExecutor
+            self.thread_pool.submit(lambda: None)
+        except Exception as e:
+            logger.exception(f"Failed to terminate thread: {e}")
+
+    ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
+    async def handle_cancel_action(self, run_id: str) -> None:
+        try:
+            # call cancel to signal the context to stop
+            if run_id in self.contexts:
+                context = self.contexts.get(run_id)
+
+                if context:
+                    context.cancel()
+
+            await asyncio.sleep(1)
+
+            if run_id in self.tasks:
+                future = self.tasks.get(run_id)
+
+                if future:
+                    future.cancel()
+
+            # check if thread is still running, if so, print a warning
+            if run_id in self.threads:
+                thread = self.threads.get(run_id)
+                if thread and self.client.config.enable_force_kill_sync_threads:
+                    self.force_kill_thread(thread)
+                    await asyncio.sleep(1)
+
+                logger.warning(
+                    f"Thread {self.threads[run_id].ident} with run id {run_id} is still running after cancellation. This could cause the thread pool to get blocked and prevent new tasks from running."
+                )
+        finally:
+            self.cleanup_run_id(run_id)
+
+    def serialize_output(self, output: Any) -> str:
+
+        if isinstance(output, BaseModel):
+            return output.model_dump_json()
+
+        if output is not None:
+            try:
+                return json.dumps(output)
+            except Exception as e:
+                logger.error(f"Could not serialize output: {e}")
+                return str(output)
+
+        return ""
+
+    async def wait_for_tasks(self) -> None:
+        running = len(self.tasks.keys())
+        while running > 0:
+            logger.info(f"waiting for {running} tasks to finish...")
+            await asyncio.sleep(1)
+            running = len(self.tasks.keys())
+
+
+def errorWithTraceback(message: str, e: Exception) -> str:
+    trace = "".join(traceback.format_exception(type(e), e, e.__traceback__))
+    return f"{message}\n{trace}"
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py
new file mode 100644
index 00000000..08c57de8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/capture_logs.py
@@ -0,0 +1,81 @@
+import contextvars
+import functools
+import logging
+from concurrent.futures import ThreadPoolExecutor
+from io import StringIO
+from typing import Any, Coroutine
+
+from hatchet_sdk import logger
+from hatchet_sdk.clients.events import EventClient
+
+wr: contextvars.ContextVar[str | None] = contextvars.ContextVar(
+    "workflow_run_id", default=None
+)
+sr: contextvars.ContextVar[str | None] = contextvars.ContextVar(
+    "step_run_id", default=None
+)
+
+
+def copy_context_vars(ctx_vars, func, *args, **kwargs):
+    for var, value in ctx_vars:
+        var.set(value)
+    return func(*args, **kwargs)
+
+
+class InjectingFilter(logging.Filter):
+    # For some reason, only the InjectingFilter has access to the contextvars method sr.get(),
+    # otherwise we would use emit within the CustomLogHandler
+    def filter(self, record):
+        record.workflow_run_id = wr.get()
+        record.step_run_id = sr.get()
+        return True
+
+
+class CustomLogHandler(logging.StreamHandler):
+    def __init__(self, event_client: EventClient, stream=None):
+        super().__init__(stream)
+        self.logger_thread_pool = ThreadPoolExecutor(max_workers=1)
+        self.event_client = event_client
+
+    def _log(self, line: str, step_run_id: str | None):
+        try:
+            if not step_run_id:
+                return
+
+            self.event_client.log(message=line, step_run_id=step_run_id)
+        except Exception as e:
+            logger.error(f"Error logging: {e}")
+
+    def emit(self, record):
+        super().emit(record)
+
+        log_entry = self.format(record)
+        self.logger_thread_pool.submit(self._log, log_entry, record.step_run_id)
+
+
+def capture_logs(
+    logger: logging.Logger,
+    event_client: EventClient,
+    func: Coroutine[Any, Any, Any],
+):
+    @functools.wraps(func)
+    async def wrapper(*args, **kwargs):
+        if not logger:
+            raise Exception("No logger configured on client")
+
+        log_stream = StringIO()
+        custom_handler = CustomLogHandler(event_client, log_stream)
+        custom_handler.setLevel(logging.INFO)
+        custom_handler.addFilter(InjectingFilter())
+        logger.addHandler(custom_handler)
+
+        try:
+            result = await func(*args, **kwargs)
+        finally:
+            custom_handler.flush()
+            logger.removeHandler(custom_handler)
+            log_stream.close()
+
+        return result
+
+    return wrapper
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/error_with_traceback.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/error_with_traceback.py
new file mode 100644
index 00000000..9c09602f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/utils/error_with_traceback.py
@@ -0,0 +1,6 @@
+import traceback
+
+
+def errorWithTraceback(message: str, e: Exception):
+    trace = "".join(traceback.format_exception(type(e), e, e.__traceback__))
+    return f"{message}\n{trace}"
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/worker.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/worker.py
new file mode 100644
index 00000000..b6ec1531
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/worker.py
@@ -0,0 +1,392 @@
+import asyncio
+import multiprocessing
+import multiprocessing.context
+import os
+import signal
+import sys
+from concurrent.futures import Future
+from dataclasses import dataclass, field
+from enum import Enum
+from multiprocessing import Queue
+from multiprocessing.process import BaseProcess
+from types import FrameType
+from typing import Any, Callable, TypeVar, get_type_hints
+
+from aiohttp import web
+from aiohttp.web_request import Request
+from aiohttp.web_response import Response
+from prometheus_client import CONTENT_TYPE_LATEST, Gauge, generate_latest
+
+from hatchet_sdk import Context
+from hatchet_sdk.client import Client, new_client_raw
+from hatchet_sdk.contracts.workflows_pb2 import CreateWorkflowVersionOpts
+from hatchet_sdk.loader import ClientConfig
+from hatchet_sdk.logger import logger
+from hatchet_sdk.utils.types import WorkflowValidator
+from hatchet_sdk.utils.typing import is_basemodel_subclass
+from hatchet_sdk.v2.callable import HatchetCallable
+from hatchet_sdk.v2.concurrency import ConcurrencyFunction
+from hatchet_sdk.worker.action_listener_process import worker_action_listener_process
+from hatchet_sdk.worker.runner.run_loop_manager import WorkerActionRunLoopManager
+from hatchet_sdk.workflow import WorkflowInterface
+
+T = TypeVar("T")
+
+
+class WorkerStatus(Enum):
+    INITIALIZED = 1
+    STARTING = 2
+    HEALTHY = 3
+    UNHEALTHY = 4
+
+
+@dataclass
+class WorkerStartOptions:
+    loop: asyncio.AbstractEventLoop | None = field(default=None)
+
+
+TWorkflow = TypeVar("TWorkflow", bound=object)
+
+
+class Worker:
+    def __init__(
+        self,
+        name: str,
+        config: ClientConfig = ClientConfig(),
+        max_runs: int | None = None,
+        labels: dict[str, str | int] = {},
+        debug: bool = False,
+        owned_loop: bool = True,
+        handle_kill: bool = True,
+    ) -> None:
+        self.name = name
+        self.config = config
+        self.max_runs = max_runs
+        self.debug = debug
+        self.labels = labels
+        self.handle_kill = handle_kill
+        self.owned_loop = owned_loop
+
+        self.client: Client
+
+        self.action_registry: dict[str, Callable[[Context], Any]] = {}
+        self.validator_registry: dict[str, WorkflowValidator] = {}
+
+        self.killing: bool = False
+        self._status: WorkerStatus
+
+        self.action_listener_process: BaseProcess
+        self.action_listener_health_check: asyncio.Task[Any]
+        self.action_runner: WorkerActionRunLoopManager
+
+        self.ctx = multiprocessing.get_context("spawn")
+
+        self.action_queue: "Queue[Any]" = self.ctx.Queue()
+        self.event_queue: "Queue[Any]" = self.ctx.Queue()
+
+        self.loop: asyncio.AbstractEventLoop
+
+        self.client = new_client_raw(self.config, self.debug)
+        self.name = self.client.config.namespace + self.name
+
+        self._setup_signal_handlers()
+
+        self.worker_status_gauge = Gauge(
+            "hatchet_worker_status", "Current status of the Hatchet worker"
+        )
+
+    def register_function(self, action: str, func: Callable[[Context], Any]) -> None:
+        self.action_registry[action] = func
+
+    def register_workflow_from_opts(
+        self, name: str, opts: CreateWorkflowVersionOpts
+    ) -> None:
+        try:
+            self.client.admin.put_workflow(opts.name, opts)
+        except Exception as e:
+            logger.error(f"failed to register workflow: {opts.name}")
+            logger.error(e)
+            sys.exit(1)
+
+    def register_workflow(self, workflow: TWorkflow) -> None:
+        ## Hack for typing
+        assert isinstance(workflow, WorkflowInterface)
+
+        namespace = self.client.config.namespace
+
+        try:
+            self.client.admin.put_workflow(
+                workflow.get_name(namespace), workflow.get_create_opts(namespace)
+            )
+        except Exception as e:
+            logger.error(f"failed to register workflow: {workflow.get_name(namespace)}")
+            logger.error(e)
+            sys.exit(1)
+
+        def create_action_function(
+            action_func: Callable[..., T]
+        ) -> Callable[[Context], T]:
+            def action_function(context: Context) -> T:
+                return action_func(workflow, context)
+
+            if asyncio.iscoroutinefunction(action_func):
+                setattr(action_function, "is_coroutine", True)
+            else:
+                setattr(action_function, "is_coroutine", False)
+
+            return action_function
+
+        for action_name, action_func in workflow.get_actions(namespace):
+            self.action_registry[action_name] = create_action_function(action_func)
+            return_type = get_type_hints(action_func).get("return")
+
+            self.validator_registry[action_name] = WorkflowValidator(
+                workflow_input=workflow.input_validator,
+                step_output=return_type if is_basemodel_subclass(return_type) else None,
+            )
+
+    def status(self) -> WorkerStatus:
+        return self._status
+
+    def setup_loop(self, loop: asyncio.AbstractEventLoop | None = None) -> bool:
+        try:
+            loop = loop or asyncio.get_running_loop()
+            self.loop = loop
+            created_loop = False
+            logger.debug("using existing event loop")
+            return created_loop
+        except RuntimeError:
+            self.loop = asyncio.new_event_loop()
+            logger.debug("creating new event loop")
+            asyncio.set_event_loop(self.loop)
+            created_loop = True
+            return created_loop
+
+    async def health_check_handler(self, request: Request) -> Response:
+        status = self.status()
+
+        return web.json_response({"status": status.name})
+
+    async def metrics_handler(self, request: Request) -> Response:
+        self.worker_status_gauge.set(1 if self.status() == WorkerStatus.HEALTHY else 0)
+
+        return web.Response(body=generate_latest(), content_type="text/plain")
+
+    async def start_health_server(self) -> None:
+        port = self.config.worker_healthcheck_port or 8001
+
+        app = web.Application()
+        app.add_routes(
+            [
+                web.get("/health", self.health_check_handler),
+                web.get("/metrics", self.metrics_handler),
+            ]
+        )
+
+        runner = web.AppRunner(app)
+
+        try:
+            await runner.setup()
+            await web.TCPSite(runner, "0.0.0.0", port).start()
+        except Exception as e:
+            logger.error("failed to start healthcheck server")
+            logger.error(str(e))
+            return
+
+        logger.info(f"healthcheck server running on port {port}")
+
+    def start(
+        self, options: WorkerStartOptions = WorkerStartOptions()
+    ) -> Future[asyncio.Task[Any] | None]:
+        self.owned_loop = self.setup_loop(options.loop)
+
+        f = asyncio.run_coroutine_threadsafe(
+            self.async_start(options, _from_start=True), self.loop
+        )
+
+        # start the loop and wait until its closed
+        if self.owned_loop:
+            self.loop.run_forever()
+
+            if self.handle_kill:
+                sys.exit(0)
+
+        return f
+
+    ## Start methods
+    async def async_start(
+        self,
+        options: WorkerStartOptions = WorkerStartOptions(),
+        _from_start: bool = False,
+    ) -> Any | None:
+        main_pid = os.getpid()
+        logger.info("------------------------------------------")
+        logger.info("STARTING HATCHET...")
+        logger.debug(f"worker runtime starting on PID: {main_pid}")
+
+        self._status = WorkerStatus.STARTING
+
+        if len(self.action_registry.keys()) == 0:
+            logger.error(
+                "no actions registered, register workflows or actions before starting worker"
+            )
+            return None
+
+        # non blocking setup
+        if not _from_start:
+            self.setup_loop(options.loop)
+
+        if self.config.worker_healthcheck_enabled:
+            await self.start_health_server()
+
+        self.action_listener_process = self._start_listener()
+
+        self.action_runner = self._run_action_runner()
+
+        self.action_listener_health_check = self.loop.create_task(
+            self._check_listener_health()
+        )
+
+        return await self.action_listener_health_check
+
+    def _run_action_runner(self) -> WorkerActionRunLoopManager:
+        # Retrieve the shared queue
+        return WorkerActionRunLoopManager(
+            self.name,
+            self.action_registry,
+            self.validator_registry,
+            self.max_runs,
+            self.config,
+            self.action_queue,
+            self.event_queue,
+            self.loop,
+            self.handle_kill,
+            self.client.debug,
+            self.labels,
+        )
+
+    def _start_listener(self) -> multiprocessing.context.SpawnProcess:
+        action_list = [str(key) for key in self.action_registry.keys()]
+
+        try:
+            process = self.ctx.Process(
+                target=worker_action_listener_process,
+                args=(
+                    self.name,
+                    action_list,
+                    self.max_runs,
+                    self.config,
+                    self.action_queue,
+                    self.event_queue,
+                    self.handle_kill,
+                    self.client.debug,
+                    self.labels,
+                ),
+            )
+            process.start()
+            logger.debug(f"action listener starting on PID: {process.pid}")
+
+            return process
+        except Exception as e:
+            logger.error(f"failed to start action listener: {e}")
+            sys.exit(1)
+
+    async def _check_listener_health(self) -> None:
+        logger.debug("starting action listener health check...")
+        try:
+            while not self.killing:
+                if (
+                    self.action_listener_process is None
+                    or not self.action_listener_process.is_alive()
+                ):
+                    logger.debug("child action listener process killed...")
+                    self._status = WorkerStatus.UNHEALTHY
+                    if not self.killing:
+                        self.loop.create_task(self.exit_gracefully())
+                    break
+                else:
+                    self._status = WorkerStatus.HEALTHY
+                await asyncio.sleep(1)
+        except Exception as e:
+            logger.error(f"error checking listener health: {e}")
+
+    ## Cleanup methods
+    def _setup_signal_handlers(self) -> None:
+        signal.signal(signal.SIGTERM, self._handle_exit_signal)
+        signal.signal(signal.SIGINT, self._handle_exit_signal)
+        signal.signal(signal.SIGQUIT, self._handle_force_quit_signal)
+
+    def _handle_exit_signal(self, signum: int, frame: FrameType | None) -> None:
+        sig_name = "SIGTERM" if signum == signal.SIGTERM else "SIGINT"
+        logger.info(f"received signal {sig_name}...")
+        self.loop.create_task(self.exit_gracefully())
+
+    def _handle_force_quit_signal(self, signum: int, frame: FrameType | None) -> None:
+        logger.info("received SIGQUIT...")
+        self.exit_forcefully()
+
+    async def close(self) -> None:
+        logger.info(f"closing worker '{self.name}'...")
+        self.killing = True
+        # self.action_queue.close()
+        # self.event_queue.close()
+
+        if self.action_runner is not None:
+            self.action_runner.cleanup()
+
+        await self.action_listener_health_check
+
+    async def exit_gracefully(self) -> None:
+        logger.debug(f"gracefully stopping worker: {self.name}")
+
+        if self.killing:
+            return self.exit_forcefully()
+
+        self.killing = True
+
+        await self.action_runner.wait_for_tasks()
+
+        await self.action_runner.exit_gracefully()
+
+        if self.action_listener_process and self.action_listener_process.is_alive():
+            self.action_listener_process.kill()
+
+        await self.close()
+        if self.loop and self.owned_loop:
+            self.loop.stop()
+
+        logger.info("👋")
+
+    def exit_forcefully(self) -> None:
+        self.killing = True
+
+        logger.debug(f"forcefully stopping worker: {self.name}")
+
+        self.close()
+
+        if self.action_listener_process:
+            self.action_listener_process.kill()  # Forcefully kill the process
+
+        logger.info("👋")
+        sys.exit(
+            1
+        )  # Exit immediately TODO - should we exit with 1 here, there may be other workers to cleanup
+
+
+def register_on_worker(callable: HatchetCallable[T], worker: Worker) -> None:
+    worker.register_function(callable.get_action_name(), callable)
+
+    if callable.function_on_failure is not None:
+        worker.register_function(
+            callable.function_on_failure.get_action_name(), callable.function_on_failure
+        )
+
+    if callable.function_concurrency is not None:
+        worker.register_function(
+            callable.function_concurrency.get_action_name(),
+            callable.function_concurrency,
+        )
+
+    opts = callable.to_workflow_opts()
+
+    worker.register_workflow_from_opts(opts.name, opts)