From 4a52a71956a8d46fcb7294ac71734504bb09bcc2 Mon Sep 17 00:00:00 2001 From: S. Solomon Darnell Date: Fri, 28 Mar 2025 21:52:21 -0500 Subject: two version of R2R are here --- .../site-packages/hatchet_sdk/workflow.py | 261 +++++++++++++++++++++ 1 file changed, 261 insertions(+) create mode 100644 .venv/lib/python3.12/site-packages/hatchet_sdk/workflow.py (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/workflow.py') diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/workflow.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/workflow.py new file mode 100644 index 00000000..9c5cef90 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/workflow.py @@ -0,0 +1,261 @@ +import functools +from typing import ( + Any, + Callable, + Protocol, + Type, + TypeVar, + Union, + cast, + get_type_hints, + runtime_checkable, +) + +from pydantic import BaseModel + +from hatchet_sdk import ConcurrencyLimitStrategy +from hatchet_sdk.contracts.workflows_pb2 import ( + CreateWorkflowJobOpts, + CreateWorkflowStepOpts, + CreateWorkflowVersionOpts, + StickyStrategy, + WorkflowConcurrencyOpts, + WorkflowKind, +) +from hatchet_sdk.logger import logger +from hatchet_sdk.utils.typing import is_basemodel_subclass + + +class WorkflowStepProtocol(Protocol): + def __call__(self, *args: Any, **kwargs: Any) -> Any: ... + + __name__: str + + _step_name: str + _step_timeout: str | None + _step_parents: list[str] + _step_retries: int | None + _step_rate_limits: list[str] | None + _step_desired_worker_labels: dict[str, str] + _step_backoff_factor: float | None + _step_backoff_max_seconds: int | None + + _concurrency_fn_name: str + _concurrency_max_runs: int | None + _concurrency_limit_strategy: str | None + + _on_failure_step_name: str + _on_failure_step_timeout: str | None + _on_failure_step_retries: int + _on_failure_step_rate_limits: list[str] | None + _on_failure_step_backoff_factor: float | None + _on_failure_step_backoff_max_seconds: int | None + + +StepsType = list[tuple[str, WorkflowStepProtocol]] + +T = TypeVar("T") +TW = TypeVar("TW", bound="WorkflowInterface") + + +class ConcurrencyExpression: + """ + Defines concurrency limits for a workflow using a CEL expression. + + Args: + expression (str): CEL expression to determine concurrency grouping. (i.e. "input.user_id") + max_runs (int): Maximum number of concurrent workflow runs. + limit_strategy (ConcurrencyLimitStrategy): Strategy for handling limit violations. + + Example: + ConcurrencyExpression("input.user_id", 5, ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS) + """ + + def __init__( + self, expression: str, max_runs: int, limit_strategy: ConcurrencyLimitStrategy + ): + self.expression = expression + self.max_runs = max_runs + self.limit_strategy = limit_strategy + + +@runtime_checkable +class WorkflowInterface(Protocol): + def get_name(self, namespace: str) -> str: ... + + def get_actions(self, namespace: str) -> list[tuple[str, Callable[..., Any]]]: ... + + def get_create_opts(self, namespace: str) -> Any: ... + + on_events: list[str] | None + on_crons: list[str] | None + name: str + version: str + timeout: str + schedule_timeout: str + sticky: Union[StickyStrategy.Value, None] # type: ignore[name-defined] + default_priority: int | None + concurrency_expression: ConcurrencyExpression | None + input_validator: Type[BaseModel] | None + + +class WorkflowMeta(type): + def __new__( + cls: Type["WorkflowMeta"], + name: str, + bases: tuple[type, ...], + attrs: dict[str, Any], + ) -> "WorkflowMeta": + def _create_steps_actions_list(name: str) -> StepsType: + return [ + (getattr(func, name), attrs.pop(func_name)) + for func_name, func in list(attrs.items()) + if hasattr(func, name) + ] + + concurrencyActions = _create_steps_actions_list("_concurrency_fn_name") + steps = _create_steps_actions_list("_step_name") + + onFailureSteps = _create_steps_actions_list("_on_failure_step_name") + + # Define __init__ and get_step_order methods + original_init = attrs.get("__init__") # Get the original __init__ if it exists + + def __init__(self: TW, *args: Any, **kwargs: Any) -> None: + if original_init: + original_init(self, *args, **kwargs) # Call original __init__ + + def get_service_name(namespace: str) -> str: + return f"{namespace}{name.lower()}" + + @functools.cache + def get_actions(self: TW, namespace: str) -> StepsType: + serviceName = get_service_name(namespace) + + func_actions = [ + (serviceName + ":" + func_name, func) for func_name, func in steps + ] + concurrency_actions = [ + (serviceName + ":" + func_name, func) + for func_name, func in concurrencyActions + ] + onFailure_actions = [ + (serviceName + ":" + func_name, func) + for func_name, func in onFailureSteps + ] + + return func_actions + concurrency_actions + onFailure_actions + + # Add these methods and steps to class attributes + attrs["__init__"] = __init__ + attrs["get_actions"] = get_actions + + for step_name, step_func in steps: + attrs[step_name] = step_func + + def get_name(self: TW, namespace: str) -> str: + return namespace + cast(str, attrs["name"]) + + attrs["get_name"] = get_name + + cron_triggers = attrs["on_crons"] + version = attrs["version"] + schedule_timeout = attrs["schedule_timeout"] + sticky = attrs["sticky"] + default_priority = attrs["default_priority"] + + @functools.cache + def get_create_opts(self: TW, namespace: str) -> CreateWorkflowVersionOpts: + serviceName = get_service_name(namespace) + name = self.get_name(namespace) + event_triggers = [namespace + event for event in attrs["on_events"]] + createStepOpts: list[CreateWorkflowStepOpts] = [ + CreateWorkflowStepOpts( + readable_id=step_name, + action=serviceName + ":" + step_name, + timeout=func._step_timeout or "60s", + inputs="{}", + parents=[x for x in func._step_parents], + retries=func._step_retries, + rate_limits=func._step_rate_limits, # type: ignore[arg-type] + worker_labels=func._step_desired_worker_labels, # type: ignore[arg-type] + backoff_factor=func._step_backoff_factor, + backoff_max_seconds=func._step_backoff_max_seconds, + ) + for step_name, func in steps + ] + + concurrency: WorkflowConcurrencyOpts | None = None + + if len(concurrencyActions) > 0: + action = concurrencyActions[0] + + concurrency = WorkflowConcurrencyOpts( + action=serviceName + ":" + action[0], + max_runs=action[1]._concurrency_max_runs, + limit_strategy=action[1]._concurrency_limit_strategy, + ) + + if self.concurrency_expression: + concurrency = WorkflowConcurrencyOpts( + expression=self.concurrency_expression.expression, + max_runs=self.concurrency_expression.max_runs, + limit_strategy=self.concurrency_expression.limit_strategy, + ) + + if len(concurrencyActions) > 0 and self.concurrency_expression: + raise ValueError( + "Error: Both concurrencyActions and concurrency_expression are defined. Please use only one concurrency configuration method." + ) + + on_failure_job: CreateWorkflowJobOpts | None = None + + if len(onFailureSteps) > 0: + func_name, func = onFailureSteps[0] + on_failure_job = CreateWorkflowJobOpts( + name=name + "-on-failure", + steps=[ + CreateWorkflowStepOpts( + readable_id=func_name, + action=serviceName + ":" + func_name, + timeout=func._on_failure_step_timeout or "60s", + inputs="{}", + parents=[], + retries=func._on_failure_step_retries, + rate_limits=func._on_failure_step_rate_limits, # type: ignore[arg-type] + backoff_factor=func._on_failure_step_backoff_factor, + backoff_max_seconds=func._on_failure_step_backoff_max_seconds, + ) + ], + ) + + validated_priority = ( + max(1, min(3, default_priority)) if default_priority else None + ) + if validated_priority != default_priority: + logger.warning( + "Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range." + ) + + return CreateWorkflowVersionOpts( + name=name, + kind=WorkflowKind.DAG, + version=version, + event_triggers=event_triggers, + cron_triggers=cron_triggers, + schedule_timeout=schedule_timeout, + sticky=sticky, + jobs=[ + CreateWorkflowJobOpts( + name=name, + steps=createStepOpts, + ) + ], + on_failure_job=on_failure_job, + concurrency=concurrency, + default_priority=validated_priority, + ) + + attrs["get_create_opts"] = get_create_opts + + return super(WorkflowMeta, cls).__new__(cls, name, bases, attrs) -- cgit v1.2.3