about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/run_loop_manager.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/run_loop_manager.py')
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/run_loop_manager.py112
1 files changed, 112 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/run_loop_manager.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/run_loop_manager.py
new file mode 100644
index 00000000..27ed788c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/run_loop_manager.py
@@ -0,0 +1,112 @@
+import asyncio
+import logging
+from dataclasses import dataclass, field
+from multiprocessing import Queue
+from typing import Callable, TypeVar
+
+from hatchet_sdk import Context
+from hatchet_sdk.client import Client, new_client_raw
+from hatchet_sdk.clients.dispatcher.action_listener import Action
+from hatchet_sdk.loader import ClientConfig
+from hatchet_sdk.logger import logger
+from hatchet_sdk.utils.types import WorkflowValidator
+from hatchet_sdk.worker.runner.runner import Runner
+from hatchet_sdk.worker.runner.utils.capture_logs import capture_logs
+
+STOP_LOOP = "STOP_LOOP"
+
+T = TypeVar("T")
+
+
+@dataclass
+class WorkerActionRunLoopManager:
+    name: str
+    action_registry: dict[str, Callable[[Context], T]]
+    validator_registry: dict[str, WorkflowValidator]
+    max_runs: int | None
+    config: ClientConfig
+    action_queue: Queue
+    event_queue: Queue
+    loop: asyncio.AbstractEventLoop
+    handle_kill: bool = True
+    debug: bool = False
+    labels: dict[str, str | int] = field(default_factory=dict)
+
+    client: Client = field(init=False, default=None)
+
+    killing: bool = field(init=False, default=False)
+    runner: Runner = field(init=False, default=None)
+
+    def __post_init__(self):
+        if self.debug:
+            logger.setLevel(logging.DEBUG)
+        self.client = new_client_raw(self.config, self.debug)
+        self.start()
+
+    def start(self, retry_count=1):
+        k = self.loop.create_task(self.async_start(retry_count))
+
+    async def async_start(self, retry_count=1):
+        await capture_logs(
+            self.client.logInterceptor,
+            self.client.event,
+            self._async_start,
+        )(retry_count=retry_count)
+
+    async def _async_start(self, retry_count: int = 1) -> None:
+        logger.info("starting runner...")
+        self.loop = asyncio.get_running_loop()
+        # needed for graceful termination
+        k = self.loop.create_task(self._start_action_loop())
+        await k
+
+    def cleanup(self) -> None:
+        self.killing = True
+
+        self.action_queue.put(STOP_LOOP)
+
+    async def wait_for_tasks(self) -> None:
+        if self.runner:
+            await self.runner.wait_for_tasks()
+
+    async def _start_action_loop(self) -> None:
+        self.runner = Runner(
+            self.name,
+            self.event_queue,
+            self.max_runs,
+            self.handle_kill,
+            self.action_registry,
+            self.validator_registry,
+            self.config,
+            self.labels,
+        )
+
+        logger.debug(f"'{self.name}' waiting for {list(self.action_registry.keys())}")
+        while not self.killing:
+            action: Action = await self._get_action()
+            if action == STOP_LOOP:
+                logger.debug("stopping action runner loop...")
+                break
+
+            self.runner.run(action)
+        logger.debug("action runner loop stopped")
+
+    async def _get_action(self):
+        return await self.loop.run_in_executor(None, self.action_queue.get)
+
+    async def exit_gracefully(self) -> None:
+        if self.killing:
+            return
+
+        logger.info("gracefully exiting runner...")
+
+        self.cleanup()
+
+        # Wait for 1 second to allow last calls to flush. These are calls which have been
+        # added to the event loop as callbacks to tasks, so we're not aware of them in the
+        # task list.
+        await asyncio.sleep(1)
+
+    def exit_forcefully(self) -> None:
+        logger.info("forcefully exiting runner...")
+        self.cleanup()