aboutsummaryrefslogtreecommitdiff
import functools
from typing import (
    Any,
    Callable,
    Protocol,
    Type,
    TypeVar,
    Union,
    cast,
    get_type_hints,
    runtime_checkable,
)

from pydantic import BaseModel

from hatchet_sdk import ConcurrencyLimitStrategy
from hatchet_sdk.contracts.workflows_pb2 import (
    CreateWorkflowJobOpts,
    CreateWorkflowStepOpts,
    CreateWorkflowVersionOpts,
    StickyStrategy,
    WorkflowConcurrencyOpts,
    WorkflowKind,
)
from hatchet_sdk.logger import logger
from hatchet_sdk.utils.typing import is_basemodel_subclass


class WorkflowStepProtocol(Protocol):
    def __call__(self, *args: Any, **kwargs: Any) -> Any: ...

    __name__: str

    _step_name: str
    _step_timeout: str | None
    _step_parents: list[str]
    _step_retries: int | None
    _step_rate_limits: list[str] | None
    _step_desired_worker_labels: dict[str, str]
    _step_backoff_factor: float | None
    _step_backoff_max_seconds: int | None

    _concurrency_fn_name: str
    _concurrency_max_runs: int | None
    _concurrency_limit_strategy: str | None

    _on_failure_step_name: str
    _on_failure_step_timeout: str | None
    _on_failure_step_retries: int
    _on_failure_step_rate_limits: list[str] | None
    _on_failure_step_backoff_factor: float | None
    _on_failure_step_backoff_max_seconds: int | None


StepsType = list[tuple[str, WorkflowStepProtocol]]

T = TypeVar("T")
TW = TypeVar("TW", bound="WorkflowInterface")


class ConcurrencyExpression:
    """
    Defines concurrency limits for a workflow using a CEL expression.

    Args:
        expression (str): CEL expression to determine concurrency grouping. (i.e. "input.user_id")
        max_runs (int): Maximum number of concurrent workflow runs.
        limit_strategy (ConcurrencyLimitStrategy): Strategy for handling limit violations.

    Example:
        ConcurrencyExpression("input.user_id", 5, ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS)
    """

    def __init__(
        self, expression: str, max_runs: int, limit_strategy: ConcurrencyLimitStrategy
    ):
        self.expression = expression
        self.max_runs = max_runs
        self.limit_strategy = limit_strategy


@runtime_checkable
class WorkflowInterface(Protocol):
    def get_name(self, namespace: str) -> str: ...

    def get_actions(self, namespace: str) -> list[tuple[str, Callable[..., Any]]]: ...

    def get_create_opts(self, namespace: str) -> Any: ...

    on_events: list[str] | None
    on_crons: list[str] | None
    name: str
    version: str
    timeout: str
    schedule_timeout: str
    sticky: Union[StickyStrategy.Value, None]  # type: ignore[name-defined]
    default_priority: int | None
    concurrency_expression: ConcurrencyExpression | None
    input_validator: Type[BaseModel] | None


