diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/hatchet_sdk/v2/callable.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/v2/callable.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/hatchet_sdk/v2/callable.py | 202 |
1 files changed, 202 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/v2/callable.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/v2/callable.py new file mode 100644 index 00000000..097a7d87 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/v2/callable.py @@ -0,0 +1,202 @@ +import asyncio +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Optional, + TypedDict, + TypeVar, + Union, +) + +from hatchet_sdk.clients.admin import ChildTriggerWorkflowOptions +from hatchet_sdk.context.context import Context +from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined] + CreateStepRateLimit, + CreateWorkflowJobOpts, + CreateWorkflowStepOpts, + CreateWorkflowVersionOpts, + DesiredWorkerLabels, + StickyStrategy, + WorkflowConcurrencyOpts, + WorkflowKind, +) +from hatchet_sdk.labels import DesiredWorkerLabel +from hatchet_sdk.logger import logger +from hatchet_sdk.rate_limit import RateLimit +from hatchet_sdk.v2.concurrency import ConcurrencyFunction +from hatchet_sdk.workflow_run import RunRef + +T = TypeVar("T") + + +class HatchetCallable(Generic[T]): + def __init__( + self, + func: Callable[[Context], T], + durable: bool = False, + name: str = "", + auto_register: bool = True, + on_events: list[str] | None = None, + on_crons: list[str] | None = None, + version: str = "", + timeout: str = "60m", + schedule_timeout: str = "5m", + sticky: StickyStrategy = None, + retries: int = 0, + rate_limits: List[RateLimit] | None = None, + concurrency: ConcurrencyFunction | None = None, + on_failure: Union["HatchetCallable[T]", None] = None, + desired_worker_labels: dict[str, DesiredWorkerLabel] = {}, + default_priority: int | None = None, + ): + self.func = func + + on_events = on_events or [] + on_crons = on_crons or [] + + limits = None + if rate_limits: + limits = [rate_limit._req for rate_limit in rate_limits or []] + + self.function_desired_worker_labels = {} + + for key, d in desired_worker_labels.items(): + value = d["value"] if "value" in d else None + self.function_desired_worker_labels[key] = DesiredWorkerLabels( + strValue=str(value) if not isinstance(value, int) else None, + intValue=value if isinstance(value, int) else None, + required=d["required"] if "required" in d else None, + weight=d["weight"] if "weight" in d else None, + comparator=d["comparator"] if "comparator" in d else None, + ) + self.sticky = sticky + self.default_priority = default_priority + self.durable = durable + self.function_name = name.lower() or str(func.__name__).lower() + self.function_version = version + self.function_on_events = on_events + self.function_on_crons = on_crons + self.function_timeout = timeout + self.function_schedule_timeout = schedule_timeout + self.function_retries = retries + self.function_rate_limits = limits + self.function_concurrency = concurrency + self.function_on_failure = on_failure + self.function_namespace = "default" + self.function_auto_register = auto_register + + self.is_coroutine = False + + if asyncio.iscoroutinefunction(func): + self.is_coroutine = True + + def __call__(self, context: Context) -> T: + return self.func(context) + + def with_namespace(self, namespace: str) -> None: + if namespace is not None and namespace != "": + self.function_namespace = namespace + self.function_name = namespace + self.function_name + + def to_workflow_opts(self) -> CreateWorkflowVersionOpts: + kind: WorkflowKind = WorkflowKind.FUNCTION + + if self.durable: + kind = WorkflowKind.DURABLE + + on_failure_job: CreateWorkflowJobOpts | None = None + + if self.function_on_failure is not None: + on_failure_job = CreateWorkflowJobOpts( + name=self.function_name + "-on-failure", + steps=[ + self.function_on_failure.to_step(), + ], + ) + + concurrency: WorkflowConcurrencyOpts | None = None + + if self.function_concurrency is not None: + self.function_concurrency.set_namespace(self.function_namespace) + concurrency = WorkflowConcurrencyOpts( + action=self.function_concurrency.get_action_name(), + max_runs=self.function_concurrency.max_runs, + limit_strategy=self.function_concurrency.limit_strategy, + ) + + validated_priority = ( + max(1, min(3, self.default_priority)) if self.default_priority else None + ) + if validated_priority != self.default_priority: + logger.warning( + "Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range." + ) + + return CreateWorkflowVersionOpts( + name=self.function_name, + kind=kind, + version=self.function_version, + event_triggers=self.function_on_events, + cron_triggers=self.function_on_crons, + schedule_timeout=self.function_schedule_timeout, + sticky=self.sticky, + on_failure_job=on_failure_job, + concurrency=concurrency, + jobs=[ + CreateWorkflowJobOpts( + name=self.function_name, + steps=[ + self.to_step(), + ], + ) + ], + default_priority=validated_priority, + ) + + def to_step(self) -> CreateWorkflowStepOpts: + return CreateWorkflowStepOpts( + readable_id=self.function_name, + action=self.get_action_name(), + timeout=self.function_timeout, + inputs="{}", + parents=[], + retries=self.function_retries, + rate_limits=self.function_rate_limits, + worker_labels=self.function_desired_worker_labels, + ) + + def get_action_name(self) -> str: + return self.function_namespace + ":" + self.function_name + + +class DurableContext(Context): + def run( + self, + function: str | Callable[[Context], Any], + input: dict[Any, Any] = {}, + key: str | None = None, + options: ChildTriggerWorkflowOptions | None = None, + ) -> "RunRef[T]": + worker_id = self.worker.id() + + workflow_name = function + + if not isinstance(function, str): + workflow_name = function.function_name + + # if ( + # options is not None + # and "sticky" in options + # and options["sticky"] == True + # and not self.worker.has_workflow(workflow_name) + # ): + # raise Exception( + # f"cannot run with sticky: workflow {workflow_name} is not registered on the worker" + # ) + + trigger_options = self._prepare_workflow_options(key, options, worker_id) + + return self.admin_client.run(function, input, trigger_options) |