import functools from typing import ( Any, Callable, Protocol, Type, TypeVar, Union, cast, get_type_hints, runtime_checkable, ) from pydantic import BaseModel from hatchet_sdk import ConcurrencyLimitStrategy from hatchet_sdk.contracts.workflows_pb2 import ( CreateWorkflowJobOpts, CreateWorkflowStepOpts, CreateWorkflowVersionOpts, StickyStrategy, WorkflowConcurrencyOpts, WorkflowKind, ) from hatchet_sdk.logger import logger from hatchet_sdk.utils.typing import is_basemodel_subclass class WorkflowStepProtocol(Protocol): def __call__(self, *args: Any, **kwargs: Any) -> Any: ... __name__: str _step_name: str _step_timeout: str | None _step_parents: list[str] _step_retries: int | None _step_rate_limits: list[str] | None _step_desired_worker_labels: dict[str, str] _step_backoff_factor: float | None _step_backoff_max_seconds: int | None _concurrency_fn_name: str _concurrency_max_runs: int | None _concurrency_limit_strategy: str | None _on_failure_step_name: str _on_failure_step_timeout: str | None _on_failure_step_retries: int _on_failure_step_rate_limits: list[str] | None _on_failure_step_backoff_factor: float | None _on_failure_step_backoff_max_seconds: int | None StepsType = list[tuple[str, WorkflowStepProtocol]] T = TypeVar("T") TW = TypeVar("TW", bound="WorkflowInterface") class ConcurrencyExpression: """ Defines concurrency limits for a workflow using a CEL expression. Args: expression (str): CEL expression to determine concurrency grouping. (i.e. "input.user_id") max_runs (int): Maximum number of concurrent workflow runs. limit_strategy (ConcurrencyLimitStrategy): Strategy for handling limit violations. Example: ConcurrencyExpression("input.user_id", 5, ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS) """ def __init__( self, expression: str, max_runs: int, limit_strategy: ConcurrencyLimitStrategy ): self.expression = expression self.max_runs = max_runs self.limit_strategy = limit_strategy @runtime_checkable class WorkflowInterface(Protocol): def get_name(self, namespace: str) -> str: ... def get_actions(self, namespace: str) -> list[tuple[str, Callable[..., Any]]]: ... def get_create_opts(self, namespace: str) -> Any: ... on_events: list[str] | None on_crons: list[str] | None name: str version: str timeout: str schedule_timeout: str sticky: Union[StickyStrategy.Value, None] # type: ignore[name-defined] default_priority: int | None concurrency_expression: ConcurrencyExpression | None input_validator: Type[BaseModel] | None class WorkflowMeta(type): def __new__( cls: Type["WorkflowMeta"], name: str, bases: tuple[type, ...], attrs: dict[str, Any], ) -> "WorkflowMeta": def _create_steps_actions_list(name: str) -> StepsType: return [ (getattr(func, name), attrs.pop(func_name)) for func_name, func in list(attrs.items()) if hasattr(func, name) ] concurrencyActions = _create_steps_actions_list("_concurrency_fn_name") steps = _create_steps_actions_list("_step_name") onFailureSteps = _create_steps_actions_list("_on_failure_step_name") # Define __init__ and get_step_order methods original_init = attrs.get("__init__") # Get the original __init__ if it exists def __init__(self: TW, *args: Any, **kwargs: Any) -> None: if original_init: original_init(self, *args, **kwargs) # Call original __init__ def get_service_name(namespace: str) -> str: return f"{namespace}{name.lower()}" @functools.cache def get_actions(self: TW, namespace: str) -> StepsType: serviceName = get_service_name(namespace) func_actions = [ (serviceName + ":" + func_name, func) for func_name, func in steps ] concurrency_actions = [ (serviceName + ":" + func_name, func) for func_name, func in concurrencyActions ] onFailure_actions = [ (serviceName + ":" + func_name, func) for func_name, func in onFailureSteps ] return func_actions + concurrency_actions + onFailure_actions # Add these methods and steps to class attributes attrs["__init__"] = __init__ attrs["get_actions"] = get_actions for step_name, step_func in steps: attrs[step_name] = step_func def get_name(self: TW, namespace: str) -> str: return namespace + cast(str, attrs["name"]) attrs["get_name"] = get_name cron_triggers = attrs["on_crons"] version = attrs["version"] schedule_timeout = attrs["schedule_timeout"] sticky = attrs["sticky"] default_priority = attrs["default_priority"] @functools.cache def get_create_opts(self: TW, namespace: str) -> CreateWorkflowVersionOpts: serviceName = get_service_name(namespace) name = self.get_name(namespace) event_triggers = [namespace + event for event in attrs["on_events"]] createStepOpts: list[CreateWorkflowStepOpts] = [ CreateWorkflowStepOpts( readable_id=step_name, action=serviceName + ":" + step_name, timeout=func._step_timeout or "60s", inputs="{}", parents=[x for x in func._step_parents], retries=func._step_retries, rate_limits=func._step_rate_limits, # type: ignore[arg-type] worker_labels=func._step_desired_worker_labels, # type: ignore[arg-type] backoff_factor=func._step_backoff_factor, backoff_max_seconds=func._step_backoff_max_seconds, ) for step_name, func in steps ] concurrency: WorkflowConcurrencyOpts | None = None if len(concurrencyActions) > 0: action = concurrencyActions[0] concurrency = WorkflowConcurrencyOpts( action=serviceName + ":" + action[0], max_runs=action[1]._concurrency_max_runs, limit_strategy=action[1]._concurrency_limit_strategy, ) if self.concurrency_expression: concurrency = WorkflowConcurrencyOpts( expression=self.concurrency_expression.expression, max_runs=self.concurrency_expression.max_runs, limit_strategy=self.concurrency_expression.limit_strategy, ) if len(concurrencyActions) > 0 and self.concurrency_expression: raise ValueError( "Error: Both concurrencyActions and concurrency_expression are defined. Please use only one concurrency configuration method." ) on_failure_job: CreateWorkflowJobOpts | None = None if len(onFailureSteps) > 0: func_name, func = onFailureSteps[0] on_failure_job = CreateWorkflowJobOpts( name=name + "-on-failure", steps=[ CreateWorkflowStepOpts( readable_id=func_name, action=serviceName + ":" + func_name, timeout=func._on_failure_step_timeout or "60s", inputs="{}", parents=[], retries=func._on_failure_step_retries, rate_limits=func._on_failure_step_rate_limits, # type: ignore[arg-type] backoff_factor=func._on_failure_step_backoff_factor, backoff_max_seconds=func._on_failure_step_backoff_max_seconds, ) ], ) validated_priority = ( max(1, min(3, default_priority)) if default_priority else None ) if validated_priority != default_priority: logger.warning( "Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range." ) return CreateWorkflowVersionOpts( name=name, kind=WorkflowKind.DAG, version=version, event_triggers=event_triggers, cron_triggers=cron_triggers, schedule_timeout=schedule_timeout, sticky=sticky, jobs=[ CreateWorkflowJobOpts( name=name, steps=createStepOpts, ) ], on_failure_job=on_failure_job, concurrency=concurrency, default_priority=validated_priority, ) attrs["get_create_opts"] = get_create_opts return super(WorkflowMeta, cls).__new__(cls, name, bases, attrs)