class WorkflowMeta(type):
    def __new__(
        cls: Type["WorkflowMeta"],
        name: str,
        bases: tuple[type, ...],
        attrs: dict[str, Any],
    ) -> "WorkflowMeta":
        def _create_steps_actions_list(name: str) -> StepsType:
            return [
                (getattr(func, name), attrs.pop(func_name))
                for func_name, func in list(attrs.items())
                if hasattr(func, name)
            ]

        concurrencyActions = _create_steps_actions_list("_concurrency_fn_name")
        steps = _create_steps_actions_list("_step_name")

        onFailureSteps = _create_steps_actions_list("_on_failure_step_name")

        # Define __init__ and get_step_order methods
        original_init = attrs.get("__init__")  # Get the original __init__ if it exists

        def __init__(self: TW, *args: Any, **kwargs: Any) -> None:
            if original_init:
                original_init(self, *args, **kwargs)  # Call original __init__

        def get_service_name(namespace: str) -> str:
            return f"{namespace}{name.lower()}"

        @functools.cache
        def get_actions(self: TW, namespace: str) -> StepsType:
            serviceName = get_service_name(namespace)

            func_actions = [
                (serviceName + ":" + func_name, func) for func_name, func in steps
            ]
            concurrency_actions = [
                (serviceName + ":" + func_name, func)
                for func_name, func in concurrencyActions
            ]
            onFailure_actions = [
                (serviceName + ":" + func_name, func)
                for func_name, func in onFailureSteps
            ]

            return func_actions + concurrency_actions + onFailure_actions

        # Add these methods and steps to class attributes
        attrs["__init__"] = __init__
        attrs["get_actions"] = get_actions

        for step_name, step_func in steps:
            attrs[step_name] = step_func

        def get_name(self: TW, namespace: str) -> str:
            return namespace + cast(str, attrs["name"])

        attrs["get_name"] = get_name

        cron_triggers = attrs["on_crons"]
        version = attrs["version"]
        schedule_timeout = attrs["schedule_timeout"]
        sticky = attrs["sticky"]
        default_priority = attrs["default_priority"]

        @functools.cache
        def get_create_opts(self: TW, namespace: str) -> CreateWorkflowVersionOpts:
            serviceName = get_service_name(namespace)
            name = self.get_name(namespace)
            event_triggers = [namespace + event for event in attrs["on_events"]]
            createStepOpts: list[CreateWorkflowStepOpts] = [
                CreateWorkflowStepOpts(
                    readable_id=step_name,
                    action=serviceName + ":" + step_name,
                    timeout=func._step_timeout or "60s",
                    inputs="{}",
                    parents=[x for x in func._step_parents],
                    retries=func._step_retries,
                    rate_limits=func._step_rate_limits,  # type: ignore[arg-type]
                    worker_labels=func._step_desired_worker_labels,  # type: ignore[arg-type]
                    backoff_factor=func._step_backoff_factor,
                    backoff_max_seconds=func._step_backoff_max_seconds,
                )
                for step_name, func in steps
            ]

            concurrency: WorkflowConcurrencyOpts | None = None

            if len(concurrencyActions) > 0:
                action = concurrencyActions[0]

                concurrency = WorkflowConcurrencyOpts(
                    action=serviceName + ":" + action[0],
                    max_runs=action[1]._concurrency_max_runs,
                    limit_strategy=action[1]._concurrency_limit_strategy,
                )

            if self.concurrency_expression:
                concurrency = WorkflowConcurrencyOpts(
                    expression=self.concurrency_expression.expression,
                    max_runs=self.concurrency_expression.max_runs,
                    limit_strategy=self.concurrency_expression.limit_strategy,
                )

            if len(concurrencyActions) > 0 and self.concurrency_expression:
                raise ValueError(
                    "Error: Both concurrencyActions and concurrency_expression are defined. Please use only one concurrency configuration method."
                )

            on_failure_job: CreateWorkflowJobOpts | None = None

            if len(onFailureSteps) > 0:
                func_name, func = onFailureSteps[0]
                on_failure_job = CreateWorkflowJobOpts(
                    name=name + "-on-failure",
                    steps=[
                        CreateWorkflowStepOpts(
                            readable_id=func_name,
                            action=serviceName + ":" + func_name,
                            timeout=func._on_failure_step_timeout or "60s",
                            inputs="{}",
                            parents=[],
                            retries=func._on_failure_step_retries,
                            rate_limits=func._on_failure_step_rate_limits,  # type: ignore[arg-type]
                            backoff_factor=func._on_failure_step_backoff_factor,
                            backoff_max_seconds=func._on_failure_step_backoff_max_seconds,
                        )
                    ],
                )

            validated_priority = (
                max(1, min(3, default_priority)) if default_priority else None
            )
            if validated_priority != default_priority:
                logger.warning(
                    "Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range."
                )

            return CreateWorkflowVersionOpts(
                name=name,
                kind=WorkflowKind.DAG,
                version=version,
                event_triggers=event_triggers,
                cron_triggers=cron_triggers,
                schedule_timeout=schedule_timeout,
                sticky=sticky,
                jobs=[
                    CreateWorkflowJobOpts(
                        name=name,
                        steps=createStepOpts,
                    )
                ],
                on_failure_job=on_failure_job,
                concurrency=concurrency,
                default_priority=validated_priority,
            )

        attrs["get_create_opts"] = get_create_opts

        return super(WorkflowMeta, cls).__new__(cls, name, bases, attrs)