aboutsummaryrefslogtreecommitdiff
import json
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, TypedDict, TypeVar, Union

import grpc
from google.protobuf import timestamp_pb2

from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun
from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
from hatchet_sdk.clients.run_event_listener import new_listener
from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
from hatchet_sdk.connection import new_conn
from hatchet_sdk.contracts.workflows_pb2 import (
    BulkTriggerWorkflowRequest,
    BulkTriggerWorkflowResponse,
    CreateWorkflowVersionOpts,
    PutRateLimitRequest,
    PutWorkflowRequest,
    RateLimitDuration,
    ScheduleWorkflowRequest,
    TriggerWorkflowRequest,
    TriggerWorkflowResponse,
    WorkflowVersion,
)
from hatchet_sdk.contracts.workflows_pb2_grpc import WorkflowServiceStub
from hatchet_sdk.utils.serialization import flatten
from hatchet_sdk.workflow_run import RunRef, WorkflowRunRef

from ..loader import ClientConfig
from ..metadata import get_metadata
from ..workflow import WorkflowMeta


def new_admin(config: ClientConfig):
    return AdminClient(config)


class ScheduleTriggerWorkflowOptions(TypedDict, total=False):
    parent_id: Optional[str]
    parent_step_run_id: Optional[str]
    child_index: Optional[int]
    child_key: Optional[str]
    namespace: Optional[str]


class ChildTriggerWorkflowOptions(TypedDict, total=False):
    additional_metadata: Dict[str, str] | None = None
    sticky: bool | None = None


class ChildWorkflowRunDict(TypedDict, total=False):
    workflow_name: str
    input: Any
    options: ChildTriggerWorkflowOptions
    key: str | None = None


class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions, total=False):
    additional_metadata: Dict[str, str] | None = None
    desired_worker_id: str | None = None
    namespace: str | None = None


class WorkflowRunDict(TypedDict, total=False):
    workflow_name: str
    input: Any
    options: TriggerWorkflowOptions | None


class DedupeViolationErr(Exception):
    """Raised by the Hatchet library to indicate that a workflow has already been run with this deduplication value."""

    pass


class AdminClientBase:
    pooled_workflow_listener: PooledWorkflowRunListener | None = None

    def _prepare_workflow_request(
        self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None
    ):
        try:
            payload_data = json.dumps(input)

            try:
                meta = (
                    None
                    if options is None or "additional_metadata" not in options
                    else options["additional_metadata"]
                )
                if meta is not None:
                    options = {
                        **options,
                        "additional_metadata": json.dumps(meta).encode("utf-8"),
                    }
            except json.JSONDecodeError as e:
                raise ValueError(f"Error encoding payload: {e}")

            return TriggerWorkflowRequest(
                name=workflow_name, input=payload_data, **(options or {})
            )
        except json.JSONDecodeError as e:
            raise ValueError(f"Error encoding payload: {e}")

    def _prepare_put_workflow_request(
        self,
        name: str,
        workflow: CreateWorkflowVersionOpts | WorkflowMeta,
        overrides: CreateWorkflowVersionOpts | None = None,
    ):
        try:
            opts: CreateWorkflowVersionOpts

            if isinstance(workflow, CreateWorkflowVersionOpts):
                opts = workflow
            else:
                opts = workflow.get_create_opts(self.client.config.namespace)

            if overrides is not None:
                opts.MergeFrom(overrides)

            opts.name = name

            return PutWorkflowRequest(
                opts=opts,
            )
        except grpc.RpcError as e:
            raise ValueError(f"Could not put workflow: {e}")

    def _prepare_schedule_workflow_request(
        self,
        name: str,
        schedules: List[Union[datetime, timestamp_pb2.Timestamp]],
        input={},
        options: ScheduleTriggerWorkflowOptions = None,
    ):
        timestamp_schedules = []
        for schedule in schedules:
            if isinstance(schedule, datetime):
                t = schedule.timestamp()
                seconds = int(t)
                nanos = int(t % 1 * 1e9)
                timestamp = timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos)
                timestamp_schedules.append(timestamp)
            elif isinstance(schedule, timestamp_pb2.Timestamp):
                timestamp_schedules.append(schedule)
            else:
                raise ValueError(
                    "Invalid schedule type. Must be datetime or timestamp_pb2.Timestamp."
                )

        return ScheduleWorkflowRequest(
            name=name,
            schedules=timestamp_schedules,
            input=json.dumps(input),
            **(options or {}),
        )


T = TypeVar("T")


