aboutsummaryrefslogtreecommitdiff
import asyncio
from typing import (
    Any,
    Callable,
    Dict,
    Generic,
    List,
    Optional,
    TypedDict,
    TypeVar,
    Union,
)

from hatchet_sdk.clients.admin import ChildTriggerWorkflowOptions
from hatchet_sdk.context.context import Context
from hatchet_sdk.contracts.workflows_pb2 import (  # type: ignore[attr-defined]
    CreateStepRateLimit,
    CreateWorkflowJobOpts,
    CreateWorkflowStepOpts,
    CreateWorkflowVersionOpts,
    DesiredWorkerLabels,
    StickyStrategy,
    WorkflowConcurrencyOpts,
    WorkflowKind,
)
from hatchet_sdk.labels import DesiredWorkerLabel
from hatchet_sdk.logger import logger
from hatchet_sdk.rate_limit import RateLimit
from hatchet_sdk.v2.concurrency import ConcurrencyFunction
from hatchet_sdk.workflow_run import RunRef

T = TypeVar("T")


class HatchetCallable(Generic[T]):
    def __init__(
        self,
        func: Callable[[Context], T],
        durable: bool = False,
        name: str = "",
        auto_register: bool = True,
        on_events: list[str] | None = None,
        on_crons: list[str] | None = None,
        version: str = "",
        timeout: str = "60m",
        schedule_timeout: str = "5m",
        sticky: StickyStrategy = None,
        retries: int = 0,
        rate_limits: List[RateLimit] | None = None,
        concurrency: ConcurrencyFunction | None = None,
        on_failure: Union["HatchetCallable[T]", None] = None,
        desired_worker_labels: dict[str, DesiredWorkerLabel] = {},
        default_priority: int | None = None,
    ):
        self.func = func

        on_events = on_events or []
        on_crons = on_crons or []

        limits = None
        if rate_limits:
            limits = [rate_limit._req for rate_limit in rate_limits or []]

        self.function_desired_worker_labels = {}

        for key, d in desired_worker_labels.items():
            value = d["value"] if "value" in d else None
            self.function_desired_worker_labels[key] = DesiredWorkerLabels(
                strValue=str(value) if not isinstance(value, int) else None,
                intValue=value if isinstance(value, int) else None,
                required=d["required"] if "required" in d else None,
                weight=d["weight"] if "weight" in d else None,
                comparator=d["comparator"] if "comparator" in d else None,
            )
        self.sticky = sticky
        self.default_priority = default_priority
        self.durable = durable
        self.function_name = name.lower() or str(func.__name__).lower()
        self.function_version = version
        self.function_on_events = on_events
        self.function_on_crons = on_crons
        self.function_timeout = timeout
        self.function_schedule_timeout = schedule_timeout
        self.function_retries = retries
        self.function_rate_limits = limits
        self.function_concurrency = concurrency
        self.function_on_failure = on_failure
        self.function_namespace = "default"
        self.function_auto_register = auto_register

        self.is_coroutine = False

        if asyncio.iscoroutinefunction(func):
            self.is_coroutine = True

    def __call__(self, context: Context) -> T:
        return self.func(context)

    def with_namespace(self, namespace: str) -> None:
        if namespace is not None and namespace != "":
            self.function_namespace = namespace
            self.function_name = namespace + self.function_name

    def to_workflow_opts(self) -> CreateWorkflowVersionOpts:
        kind: WorkflowKind = WorkflowKind.FUNCTION

        if self.durable:
            kind = WorkflowKind.DURABLE

        on_failure_job: CreateWorkflowJobOpts | None = None

        if self.function_on_failure is not None:
            on_failure_job = CreateWorkflowJobOpts(
                name=self.function_name + "-on-failure",
                steps=[
                    self.function_on_failure.to_step(),
                ],
            )

        concurrency: WorkflowConcurrencyOpts | None = None

        if self.function_concurrency is not None:
            self.function_concurrency.set_namespace(self.function_namespace)
            concurrency = WorkflowConcurrencyOpts(
                action=self.function_concurrency.get_action_name(),
                max_runs=self.function_concurrency.max_runs,
                limit_strategy=self.function_concurrency.limit_strategy,
            )

        validated_priority = (
            max(1, min(3, self.default_priority)) if self.default_priority else None
        )
        if validated_priority != self.default_priority:
            logger.warning(
                "Warning: Default Priority Must be between 1 and 3 -- inclusively. Adjusted to be within the range."
            )

        return CreateWorkflowVersionOpts(
            name=self.function_name,
            kind=kind,
            version=self.function_version,
            event_triggers=self.function_on_events,
            cron_triggers=self.function_on_crons,
            schedule_timeout=self.function_schedule_timeout,
            sticky=self.sticky,
            on_failure_job=on_failure_job,
            concurrency=concurrency,
            jobs=[
                CreateWorkflowJobOpts(
                    name=self.function_name,
                    steps=[
                        self.to_step(),
                    ],
                )
            ],
            default_priority=validated_priority,
        )

    def to_step(self) -> CreateWorkflowStepOpts:
        return CreateWorkflowStepOpts(
            readable_id=self.function_name,
            action=self.get_action_name(),
            timeout=self.function_timeout,
            inputs="{}",
            parents=[],
            retries=self.function_retries,
            rate_limits=self.function_rate_limits,
            worker_labels=self.function_desired_worker_labels,
        )

    def get_action_name(self) -> str:
        return self.function_namespace + ":" + self.function_name


class DurableContext(Context):
    def run(
        self,
        function: str | Callable[[Context], Any],
        input: dict[Any, Any] = {},
        key: str | None = None,
        options: ChildTriggerWorkflowOptions | None = None,
    ) -> "RunRef[T]":
        worker_id = self.worker.id()

        workflow_name = function

        if not isinstance(function, str):
            workflow_name = function.function_name

        # if (
        #     options is not None
        #     and "sticky" in options
        #     and options["sticky"] == True
        #     and not self.worker.has_workflow(workflow_name)
        # ):
        #     raise Exception(
        #         f"cannot run with sticky: workflow {workflow_name} is not registered on the worker"
        #     )

        trigger_options = self._prepare_workflow_options(key, options, worker_id)

        return self.admin_client.run(function, input, trigger_options)