aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/hatchet_sdk/worker/runner/run_loop_manager.py
blob: 27ed788cb56b9da00ecca3f650a3ea6ba1d0508d (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import asyncio
import logging
from dataclasses import dataclass, field
from multiprocessing import Queue
from typing import Callable, TypeVar

from hatchet_sdk import Context
from hatchet_sdk.client import Client, new_client_raw
from hatchet_sdk.clients.dispatcher.action_listener import Action
from hatchet_sdk.loader import ClientConfig
from hatchet_sdk.logger import logger
from hatchet_sdk.utils.types import WorkflowValidator
from hatchet_sdk.worker.runner.runner import Runner
from hatchet_sdk.worker.runner.utils.capture_logs import capture_logs

STOP_LOOP = "STOP_LOOP"

T = TypeVar("T")


@dataclass
class WorkerActionRunLoopManager:
    name: str
    action_registry: dict[str, Callable[[Context], T]]
    validator_registry: dict[str, WorkflowValidator]
    max_runs: int | None
    config: ClientConfig
    action_queue: Queue
    event_queue: Queue
    loop: asyncio.AbstractEventLoop
    handle_kill: bool = True
    debug: bool = False
    labels: dict[str, str | int] = field(default_factory=dict)

    client: Client = field(init=False, default=None)

    killing: bool = field(init=False, default=False)
    runner: Runner = field(init=False, default=None)

    def __post_init__(self):
        if self.debug:
            logger.setLevel(logging.DEBUG)
        self.client = new_client_raw(self.config, self.debug)
        self.start()

    def start(self, retry_count=1):
        k = self.loop.create_task(self.async_start(retry_count))

    async def async_start(self, retry_count=1):
        await capture_logs(
            self.client.logInterceptor,
            self.client.event,
            self._async_start,
        )(retry_count=retry_count)

    async def _async_start(self, retry_count: int = 1) -> None:
        logger.info("starting runner...")
        self.loop = asyncio.get_running_loop()
        # needed for graceful termination
        k = self.loop.create_task(self._start_action_loop())
        await k

    def cleanup(self) -> None:
        self.killing = True

        self.action_queue.put(STOP_LOOP)

    async def wait_for_tasks(self) -> None:
        if self.runner:
            await self.runner.wait_for_tasks()

    async def _start_action_loop(self) -> None:
        self.runner = Runner(
            self.name,
            self.event_queue,
            self.max_runs,
            self.handle_kill,
            self.action_registry,
            self.validator_registry,
            self.config,
            self.labels,
        )

        logger.debug(f"'{self.name}' waiting for {list(self.action_registry.keys())}")
        while not self.killing:
            action: Action = await self._get_action()
            if action == STOP_LOOP:
                logger.debug("stopping action runner loop...")
                break

            self.runner.run(action)
        logger.debug("action runner loop stopped")

    async def _get_action(self):
        return await self.loop.run_in_executor(None, self.action_queue.get)

    async def exit_gracefully(self) -> None:
        if self.killing:
            return

        logger.info("gracefully exiting runner...")

        self.cleanup()

        # Wait for 1 second to allow last calls to flush. These are calls which have been
        # added to the event loop as callbacks to tasks, so we're not aware of them in the
        # task list.
        await asyncio.sleep(1)

    def exit_forcefully(self) -> None:
        logger.info("forcefully exiting runner...")
        self.cleanup()