aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/worker.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/worker/worker.py')
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/worker/worker.py392
1 files changed, 392 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/worker.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/worker.py
new file mode 100644
index 00000000..b6ec1531
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/worker.py
@@ -0,0 +1,392 @@
+import asyncio
+import multiprocessing
+import multiprocessing.context
+import os
+import signal
+import sys
+from concurrent.futures import Future
+from dataclasses import dataclass, field
+from enum import Enum
+from multiprocessing import Queue
+from multiprocessing.process import BaseProcess
+from types import FrameType
+from typing import Any, Callable, TypeVar, get_type_hints
+
+from aiohttp import web
+from aiohttp.web_request import Request
+from aiohttp.web_response import Response
+from prometheus_client import CONTENT_TYPE_LATEST, Gauge, generate_latest
+
+from hatchet_sdk import Context
+from hatchet_sdk.client import Client, new_client_raw
+from hatchet_sdk.contracts.workflows_pb2 import CreateWorkflowVersionOpts
+from hatchet_sdk.loader import ClientConfig
+from hatchet_sdk.logger import logger
+from hatchet_sdk.utils.types import WorkflowValidator
+from hatchet_sdk.utils.typing import is_basemodel_subclass
+from hatchet_sdk.v2.callable import HatchetCallable
+from hatchet_sdk.v2.concurrency import ConcurrencyFunction
+from hatchet_sdk.worker.action_listener_process import worker_action_listener_process
+from hatchet_sdk.worker.runner.run_loop_manager import WorkerActionRunLoopManager
+from hatchet_sdk.workflow import WorkflowInterface
+
+T = TypeVar("T")
+
+
+class WorkerStatus(Enum):
+ INITIALIZED = 1
+ STARTING = 2
+ HEALTHY = 3
+ UNHEALTHY = 4
+
+
+@dataclass
+class WorkerStartOptions:
+ loop: asyncio.AbstractEventLoop | None = field(default=None)
+
+
+TWorkflow = TypeVar("TWorkflow", bound=object)
+
+
+class Worker:
+ def __init__(
+ self,
+ name: str,
+ config: ClientConfig = ClientConfig(),
+ max_runs: int | None = None,
+ labels: dict[str, str | int] = {},
+ debug: bool = False,
+ owned_loop: bool = True,
+ handle_kill: bool = True,
+ ) -> None:
+ self.name = name
+ self.config = config
+ self.max_runs = max_runs
+ self.debug = debug
+ self.labels = labels
+ self.handle_kill = handle_kill
+ self.owned_loop = owned_loop
+
+ self.client: Client
+
+ self.action_registry: dict[str, Callable[[Context], Any]] = {}
+ self.validator_registry: dict[str, WorkflowValidator] = {}
+
+ self.killing: bool = False
+ self._status: WorkerStatus
+
+ self.action_listener_process: BaseProcess
+ self.action_listener_health_check: asyncio.Task[Any]
+ self.action_runner: WorkerActionRunLoopManager
+
+ self.ctx = multiprocessing.get_context("spawn")
+
+ self.action_queue: "Queue[Any]" = self.ctx.Queue()
+ self.event_queue: "Queue[Any]" = self.ctx.Queue()
+
+ self.loop: asyncio.AbstractEventLoop
+
+ self.client = new_client_raw(self.config, self.debug)
+ self.name = self.client.config.namespace + self.name
+
+ self._setup_signal_handlers()
+
+ self.worker_status_gauge = Gauge(
+ "hatchet_worker_status", "Current status of the Hatchet worker"
+ )
+
+ def register_function(self, action: str, func: Callable[[Context], Any]) -> None:
+ self.action_registry[action] = func
+
+ def register_workflow_from_opts(
+ self, name: str, opts: CreateWorkflowVersionOpts
+ ) -> None:
+ try:
+ self.client.admin.put_workflow(opts.name, opts)
+ except Exception as e:
+ logger.error(f"failed to register workflow: {opts.name}")
+ logger.error(e)
+ sys.exit(1)
+
+ def register_workflow(self, workflow: TWorkflow) -> None:
+ ## Hack for typing
+ assert isinstance(workflow, WorkflowInterface)
+
+ namespace = self.client.config.namespace
+
+ try:
+ self.client.admin.put_workflow(
+ workflow.get_name(namespace), workflow.get_create_opts(namespace)
+ )
+ except Exception as e:
+ logger.error(f"failed to register workflow: {workflow.get_name(namespace)}")
+ logger.error(e)
+ sys.exit(1)
+
+ def create_action_function(
+ action_func: Callable[..., T]
+ ) -> Callable[[Context], T]:
+ def action_function(context: Context) -> T:
+ return action_func(workflow, context)
+
+ if asyncio.iscoroutinefunction(action_func):
+ setattr(action_function, "is_coroutine", True)
+ else:
+ setattr(action_function, "is_coroutine", False)
+
+ return action_function
+
+ for action_name, action_func in workflow.get_actions(namespace):
+ self.action_registry[action_name] = create_action_function(action_func)
+ return_type = get_type_hints(action_func).get("return")
+
+ self.validator_registry[action_name] = WorkflowValidator(
+ workflow_input=workflow.input_validator,
+ step_output=return_type if is_basemodel_subclass(return_type) else None,
+ )
+
+ def status(self) -> WorkerStatus:
+ return self._status
+
+ def setup_loop(self, loop: asyncio.AbstractEventLoop | None = None) -> bool:
+ try:
+ loop = loop or asyncio.get_running_loop()
+ self.loop = loop
+ created_loop = False
+ logger.debug("using existing event loop")
+ return created_loop
+ except RuntimeError:
+ self.loop = asyncio.new_event_loop()
+ logger.debug("creating new event loop")
+ asyncio.set_event_loop(self.loop)
+ created_loop = True
+ return created_loop
+
+ async def health_check_handler(self, request: Request) -> Response:
+ status = self.status()
+
+ return web.json_response({"status": status.name})
+
+ async def metrics_handler(self, request: Request) -> Response:
+ self.worker_status_gauge.set(1 if self.status() == WorkerStatus.HEALTHY else 0)
+
+ return web.Response(body=generate_latest(), content_type="text/plain")
+
+ async def start_health_server(self) -> None:
+ port = self.config.worker_healthcheck_port or 8001
+
+ app = web.Application()
+ app.add_routes(
+ [
+ web.get("/health", self.health_check_handler),
+ web.get("/metrics", self.metrics_handler),
+ ]
+ )
+
+ runner = web.AppRunner(app)
+
+ try:
+ await runner.setup()
+ await web.TCPSite(runner, "0.0.0.0", port).start()
+ except Exception as e:
+ logger.error("failed to start healthcheck server")
+ logger.error(str(e))
+ return
+
+ logger.info(f"healthcheck server running on port {port}")
+
+ def start(
+ self, options: WorkerStartOptions = WorkerStartOptions()
+ ) -> Future[asyncio.Task[Any] | None]:
+ self.owned_loop = self.setup_loop(options.loop)
+
+ f = asyncio.run_coroutine_threadsafe(
+ self.async_start(options, _from_start=True), self.loop
+ )
+
+ # start the loop and wait until its closed
+ if self.owned_loop:
+ self.loop.run_forever()
+
+ if self.handle_kill:
+ sys.exit(0)
+
+ return f
+
+ ## Start methods
+ async def async_start(
+ self,
+ options: WorkerStartOptions = WorkerStartOptions(),
+ _from_start: bool = False,
+ ) -> Any | None:
+ main_pid = os.getpid()
+ logger.info("------------------------------------------")
+ logger.info("STARTING HATCHET...")
+ logger.debug(f"worker runtime starting on PID: {main_pid}")
+
+ self._status = WorkerStatus.STARTING
+
+ if len(self.action_registry.keys()) == 0:
+ logger.error(
+ "no actions registered, register workflows or actions before starting worker"
+ )
+ return None
+
+ # non blocking setup
+ if not _from_start:
+ self.setup_loop(options.loop)
+
+ if self.config.worker_healthcheck_enabled:
+ await self.start_health_server()
+
+ self.action_listener_process = self._start_listener()
+
+ self.action_runner = self._run_action_runner()
+
+ self.action_listener_health_check = self.loop.create_task(
+ self._check_listener_health()
+ )
+
+ return await self.action_listener_health_check
+
+ def _run_action_runner(self) -> WorkerActionRunLoopManager:
+ # Retrieve the shared queue
+ return WorkerActionRunLoopManager(
+ self.name,
+ self.action_registry,
+ self.validator_registry,
+ self.max_runs,
+ self.config,
+ self.action_queue,
+ self.event_queue,
+ self.loop,
+ self.handle_kill,
+ self.client.debug,
+ self.labels,
+ )
+
+ def _start_listener(self) -> multiprocessing.context.SpawnProcess:
+ action_list = [str(key) for key in self.action_registry.keys()]
+
+ try:
+ process = self.ctx.Process(
+ target=worker_action_listener_process,
+ args=(
+ self.name,
+ action_list,
+ self.max_runs,
+ self.config,
+ self.action_queue,
+ self.event_queue,
+ self.handle_kill,
+ self.client.debug,
+ self.labels,
+ ),
+ )
+ process.start()
+ logger.debug(f"action listener starting on PID: {process.pid}")
+
+ return process
+ except Exception as e:
+ logger.error(f"failed to start action listener: {e}")
+ sys.exit(1)
+
+ async def _check_listener_health(self) -> None:
+ logger.debug("starting action listener health check...")
+ try:
+ while not self.killing:
+ if (
+ self.action_listener_process is None
+ or not self.action_listener_process.is_alive()
+ ):
+ logger.debug("child action listener process killed...")
+ self._status = WorkerStatus.UNHEALTHY
+ if not self.killing:
+ self.loop.create_task(self.exit_gracefully())
+ break
+ else:
+ self._status = WorkerStatus.HEALTHY
+ await asyncio.sleep(1)
+ except Exception as e:
+ logger.error(f"error checking listener health: {e}")
+
+ ## Cleanup methods
+ def _setup_signal_handlers(self) -> None:
+ signal.signal(signal.SIGTERM, self._handle_exit_signal)
+ signal.signal(signal.SIGINT, self._handle_exit_signal)
+ signal.signal(signal.SIGQUIT, self._handle_force_quit_signal)
+
+ def _handle_exit_signal(self, signum: int, frame: FrameType | None) -> None:
+ sig_name = "SIGTERM" if signum == signal.SIGTERM else "SIGINT"
+ logger.info(f"received signal {sig_name}...")
+ self.loop.create_task(self.exit_gracefully())
+
+ def _handle_force_quit_signal(self, signum: int, frame: FrameType | None) -> None:
+ logger.info("received SIGQUIT...")
+ self.exit_forcefully()
+
+ async def close(self) -> None:
+ logger.info(f"closing worker '{self.name}'...")
+ self.killing = True
+ # self.action_queue.close()
+ # self.event_queue.close()
+
+ if self.action_runner is not None:
+ self.action_runner.cleanup()
+
+ await self.action_listener_health_check
+
+ async def exit_gracefully(self) -> None:
+ logger.debug(f"gracefully stopping worker: {self.name}")
+
+ if self.killing:
+ return self.exit_forcefully()
+
+ self.killing = True
+
+ await self.action_runner.wait_for_tasks()
+
+ await self.action_runner.exit_gracefully()
+
+ if self.action_listener_process and self.action_listener_process.is_alive():
+ self.action_listener_process.kill()
+
+ await self.close()
+ if self.loop and self.owned_loop:
+ self.loop.stop()
+
+ logger.info("👋")
+
+ def exit_forcefully(self) -> None:
+ self.killing = True
+
+ logger.debug(f"forcefully stopping worker: {self.name}")
+
+ self.close()
+
+ if self.action_listener_process:
+ self.action_listener_process.kill() # Forcefully kill the process
+
+ logger.info("👋")
+ sys.exit(
+ 1
+ ) # Exit immediately TODO - should we exit with 1 here, there may be other workers to cleanup
+
+
+def register_on_worker(callable: HatchetCallable[T], worker: Worker) -> None:
+ worker.register_function(callable.get_action_name(), callable)
+
+ if callable.function_on_failure is not None:
+ worker.register_function(
+ callable.function_on_failure.get_action_name(), callable.function_on_failure
+ )
+
+ if callable.function_concurrency is not None:
+ worker.register_function(
+ callable.function_concurrency.get_action_name(),
+ callable.function_concurrency,
+ )
+
+ opts = callable.to_workflow_opts()
+
+ worker.register_workflow_from_opts(opts.name, opts)