aboutsummaryrefslogtreecommitdiff
from typing import Any, cast

from google.protobuf.timestamp_pb2 import Timestamp

from hatchet_sdk.clients.dispatcher.action_listener import (
    Action,
    ActionListener,
    GetActionListenerRequest,
)
from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
from hatchet_sdk.connection import new_conn
from hatchet_sdk.contracts.dispatcher_pb2 import (
    STEP_EVENT_TYPE_COMPLETED,
    STEP_EVENT_TYPE_FAILED,
    ActionEventResponse,
    GroupKeyActionEvent,
    GroupKeyActionEventType,
    OverridesData,
    RefreshTimeoutRequest,
    ReleaseSlotRequest,
    StepActionEvent,
    StepActionEventType,
    UpsertWorkerLabelsRequest,
    WorkerLabels,
    WorkerRegisterRequest,
    WorkerRegisterResponse,
)
from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub

from ...loader import ClientConfig
from ...metadata import get_metadata

DEFAULT_REGISTER_TIMEOUT = 30


def new_dispatcher(config: ClientConfig) -> "DispatcherClient":
    return DispatcherClient(config=config)


class DispatcherClient:
    config: ClientConfig

    def __init__(self, config: ClientConfig):
        conn = new_conn(config)
        self.client = DispatcherStub(conn)  # type: ignore[no-untyped-call]

        aio_conn = new_conn(config, True)
        self.aio_client = DispatcherStub(aio_conn)  # type: ignore[no-untyped-call]
        self.token = config.token
        self.config = config

    async def get_action_listener(
        self, req: GetActionListenerRequest
    ) -> ActionListener:

        # Override labels with the preset labels
        preset_labels = self.config.worker_preset_labels

        for key, value in preset_labels.items():
            req.labels[key] = WorkerLabels(strValue=str(value))

        # Register the worker
        response: WorkerRegisterResponse = await self.aio_client.Register(
            WorkerRegisterRequest(
                workerName=req.worker_name,
                actions=req.actions,
                services=req.services,
                maxRuns=req.max_runs,
                labels=req.labels,
            ),
            timeout=DEFAULT_REGISTER_TIMEOUT,
            metadata=get_metadata(self.token),
        )

        return ActionListener(self.config, response.workerId)

    async def send_step_action_event(
        self, action: Action, event_type: StepActionEventType, payload: str
    ) -> Any:
        try:
            return await self._try_send_step_action_event(action, event_type, payload)
        except Exception as e:
            # for step action events, send a failure event when we cannot send the completed event
            if (
                event_type == STEP_EVENT_TYPE_COMPLETED
                or event_type == STEP_EVENT_TYPE_FAILED
            ):
                await self._try_send_step_action_event(
                    action,
                    STEP_EVENT_TYPE_FAILED,
                    "Failed to send finished event: " + str(e),
                )

            return

    @tenacity_retry
    async def _try_send_step_action_event(
        self, action: Action, event_type: StepActionEventType, payload: str
    ) -> Any:
        eventTimestamp = Timestamp()
        eventTimestamp.GetCurrentTime()

        event = StepActionEvent(
            workerId=action.worker_id,
            jobId=action.job_id,
            jobRunId=action.job_run_id,
            stepId=action.step_id,
            stepRunId=action.step_run_id,
            actionId=action.action_id,
            eventTimestamp=eventTimestamp,
            eventType=event_type,
            eventPayload=payload,
            retryCount=action.retry_count,
        )

        ## TODO: What does this return?
        return await self.aio_client.SendStepActionEvent(
            event,
            metadata=get_metadata(self.token),
        )

    async def send_group_key_action_event(
        self, action: Action, event_type: GroupKeyActionEventType, payload: str
    ) -> Any:
        eventTimestamp = Timestamp()
        eventTimestamp.GetCurrentTime()

        event = GroupKeyActionEvent(
            workerId=action.worker_id,
            workflowRunId=action.workflow_run_id,
            getGroupKeyRunId=action.get_group_key_run_id,
            actionId=action.action_id,
            eventTimestamp=eventTimestamp,
            eventType=event_type,
            eventPayload=payload,
        )

        ## TODO: What does this return?
        return await self.aio_client.SendGroupKeyActionEvent(
            event,
            metadata=get_metadata(self.token),
        )

    def put_overrides_data(self, data: OverridesData) -> ActionEventResponse:
        return cast(
            ActionEventResponse,
            self.client.PutOverridesData(
                data,
                metadata=get_metadata(self.token),
            ),
        )

    def release_slot(self, step_run_id: str) -> None:
        self.client.ReleaseSlot(
            ReleaseSlotRequest(stepRunId=step_run_id),
            timeout=DEFAULT_REGISTER_TIMEOUT,
            metadata=get_metadata(self.token),
        )

    def refresh_timeout(self, step_run_id: str, increment_by: str) -> None:
        self.client.RefreshTimeout(
            RefreshTimeoutRequest(
                stepRunId=step_run_id,
                incrementTimeoutBy=increment_by,
            ),
            timeout=DEFAULT_REGISTER_TIMEOUT,
            metadata=get_metadata(self.token),
        )

    def upsert_worker_labels(
        self, worker_id: str | None, labels: dict[str, str | int]
    ) -> None:
        worker_labels = {}

        for key, value in labels.items():
            if isinstance(value, int):
                worker_labels[key] = WorkerLabels(intValue=value)
            else:
                worker_labels[key] = WorkerLabels(strValue=str(value))

        self.client.UpsertWorkerLabels(
            UpsertWorkerLabelsRequest(workerId=worker_id, labels=worker_labels),
            timeout=DEFAULT_REGISTER_TIMEOUT,
            metadata=get_metadata(self.token),
        )

    async def async_upsert_worker_labels(
        self,
        worker_id: str | None,
        labels: dict[str, str | int],
    ) -> None:
        worker_labels = {}

        for key, value in labels.items():
            if isinstance(value, int):
                worker_labels[key] = WorkerLabels(intValue=value)
            else:
                worker_labels[key] = WorkerLabels(strValue=str(value))

        await self.aio_client.UpsertWorkerLabels(
            UpsertWorkerLabelsRequest(workerId=worker_id, labels=worker_labels),
            timeout=DEFAULT_REGISTER_TIMEOUT,
            metadata=get_metadata(self.token),
        )