aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/hatchet_sdk/v2/callable.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/hatchet_sdk/v2/callable.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/v2/callable.py')
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/v2/callable.py202
1 files changed, 202 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/v2/callable.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/v2/callable.py
new file mode 100644
index 00000000..097a7d87
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/v2/callable.py
@@ -0,0 +1,202 @@
+import asyncio
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Generic,
+ List,
+ Optional,
+ TypedDict,
+ TypeVar,
+ Union,
+)
+
+from hatchet_sdk.clients.admin import ChildTriggerWorkflowOptions
+from hatchet_sdk.context.context import Context
+from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined]
+ CreateStepRateLimit,
+ CreateWorkflowJobOpts,
+ CreateWorkflowStepOpts,
+ CreateWorkflowVersionOpts,
+ DesiredWorkerLabels,
+ StickyStrategy,
+ WorkflowConcurrencyOpts,
+ WorkflowKind,
+)
+from hatchet_sdk.labels import DesiredWorkerLabel
+from hatchet_sdk.logger import logger
+from hatchet_sdk.rate_limit import RateLimit
+from hatchet_sdk.v2.concurrency import ConcurrencyFunction
+from hatchet_sdk.workflow_run import RunRef
+
+T = TypeVar("T")
+
+
+class HatchetCallable(Generic[T]):
+ def __init__(
+ self,
+ func: Callable[[Context], T],
+ durable: bool = False,
+ name: str = "",
+ auto_register: bool = True,
+ on_events: list[str] | None = None,
+ on_crons: list[str] | None = None,
+ version: str = "",
+ timeout: str = "60m",
+ schedule_timeout: str = "5m",
+ sticky: StickyStrategy = None,
+ retries: int = 0,
+ rate_limits: List[RateLimit] | None = None,
+ concurrency: ConcurrencyFunction | None = None,
+ on_failure: Union["HatchetCallable[T]", None] = None,
+ desired_worker_labels: dict[str, DesiredWorkerLabel] = {},
+ default_priority: int | None = None,
+ ):
+ self.func = func
+
+ on_events = on_events or []
+ on_crons = on_crons or []
+
+ limits = None
+ if rate_limits:
+ limits = [rate_limit._req for rate_limit in rate_limits or []]
+
+ self.function_desired_worker_labels = {}
+
+ for key, d in desired_worker_labels.items():
+ value = d["value"] if "value" in d else None
+ self.function_desired_worker_labels[key] = DesiredWorkerLabels(
+ strValue=str(value) if not isinstance(value, int) else None,
+ intValue=value if isinstance(value, int) else None,
+ required=d["required"] if "required" in d else None,
+ weight=d["weight"] if "weight" in d else None,
+ comparator=d["comparator"] if "comparator" in d else None,
+ )
+ self.sticky = sticky
+ self.default_priority = default_priority
+ self.durable = durable
+ self.function_name = name.lower() or str(func.__name__).lower()
+ self.function_version = version
+ self.function_on_events = on_events
+ self.function_on_crons = on_crons
+ self.function_timeout = timeout
+ self.function_schedule_timeout = schedule_timeout
+ self.function_retries = retries
+ self.function_rate_limits = limits
+ self.function_concurrency = concurrency
+ self.function_on_failure = on_failure
+ self.function_namespace = "default"
+ self.function_auto_register = auto_register
+
+ self.is_coroutine = False
+
+ if asyncio.iscoroutinefunction(func):
+ self.is_coroutine = True
+
+ def __call__(self, context: Context) -> T:
+ return self.func(context)
+
+ def with_namespace(self, namespace: str) -> None:
+ if namespace is not None and namespace != "":
+ self.function_namespace = namespace
+ self.function_name = namespace + self.function_name
+
+ def to_workflow_opts(self) -> CreateWorkflowVersionOpts:
+ kind: WorkflowKind = WorkflowKind.FUNCTION
+
+ if self.durable:
+ kind = WorkflowKind.DURABLE
+
+ on_failure_job: CreateWorkflowJobOpts | None = None
+
+ if self.function_on_failure is not None:
+ on_failure_job = CreateWorkflowJobOpts(
+ name=self.function_name + "-on-failure",
+ steps=[
+ self.function_on_failure.to_step(),
+ ],
+ )
+
+ concurrency: WorkflowConcurrencyOpts | None = None
+
+ if self.function_concurrency is not None:
+ self.function_concurrency.set_namespace(self.function_namespace)
+ concurrency = WorkflowConcurrencyOpts(
+ action=self.function_concurrency.get_action_name(),
+ max_runs=self.function_concurrency.max_runs,
+ limit_strategy=self.function_concurrency.limit_strategy,
+ )
+
+ validated_priority = (
+ max(1, min(3, self.default_priority)) if self.default_priority else None
+ )
+ if validated_priority != self.default_priority:
+ logger.warning(
+ "Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range."
+ )
+
+ return CreateWorkflowVersionOpts(
+ name=self.function_name,
+ kind=kind,
+ version=self.function_version,
+ event_triggers=self.function_on_events,
+ cron_triggers=self.function_on_crons,
+ schedule_timeout=self.function_schedule_timeout,
+ sticky=self.sticky,
+ on_failure_job=on_failure_job,
+ concurrency=concurrency,
+ jobs=[
+ CreateWorkflowJobOpts(
+ name=self.function_name,
+ steps=[
+ self.to_step(),
+ ],
+ )
+ ],
+ default_priority=validated_priority,
+ )
+
+ def to_step(self) -> CreateWorkflowStepOpts:
+ return CreateWorkflowStepOpts(
+ readable_id=self.function_name,
+ action=self.get_action_name(),
+ timeout=self.function_timeout,
+ inputs="{}",
+ parents=[],
+ retries=self.function_retries,
+ rate_limits=self.function_rate_limits,
+ worker_labels=self.function_desired_worker_labels,
+ )
+
+ def get_action_name(self) -> str:
+ return self.function_namespace + ":" + self.function_name
+
+
+class DurableContext(Context):
+ def run(
+ self,
+ function: str | Callable[[Context], Any],
+ input: dict[Any, Any] = {},
+ key: str | None = None,
+ options: ChildTriggerWorkflowOptions | None = None,
+ ) -> "RunRef[T]":
+ worker_id = self.worker.id()
+
+ workflow_name = function
+
+ if not isinstance(function, str):
+ workflow_name = function.function_name
+
+ # if (
+ # options is not None
+ # and "sticky" in options
+ # and options["sticky"] == True
+ # and not self.worker.has_workflow(workflow_name)
+ # ):
+ # raise Exception(
+ # f"cannot run with sticky: workflow {workflow_name} is not registered on the worker"
+ # )
+
+ trigger_options = self._prepare_workflow_options(key, options, worker_id)
+
+ return self.admin_client.run(function, input, trigger_options)