aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/hatchet_sdk/workflow.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/hatchet_sdk/workflow.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/workflow.py')
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/workflow.py261
1 files changed, 261 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/workflow.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/workflow.py
new file mode 100644
index 00000000..9c5cef90
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/workflow.py
@@ -0,0 +1,261 @@
+import functools
+from typing import (
+ Any,
+ Callable,
+ Protocol,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+ get_type_hints,
+ runtime_checkable,
+)
+
+from pydantic import BaseModel
+
+from hatchet_sdk import ConcurrencyLimitStrategy
+from hatchet_sdk.contracts.workflows_pb2 import (
+ CreateWorkflowJobOpts,
+ CreateWorkflowStepOpts,
+ CreateWorkflowVersionOpts,
+ StickyStrategy,
+ WorkflowConcurrencyOpts,
+ WorkflowKind,
+)
+from hatchet_sdk.logger import logger
+from hatchet_sdk.utils.typing import is_basemodel_subclass
+
+
+class WorkflowStepProtocol(Protocol):
+ def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
+
+ __name__: str
+
+ _step_name: str
+ _step_timeout: str | None
+ _step_parents: list[str]
+ _step_retries: int | None
+ _step_rate_limits: list[str] | None
+ _step_desired_worker_labels: dict[str, str]
+ _step_backoff_factor: float | None
+ _step_backoff_max_seconds: int | None
+
+ _concurrency_fn_name: str
+ _concurrency_max_runs: int | None
+ _concurrency_limit_strategy: str | None
+
+ _on_failure_step_name: str
+ _on_failure_step_timeout: str | None
+ _on_failure_step_retries: int
+ _on_failure_step_rate_limits: list[str] | None
+ _on_failure_step_backoff_factor: float | None
+ _on_failure_step_backoff_max_seconds: int | None
+
+
+StepsType = list[tuple[str, WorkflowStepProtocol]]
+
+T = TypeVar("T")
+TW = TypeVar("TW", bound="WorkflowInterface")
+
+
+class ConcurrencyExpression:
+ """
+ Defines concurrency limits for a workflow using a CEL expression.
+
+ Args:
+ expression (str): CEL expression to determine concurrency grouping. (i.e. "input.user_id")
+ max_runs (int): Maximum number of concurrent workflow runs.
+ limit_strategy (ConcurrencyLimitStrategy): Strategy for handling limit violations.
+
+ Example:
+ ConcurrencyExpression("input.user_id", 5, ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS)
+ """
+
+ def __init__(
+ self, expression: str, max_runs: int, limit_strategy: ConcurrencyLimitStrategy
+ ):
+ self.expression = expression
+ self.max_runs = max_runs
+ self.limit_strategy = limit_strategy
+
+
+@runtime_checkable
+class WorkflowInterface(Protocol):
+ def get_name(self, namespace: str) -> str: ...
+
+ def get_actions(self, namespace: str) -> list[tuple[str, Callable[..., Any]]]: ...
+
+ def get_create_opts(self, namespace: str) -> Any: ...
+
+ on_events: list[str] | None
+ on_crons: list[str] | None
+ name: str
+ version: str
+ timeout: str
+ schedule_timeout: str
+ sticky: Union[StickyStrategy.Value, None] # type: ignore[name-defined]
+ default_priority: int | None
+ concurrency_expression: ConcurrencyExpression | None
+ input_validator: Type[BaseModel] | None
+
+
+class WorkflowMeta(type):
+ def __new__(
+ cls: Type["WorkflowMeta"],
+ name: str,
+ bases: tuple[type, ...],
+ attrs: dict[str, Any],
+ ) -> "WorkflowMeta":
+ def _create_steps_actions_list(name: str) -> StepsType:
+ return [
+ (getattr(func, name), attrs.pop(func_name))
+ for func_name, func in list(attrs.items())
+ if hasattr(func, name)
+ ]
+
+ concurrencyActions = _create_steps_actions_list("_concurrency_fn_name")
+ steps = _create_steps_actions_list("_step_name")
+
+ onFailureSteps = _create_steps_actions_list("_on_failure_step_name")
+
+ # Define __init__ and get_step_order methods
+ original_init = attrs.get("__init__") # Get the original __init__ if it exists
+
+ def __init__(self: TW, *args: Any, **kwargs: Any) -> None:
+ if original_init:
+ original_init(self, *args, **kwargs) # Call original __init__
+
+ def get_service_name(namespace: str) -> str:
+ return f"{namespace}{name.lower()}"
+
+ @functools.cache
+ def get_actions(self: TW, namespace: str) -> StepsType:
+ serviceName = get_service_name(namespace)
+
+ func_actions = [
+ (serviceName + ":" + func_name, func) for func_name, func in steps
+ ]
+ concurrency_actions = [
+ (serviceName + ":" + func_name, func)
+ for func_name, func in concurrencyActions
+ ]
+ onFailure_actions = [
+ (serviceName + ":" + func_name, func)
+ for func_name, func in onFailureSteps
+ ]
+
+ return func_actions + concurrency_actions + onFailure_actions
+
+ # Add these methods and steps to class attributes
+ attrs["__init__"] = __init__
+ attrs["get_actions"] = get_actions
+
+ for step_name, step_func in steps:
+ attrs[step_name] = step_func
+
+ def get_name(self: TW, namespace: str) -> str:
+ return namespace + cast(str, attrs["name"])
+
+ attrs["get_name"] = get_name
+
+ cron_triggers = attrs["on_crons"]
+ version = attrs["version"]
+ schedule_timeout = attrs["schedule_timeout"]
+ sticky = attrs["sticky"]
+ default_priority = attrs["default_priority"]
+
+ @functools.cache
+ def get_create_opts(self: TW, namespace: str) -> CreateWorkflowVersionOpts:
+ serviceName = get_service_name(namespace)
+ name = self.get_name(namespace)
+ event_triggers = [namespace + event for event in attrs["on_events"]]
+ createStepOpts: list[CreateWorkflowStepOpts] = [
+ CreateWorkflowStepOpts(
+ readable_id=step_name,
+ action=serviceName + ":" + step_name,
+ timeout=func._step_timeout or "60s",
+ inputs="{}",
+ parents=[x for x in func._step_parents],
+ retries=func._step_retries,
+ rate_limits=func._step_rate_limits, # type: ignore[arg-type]
+ worker_labels=func._step_desired_worker_labels, # type: ignore[arg-type]
+ backoff_factor=func._step_backoff_factor,
+ backoff_max_seconds=func._step_backoff_max_seconds,
+ )
+ for step_name, func in steps
+ ]
+
+ concurrency: WorkflowConcurrencyOpts | None = None
+
+ if len(concurrencyActions) > 0:
+ action = concurrencyActions[0]
+
+ concurrency = WorkflowConcurrencyOpts(
+ action=serviceName + ":" + action[0],
+ max_runs=action[1]._concurrency_max_runs,
+ limit_strategy=action[1]._concurrency_limit_strategy,
+ )
+
+ if self.concurrency_expression:
+ concurrency = WorkflowConcurrencyOpts(
+ expression=self.concurrency_expression.expression,
+ max_runs=self.concurrency_expression.max_runs,
+ limit_strategy=self.concurrency_expression.limit_strategy,
+ )
+
+ if len(concurrencyActions) > 0 and self.concurrency_expression:
+ raise ValueError(
+ "Error: Both concurrencyActions and concurrency_expression are defined. Please use only one concurrency configuration method."
+ )
+
+ on_failure_job: CreateWorkflowJobOpts | None = None
+
+ if len(onFailureSteps) > 0:
+ func_name, func = onFailureSteps[0]
+ on_failure_job = CreateWorkflowJobOpts(
+ name=name + "-on-failure",
+ steps=[
+ CreateWorkflowStepOpts(
+ readable_id=func_name,
+ action=serviceName + ":" + func_name,
+ timeout=func._on_failure_step_timeout or "60s",
+ inputs="{}",
+ parents=[],
+ retries=func._on_failure_step_retries,
+ rate_limits=func._on_failure_step_rate_limits, # type: ignore[arg-type]
+ backoff_factor=func._on_failure_step_backoff_factor,
+ backoff_max_seconds=func._on_failure_step_backoff_max_seconds,
+ )
+ ],
+ )
+
+ validated_priority = (
+ max(1, min(3, default_priority)) if default_priority else None
+ )
+ if validated_priority != default_priority:
+ logger.warning(
+ "Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range."
+ )
+
+ return CreateWorkflowVersionOpts(
+ name=name,
+ kind=WorkflowKind.DAG,
+ version=version,
+ event_triggers=event_triggers,
+ cron_triggers=cron_triggers,
+ schedule_timeout=schedule_timeout,
+ sticky=sticky,
+ jobs=[
+ CreateWorkflowJobOpts(
+ name=name,
+ steps=createStepOpts,
+ )
+ ],
+ on_failure_job=on_failure_job,
+ concurrency=concurrency,
+ default_priority=validated_priority,
+ )
+
+ attrs["get_create_opts"] = get_create_opts
+
+ return super(WorkflowMeta, cls).__new__(cls, name, bases, attrs)