class AdminClientAioImpl(AdminClientBase):
    def __init__(self, config: ClientConfig):
        aio_conn = new_conn(config, True)
        self.config = config
        self.aio_client = WorkflowServiceStub(aio_conn)
        self.token = config.token
        self.listener_client = new_listener(config)
        self.namespace = config.namespace

    async def run(
        self,
        function: Union[str, Callable[[Any], T]],
        input: any,
        options: TriggerWorkflowOptions = None,
    ) -> "RunRef[T]":
        workflow_name = function

        if not isinstance(function, str):
            workflow_name = function.function_name

        wrr = await self.run_workflow(workflow_name, input, options)

        return RunRef[T](
            wrr.workflow_run_id, wrr.workflow_listener, wrr.workflow_run_event_listener
        )

    ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
    @tenacity_retry
    async def run_workflow(
        self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None
    ) -> WorkflowRunRef:
        try:
            if not self.pooled_workflow_listener:
                self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)

            namespace = self.namespace

            if (
                options is not None
                and "namespace" in options
                and options["namespace"] is not None
            ):
                namespace = options.pop("namespace")

            if namespace != "" and not workflow_name.startswith(self.namespace):
                workflow_name = f"{namespace}{workflow_name}"

            request = self._prepare_workflow_request(workflow_name, input, options)

            resp: TriggerWorkflowResponse = await self.aio_client.TriggerWorkflow(
                request,
                metadata=get_metadata(self.token),
            )

            return WorkflowRunRef(
                workflow_run_id=resp.workflow_run_id,
                workflow_listener=self.pooled_workflow_listener,
                workflow_run_event_listener=self.listener_client,
            )
        except (grpc.RpcError, grpc.aio.AioRpcError) as e:
            if e.code() == grpc.StatusCode.ALREADY_EXISTS:
                raise DedupeViolationErr(e.details())

            raise e

    ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
    @tenacity_retry
    async def run_workflows(
        self,
        workflows: list[WorkflowRunDict],
        options: TriggerWorkflowOptions | None = None,
    ) -> List[WorkflowRunRef]:
        if len(workflows) == 0:
            raise ValueError("No workflows to run")

        if not self.pooled_workflow_listener:
            self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)

        namespace = self.namespace

        if (
            options is not None
            and "namespace" in options
            and options["namespace"] is not None
        ):
            namespace = options["namespace"]
            del options["namespace"]

        workflow_run_requests: TriggerWorkflowRequest = []

        for workflow in workflows:
            workflow_name = workflow["workflow_name"]
            input_data = workflow["input"]
            options = workflow["options"]

            if namespace != "" and not workflow_name.startswith(self.namespace):
                workflow_name = f"{namespace}{workflow_name}"

            # Prepare and trigger workflow for each workflow name and input
            request = self._prepare_workflow_request(workflow_name, input_data, options)
            workflow_run_requests.append(request)

        request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests)

        resp: BulkTriggerWorkflowResponse = await self.aio_client.BulkTriggerWorkflow(
            request,
            metadata=get_metadata(self.token),
        )

        return [
            WorkflowRunRef(
                workflow_run_id=workflow_run_id,
                workflow_listener=self.pooled_workflow_listener,
                workflow_run_event_listener=self.listener_client,
            )
            for workflow_run_id in resp.workflow_run_ids
        ]

    @tenacity_retry
    async def put_workflow(
        self,
        name: str,
        workflow: CreateWorkflowVersionOpts | WorkflowMeta,
        overrides: CreateWorkflowVersionOpts | None = None,
    ) -> WorkflowVersion:
        opts = self._prepare_put_workflow_request(name, workflow, overrides)

        return await self.aio_client.PutWorkflow(
            opts,
            metadata=get_metadata(self.token),
        )

    @tenacity_retry
    async def put_rate_limit(
        self,
        key: str,
        limit: int,
        duration: RateLimitDuration = RateLimitDuration.SECOND,
    ):
        await self.aio_client.PutRateLimit(
            PutRateLimitRequest(
                key=key,
                limit=limit,
                duration=duration,
            ),
            metadata=get_metadata(self.token),
        )

    @tenacity_retry
    async def schedule_workflow(
        self,
        name: str,
        schedules: List[Union[datetime, timestamp_pb2.Timestamp]],
        input={},
        options: ScheduleTriggerWorkflowOptions = None,
    ) -> WorkflowVersion:
        try:
            namespace = self.namespace

            if (
                options is not None
                and "namespace" in options
                and options["namespace"] is not None
            ):
                namespace = options["namespace"]
                del options["namespace"]

            if namespace != "" and not name.startswith(self.namespace):
                name = f"{namespace}{name}"

            request = self._prepare_schedule_workflow_request(
                name, schedules, input, options
            )

            return await self.aio_client.ScheduleWorkflow(
                request,
                metadata=get_metadata(self.token),
            )
        except (grpc.aio.AioRpcError, grpc.RpcError) as e:
            if e.code() == grpc.StatusCode.ALREADY_EXISTS:
                raise DedupeViolationErr(e.details())

            raise e


