about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/hatchet_sdk/workflow.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/workflow.py')
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/workflow.py261
1 files changed, 261 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/workflow.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/workflow.py
new file mode 100644
index 00000000..9c5cef90
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/workflow.py
@@ -0,0 +1,261 @@
+import functools
+from typing import (
+    Any,
+    Callable,
+    Protocol,
+    Type,
+    TypeVar,
+    Union,
+    cast,
+    get_type_hints,
+    runtime_checkable,
+)
+
+from pydantic import BaseModel
+
+from hatchet_sdk import ConcurrencyLimitStrategy
+from hatchet_sdk.contracts.workflows_pb2 import (
+    CreateWorkflowJobOpts,
+    CreateWorkflowStepOpts,
+    CreateWorkflowVersionOpts,
+    StickyStrategy,
+    WorkflowConcurrencyOpts,
+    WorkflowKind,
+)
+from hatchet_sdk.logger import logger
+from hatchet_sdk.utils.typing import is_basemodel_subclass
+
+
+class WorkflowStepProtocol(Protocol):
+    def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
+
+    __name__: str
+
+    _step_name: str
+    _step_timeout: str | None
+    _step_parents: list[str]
+    _step_retries: int | None
+    _step_rate_limits: list[str] | None
+    _step_desired_worker_labels: dict[str, str]
+    _step_backoff_factor: float | None
+    _step_backoff_max_seconds: int | None
+
+    _concurrency_fn_name: str
+    _concurrency_max_runs: int | None
+    _concurrency_limit_strategy: str | None
+
+    _on_failure_step_name: str
+    _on_failure_step_timeout: str | None
+    _on_failure_step_retries: int
+    _on_failure_step_rate_limits: list[str] | None
+    _on_failure_step_backoff_factor: float | None
+    _on_failure_step_backoff_max_seconds: int | None
+
+
+StepsType = list[tuple[str, WorkflowStepProtocol]]
+
+T = TypeVar("T")
+TW = TypeVar("TW", bound="WorkflowInterface")
+
+
+class ConcurrencyExpression:
+    """
+    Defines concurrency limits for a workflow using a CEL expression.
+
+    Args:
+        expression (str): CEL expression to determine concurrency grouping. (i.e. "input.user_id")
+        max_runs (int): Maximum number of concurrent workflow runs.
+        limit_strategy (ConcurrencyLimitStrategy): Strategy for handling limit violations.
+
+    Example:
+        ConcurrencyExpression("input.user_id", 5, ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS)
+    """
+
+    def __init__(
+        self, expression: str, max_runs: int, limit_strategy: ConcurrencyLimitStrategy
+    ):
+        self.expression = expression
+        self.max_runs = max_runs
+        self.limit_strategy = limit_strategy
+
+
+@runtime_checkable
+class WorkflowInterface(Protocol):
+    def get_name(self, namespace: str) -> str: ...
+
+    def get_actions(self, namespace: str) -> list[tuple[str, Callable[..., Any]]]: ...
+
+    def get_create_opts(self, namespace: str) -> Any: ...
+
+    on_events: list[str] | None
+    on_crons: list[str] | None
+    name: str
+    version: str
+    timeout: str
+    schedule_timeout: str
+    sticky: Union[StickyStrategy.Value, None]  # type: ignore[name-defined]
+    default_priority: int | None
+    concurrency_expression: ConcurrencyExpression | None
+    input_validator: Type[BaseModel] | None
+
+
+class WorkflowMeta(type):
+    def __new__(
+        cls: Type["WorkflowMeta"],
+        name: str,
+        bases: tuple[type, ...],
+        attrs: dict[str, Any],
+    ) -> "WorkflowMeta":
+        def _create_steps_actions_list(name: str) -> StepsType:
+            return [
+                (getattr(func, name), attrs.pop(func_name))
+                for func_name, func in list(attrs.items())
+                if hasattr(func, name)
+            ]
+
+        concurrencyActions = _create_steps_actions_list("_concurrency_fn_name")
+        steps = _create_steps_actions_list("_step_name")
+
+        onFailureSteps = _create_steps_actions_list("_on_failure_step_name")
+
+        # Define __init__ and get_step_order methods
+        original_init = attrs.get("__init__")  # Get the original __init__ if it exists
+
+        def __init__(self: TW, *args: Any, **kwargs: Any) -> None:
+            if original_init:
+                original_init(self, *args, **kwargs)  # Call original __init__
+
+        def get_service_name(namespace: str) -> str:
+            return f"{namespace}{name.lower()}"
+
+        @functools.cache
+        def get_actions(self: TW, namespace: str) -> StepsType:
+            serviceName = get_service_name(namespace)
+
+            func_actions = [
+                (serviceName + ":" + func_name, func) for func_name, func in steps
+            ]
+            concurrency_actions = [
+                (serviceName + ":" + func_name, func)
+                for func_name, func in concurrencyActions
+            ]
+            onFailure_actions = [
+                (serviceName + ":" + func_name, func)
+                for func_name, func in onFailureSteps
+            ]
+
+            return func_actions + concurrency_actions + onFailure_actions
+
+        # Add these methods and steps to class attributes
+        attrs["__init__"] = __init__
+        attrs["get_actions"] = get_actions
+
+        for step_name, step_func in steps:
+            attrs[step_name] = step_func
+
+        def get_name(self: TW, namespace: str) -> str:
+            return namespace + cast(str, attrs["name"])
+
+        attrs["get_name"] = get_name
+
+        cron_triggers = attrs["on_crons"]
+        version = attrs["version"]
+        schedule_timeout = attrs["schedule_timeout"]
+        sticky = attrs["sticky"]
+        default_priority = attrs["default_priority"]
+
+        @functools.cache
+        def get_create_opts(self: TW, namespace: str) -> CreateWorkflowVersionOpts:
+            serviceName = get_service_name(namespace)
+            name = self.get_name(namespace)
+            event_triggers = [namespace + event for event in attrs["on_events"]]
+            createStepOpts: list[CreateWorkflowStepOpts] = [
+                CreateWorkflowStepOpts(
+                    readable_id=step_name,
+                    action=serviceName + ":" + step_name,
+                    timeout=func._step_timeout or "60s",
+                    inputs="{}",
+                    parents=[x for x in func._step_parents],
+                    retries=func._step_retries,
+                    rate_limits=func._step_rate_limits,  # type: ignore[arg-type]
+                    worker_labels=func._step_desired_worker_labels,  # type: ignore[arg-type]
+                    backoff_factor=func._step_backoff_factor,
+                    backoff_max_seconds=func._step_backoff_max_seconds,
+                )
+                for step_name, func in steps
+            ]
+
+            concurrency: WorkflowConcurrencyOpts | None = None
+
+            if len(concurrencyActions) > 0:
+                action = concurrencyActions[0]
+
+                concurrency = WorkflowConcurrencyOpts(
+                    action=serviceName + ":" + action[0],
+                    max_runs=action[1]._concurrency_max_runs,
+                    limit_strategy=action[1]._concurrency_limit_strategy,
+                )
+
+            if self.concurrency_expression:
+                concurrency = WorkflowConcurrencyOpts(
+                    expression=self.concurrency_expression.expression,
+                    max_runs=self.concurrency_expression.max_runs,
+                    limit_strategy=self.concurrency_expression.limit_strategy,
+                )
+
+            if len(concurrencyActions) > 0 and self.concurrency_expression:
+                raise ValueError(
+                    "Error: Both concurrencyActions and concurrency_expression are defined. Please use only one concurrency configuration method."
+                )
+
+            on_failure_job: CreateWorkflowJobOpts | None = None
+
+            if len(onFailureSteps) > 0:
+                func_name, func = onFailureSteps[0]
+                on_failure_job = CreateWorkflowJobOpts(
+                    name=name + "-on-failure",
+                    steps=[
+                        CreateWorkflowStepOpts(
+                            readable_id=func_name,
+                            action=serviceName + ":" + func_name,
+                            timeout=func._on_failure_step_timeout or "60s",
+                            inputs="{}",
+                            parents=[],
+                            retries=func._on_failure_step_retries,
+                            rate_limits=func._on_failure_step_rate_limits,  # type: ignore[arg-type]
+                            backoff_factor=func._on_failure_step_backoff_factor,
+                            backoff_max_seconds=func._on_failure_step_backoff_max_seconds,
+                        )
+                    ],
+                )
+
+            validated_priority = (
+                max(1, min(3, default_priority)) if default_priority else None
+            )
+            if validated_priority != default_priority:
+                logger.warning(
+                    "Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range."
+                )
+
+            return CreateWorkflowVersionOpts(
+                name=name,
+                kind=WorkflowKind.DAG,
+                version=version,
+                event_triggers=event_triggers,
+                cron_triggers=cron_triggers,
+                schedule_timeout=schedule_timeout,
+                sticky=sticky,
+                jobs=[
+                    CreateWorkflowJobOpts(
+                        name=name,
+                        steps=createStepOpts,
+                    )
+                ],
+                on_failure_job=on_failure_job,
+                concurrency=concurrency,
+                default_priority=validated_priority,
+            )
+
+        attrs["get_create_opts"] = get_create_opts
+
+        return super(WorkflowMeta, cls).__new__(cls, name, bases, attrs)