diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/hatchet.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/hatchet_sdk/hatchet.py | 310 |
1 files changed, 310 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/hatchet.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/hatchet.py new file mode 100644 index 00000000..bf0e9089 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/hatchet.py @@ -0,0 +1,310 @@ +import asyncio +import logging +from typing import Any, Callable, Optional, ParamSpec, Type, TypeVar, Union + +from pydantic import BaseModel +from typing_extensions import deprecated + +from hatchet_sdk.clients.rest_client import RestApi +from hatchet_sdk.context.context import Context +from hatchet_sdk.contracts.workflows_pb2 import ( + ConcurrencyLimitStrategy, + CreateStepRateLimit, + DesiredWorkerLabels, + StickyStrategy, +) +from hatchet_sdk.features.cron import CronClient +from hatchet_sdk.features.scheduled import ScheduledClient +from hatchet_sdk.labels import DesiredWorkerLabel +from hatchet_sdk.loader import ClientConfig, ConfigLoader +from hatchet_sdk.rate_limit import RateLimit +from hatchet_sdk.v2.callable import HatchetCallable + +from .client import Client, new_client, new_client_raw +from .clients.admin import AdminClient +from .clients.dispatcher.dispatcher import DispatcherClient +from .clients.events import EventClient +from .clients.run_event_listener import RunEventListenerClient +from .logger import logger +from .worker.worker import Worker +from .workflow import ( + ConcurrencyExpression, + WorkflowInterface, + WorkflowMeta, + WorkflowStepProtocol, +) + +T = TypeVar("T", bound=BaseModel) +R = TypeVar("R") +P = ParamSpec("P") + +TWorkflow = TypeVar("TWorkflow", bound=object) + + +def workflow( + name: str = "", + on_events: list[str] | None = None, + on_crons: list[str] | None = None, + version: str = "", + timeout: str = "60m", + schedule_timeout: str = "5m", + sticky: Union[StickyStrategy.Value, None] = None, # type: ignore[name-defined] + default_priority: int | None = None, + concurrency: ConcurrencyExpression | None = None, + input_validator: Type[T] | None = None, +) -> Callable[[Type[TWorkflow]], WorkflowMeta]: + on_events = on_events or [] + on_crons = on_crons or [] + + def inner(cls: Type[TWorkflow]) -> WorkflowMeta: + nonlocal name + name = name or str(cls.__name__) + + setattr(cls, "on_events", on_events) + setattr(cls, "on_crons", on_crons) + setattr(cls, "name", name) + setattr(cls, "version", version) + setattr(cls, "timeout", timeout) + setattr(cls, "schedule_timeout", schedule_timeout) + setattr(cls, "sticky", sticky) + setattr(cls, "default_priority", default_priority) + setattr(cls, "concurrency_expression", concurrency) + + # Define a new class with the same name and bases as the original, but + # with WorkflowMeta as its metaclass + + ## TODO: Figure out how to type this metaclass correctly + setattr(cls, "input_validator", input_validator) + + return WorkflowMeta(name, cls.__bases__, dict(cls.__dict__)) + + return inner + + +def step( + name: str = "", + timeout: str = "", + parents: list[str] | None = None, + retries: int = 0, + rate_limits: list[RateLimit] | None = None, + desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, + backoff_factor: float | None = None, + backoff_max_seconds: int | None = None, +) -> Callable[[Callable[P, R]], Callable[P, R]]: + parents = parents or [] + + def inner(func: Callable[P, R]) -> Callable[P, R]: + limits = None + if rate_limits: + limits = [rate_limit._req for rate_limit in rate_limits or []] + + setattr(func, "_step_name", name.lower() or str(func.__name__).lower()) + setattr(func, "_step_parents", parents) + setattr(func, "_step_timeout", timeout) + setattr(func, "_step_retries", retries) + setattr(func, "_step_rate_limits", limits) + setattr(func, "_step_backoff_factor", backoff_factor) + setattr(func, "_step_backoff_max_seconds", backoff_max_seconds) + + def create_label(d: DesiredWorkerLabel) -> DesiredWorkerLabels: + value = d["value"] if "value" in d else None + return DesiredWorkerLabels( + strValue=str(value) if not isinstance(value, int) else None, + intValue=value if isinstance(value, int) else None, + required=d["required"] if "required" in d else None, # type: ignore[arg-type] + weight=d["weight"] if "weight" in d else None, + comparator=d["comparator"] if "comparator" in d else None, # type: ignore[arg-type] + ) + + setattr( + func, + "_step_desired_worker_labels", + {key: create_label(d) for key, d in desired_worker_labels.items()}, + ) + + return func + + return inner + + +def on_failure_step( + name: str = "", + timeout: str = "", + retries: int = 0, + rate_limits: list[RateLimit] | None = None, + backoff_factor: float | None = None, + backoff_max_seconds: int | None = None, +) -> Callable[..., Any]: + def inner(func: Callable[[Context], Any]) -> Callable[[Context], Any]: + limits = None + if rate_limits: + limits = [ + CreateStepRateLimit(key=rate_limit.static_key, units=rate_limit.units) # type: ignore[arg-type] + for rate_limit in rate_limits or [] + ] + + setattr( + func, "_on_failure_step_name", name.lower() or str(func.__name__).lower() + ) + setattr(func, "_on_failure_step_timeout", timeout) + setattr(func, "_on_failure_step_retries", retries) + setattr(func, "_on_failure_step_rate_limits", limits) + setattr(func, "_on_failure_step_backoff_factor", backoff_factor) + setattr(func, "_on_failure_step_backoff_max_seconds", backoff_max_seconds) + + return func + + return inner + + +def concurrency( + name: str = "", + max_runs: int = 1, + limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS, +) -> Callable[..., Any]: + def inner(func: Callable[[Context], Any]) -> Callable[[Context], Any]: + setattr( + func, + "_concurrency_fn_name", + name.lower() or str(func.__name__).lower(), + ) + setattr(func, "_concurrency_max_runs", max_runs) + setattr(func, "_concurrency_limit_strategy", limit_strategy) + + return func + + return inner + + +class HatchetRest: + """ + Main client for interacting with the Hatchet API. + + This class provides access to various client interfaces and utility methods + for working with Hatchet via the REST API, + + Attributes: + rest (RestApi): Interface for REST API operations. + """ + + rest: RestApi + + def __init__(self, config: ClientConfig = ClientConfig()): + _config: ClientConfig = ConfigLoader(".").load_client_config(config) + self.rest = RestApi(_config.server_url, _config.token, _config.tenant_id) + + +class Hatchet: + """ + Main client for interacting with the Hatchet SDK. + + This class provides access to various client interfaces and utility methods + for working with Hatchet workers, workflows, and steps. + + Attributes: + cron (CronClient): Interface for cron trigger operations. + + admin (AdminClient): Interface for administrative operations. + dispatcher (DispatcherClient): Interface for dispatching operations. + event (EventClient): Interface for event-related operations. + rest (RestApi): Interface for REST API operations. + """ + + _client: Client + cron: CronClient + scheduled: ScheduledClient + + @classmethod + def from_environment( + cls, defaults: ClientConfig = ClientConfig(), **kwargs: Any + ) -> "Hatchet": + return cls(client=new_client(defaults), **kwargs) + + @classmethod + def from_config(cls, config: ClientConfig, **kwargs: Any) -> "Hatchet": + return cls(client=new_client_raw(config), **kwargs) + + def __init__( + self, + debug: bool = False, + client: Optional[Client] = None, + config: ClientConfig = ClientConfig(), + ): + """ + Initialize a new Hatchet instance. + + Args: + debug (bool, optional): Enable debug logging. Defaults to False. + client (Optional[Client], optional): A pre-configured Client instance. Defaults to None. + config (ClientConfig, optional): Configuration for creating a new Client. Defaults to ClientConfig(). + """ + if client is not None: + self._client = client + else: + self._client = new_client(config, debug) + + if debug: + logger.setLevel(logging.DEBUG) + + self.cron = CronClient(self._client) + self.scheduled = ScheduledClient(self._client) + + @property + @deprecated( + "Direct access to client is deprecated and will be removed in a future version. Use specific client properties (Hatchet.admin, Hatchet.dispatcher, Hatchet.event, Hatchet.rest) instead. [0.32.0]", + ) + def client(self) -> Client: + return self._client + + @property + def admin(self) -> AdminClient: + return self._client.admin + + @property + def dispatcher(self) -> DispatcherClient: + return self._client.dispatcher + + @property + def event(self) -> EventClient: + return self._client.event + + @property + def rest(self) -> RestApi: + return self._client.rest + + @property + def listener(self) -> RunEventListenerClient: + return self._client.listener + + @property + def config(self) -> ClientConfig: + return self._client.config + + @property + def tenant_id(self) -> str: + return self._client.config.tenant_id + + concurrency = staticmethod(concurrency) + + workflow = staticmethod(workflow) + + step = staticmethod(step) + + on_failure_step = staticmethod(on_failure_step) + + def worker( + self, name: str, max_runs: int | None = None, labels: dict[str, str | int] = {} + ) -> Worker: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + return Worker( + name=name, + max_runs=max_runs, + labels=labels, + config=self._client.config, + debug=self._client.debug, + owned_loop=loop is None, + ) |