class AdminClient(AdminClientBase):
    def __init__(self, config: ClientConfig):
        conn = new_conn(config)
        self.config = config
        self.client = WorkflowServiceStub(conn)
        self.aio = AdminClientAioImpl(config)
        self.token = config.token
        self.listener_client = new_listener(config)
        self.namespace = config.namespace

    @tenacity_retry
    def put_workflow(
        self,
        name: str,
        workflow: CreateWorkflowVersionOpts | WorkflowMeta,
        overrides: CreateWorkflowVersionOpts | None = None,
    ) -> WorkflowVersion:
        opts = self._prepare_put_workflow_request(name, workflow, overrides)

        resp: WorkflowVersion = self.client.PutWorkflow(
            opts,
            metadata=get_metadata(self.token),
        )

        return resp

    @tenacity_retry
    def put_rate_limit(
        self,
        key: str,
        limit: int,
        duration: Union[RateLimitDuration.Value, str] = RateLimitDuration.SECOND,
    ):
        self.client.PutRateLimit(
            PutRateLimitRequest(
                key=key,
                limit=limit,
                duration=duration,
            ),
            metadata=get_metadata(self.token),
        )

    @tenacity_retry
    def schedule_workflow(
        self,
        name: str,
        schedules: List[Union[datetime, timestamp_pb2.Timestamp]],
        input={},
        options: ScheduleTriggerWorkflowOptions = None,
    ) -> WorkflowVersion:
        try:
            namespace = self.namespace

            if (
                options is not None
                and "namespace" in options
                and options["namespace"] is not None
            ):
                namespace = options["namespace"]
                del options["namespace"]

            if namespace != "" and not name.startswith(self.namespace):
                name = f"{namespace}{name}"

            request = self._prepare_schedule_workflow_request(
                name, schedules, input, options
            )

            return self.client.ScheduleWorkflow(
                request,
                metadata=get_metadata(self.token),
            )
        except (grpc.RpcError, grpc.aio.AioRpcError) as e:
            if e.code() == grpc.StatusCode.ALREADY_EXISTS:
                raise DedupeViolationErr(e.details())

            raise e

    ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
    @tenacity_retry
    def run_workflow(
        self, workflow_name: str, input: any, options: TriggerWorkflowOptions = None
    ) -> WorkflowRunRef:
        try:
            if not self.pooled_workflow_listener:
                self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)

            namespace = self.namespace

            ## TODO: Factor this out - it's repeated a lot of places
            if (
                options is not None
                and "namespace" in options
                and options["namespace"] is not None
            ):
                namespace = options.pop("namespace")

            if namespace != "" and not workflow_name.startswith(self.namespace):
                workflow_name = f"{namespace}{workflow_name}"

            request = self._prepare_workflow_request(workflow_name, input, options)

            resp: TriggerWorkflowResponse = self.client.TriggerWorkflow(
                request,
                metadata=get_metadata(self.token),
            )

            return WorkflowRunRef(
                workflow_run_id=resp.workflow_run_id,
                workflow_listener=self.pooled_workflow_listener,
                workflow_run_event_listener=self.listener_client,
            )
        except (grpc.RpcError, grpc.aio.AioRpcError) as e:
            if e.code() == grpc.StatusCode.ALREADY_EXISTS:
                raise DedupeViolationErr(e.details())

            raise e

    ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
    @tenacity_retry
    def run_workflows(
        self, workflows: List[WorkflowRunDict], options: TriggerWorkflowOptions = None
    ) -> list[WorkflowRunRef]:
        workflow_run_requests: TriggerWorkflowRequest = []
        if not self.pooled_workflow_listener:
            self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)

        for workflow in workflows:
            workflow_name = workflow["workflow_name"]
            input_data = workflow["input"]
            options = workflow["options"]

            namespace = self.namespace

            if (
                options is not None
                and "namespace" in options
                and options["namespace"] is not None
            ):
                namespace = options["namespace"]
                del options["namespace"]

            if namespace != "" and not workflow_name.startswith(self.namespace):
                workflow_name = f"{namespace}{workflow_name}"

            # Prepare and trigger workflow for each workflow name and input
            request = self._prepare_workflow_request(workflow_name, input_data, options)

            workflow_run_requests.append(request)

            request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests)

        resp: BulkTriggerWorkflowResponse = self.client.BulkTriggerWorkflow(
            request,
            metadata=get_metadata(self.token),
        )

        return [
            WorkflowRunRef(
                workflow_run_id=workflow_run_id,
                workflow_listener=self.pooled_workflow_listener,
                workflow_run_event_listener=self.listener_client,
            )
            for workflow_run_id in resp.workflow_run_ids
        ]

    def run(
        self,
        function: Union[str, Callable[[Any], T]],
        input: any,
        options: TriggerWorkflowOptions = None,
    ) -> "RunRef[T]":
        workflow_name = function

        if not isinstance(function, str):
            workflow_name = function.function_name

        wrr = self.run_workflow(workflow_name, input, options)

        return RunRef[T](
            wrr.workflow_run_id, wrr.workflow_listener, wrr.workflow_run_event_listener
        )

    def get_workflow_run(self, workflow_run_id: str) -> WorkflowRunRef:
        try:
            if not self.pooled_workflow_listener:
                self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)

            return WorkflowRunRef(
                workflow_run_id=workflow_run_id,
                workflow_listener=self.pooled_workflow_listener,
                workflow_run_event_listener=self.listener_client,
            )
        except grpc.RpcError as e:
            raise ValueError(f"Could not get workflow run: {e}")