aboutsummaryrefslogtreecommitdiff
import os
from logging import Logger, getLogger
from typing import Dict, Optional
from warnings import warn

import yaml

from .token import get_addresses_from_jwt, get_tenant_id_from_jwt


class ClientTLSConfig:
    def __init__(
        self,
        tls_strategy: str,
        cert_file: str,
        key_file: str,
        ca_file: str,
        server_name: str,
    ):
        self.tls_strategy = tls_strategy
        self.cert_file = cert_file
        self.key_file = key_file
        self.ca_file = ca_file
        self.server_name = server_name


class ClientConfig:
    logInterceptor: Logger

    def __init__(
        self,
        tenant_id: str = None,
        tls_config: ClientTLSConfig = None,
        token: str = None,
        host_port: str = "localhost:7070",
        server_url: str = "https://app.dev.hatchet-tools.com",
        namespace: str = None,
        listener_v2_timeout: int = None,
        logger: Logger = None,
        grpc_max_recv_message_length: int = 4 * 1024 * 1024,  # 4MB
        grpc_max_send_message_length: int = 4 * 1024 * 1024,  # 4MB
        worker_healthcheck_port: int | None = None,
        worker_healthcheck_enabled: bool | None = None,
        worker_preset_labels: dict[str, str] = {},
        enable_force_kill_sync_threads: bool = False,
    ):
        self.tenant_id = tenant_id
        self.tls_config = tls_config
        self.host_port = host_port
        self.token = token
        self.server_url = server_url
        self.namespace = ""
        self.logInterceptor = logger
        self.grpc_max_recv_message_length = grpc_max_recv_message_length
        self.grpc_max_send_message_length = grpc_max_send_message_length
        self.worker_healthcheck_port = worker_healthcheck_port
        self.worker_healthcheck_enabled = worker_healthcheck_enabled
        self.worker_preset_labels = worker_preset_labels
        self.enable_force_kill_sync_threads = enable_force_kill_sync_threads

        if not self.logInterceptor:
            self.logInterceptor = getLogger()

        # case on whether the namespace already has a trailing underscore
        if namespace and not namespace.endswith("_"):
            self.namespace = f"{namespace}_"
        elif namespace:
            self.namespace = namespace

        self.namespace = self.namespace.lower()

        self.listener_v2_timeout = listener_v2_timeout


class ConfigLoader:
    def __init__(self, directory: str):
        self.directory = directory

    def load_client_config(self, defaults: ClientConfig) -> ClientConfig:
        config_file_path = os.path.join(self.directory, "client.yaml")
        config_data: object = {"tls": {}}

        # determine if client.yaml exists
        if os.path.exists(config_file_path):
            with open(config_file_path, "r") as file:
                config_data = yaml.safe_load(file)

        def get_config_value(key, env_var):
            if key in config_data:
                return config_data[key]

            if self._get_env_var(env_var) is not None:
                return self._get_env_var(env_var)

            return getattr(defaults, key, None)

        namespace = get_config_value("namespace", "HATCHET_CLIENT_NAMESPACE")

        tenant_id = get_config_value("tenantId", "HATCHET_CLIENT_TENANT_ID")
        token = get_config_value("token", "HATCHET_CLIENT_TOKEN")
        listener_v2_timeout = get_config_value(
            "listener_v2_timeout", "HATCHET_CLIENT_LISTENER_V2_TIMEOUT"
        )
        listener_v2_timeout = int(listener_v2_timeout) if listener_v2_timeout else None

        if not token:
            raise ValueError(
                "Token must be set via HATCHET_CLIENT_TOKEN environment variable"
            )

        host_port = get_config_value("hostPort", "HATCHET_CLIENT_HOST_PORT")
        server_url: str | None = None

        grpc_max_recv_message_length = get_config_value(
            "grpc_max_recv_message_length",
            "HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH",
        )
        grpc_max_send_message_length = get_config_value(
            "grpc_max_send_message_length",
            "HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH",
        )

        if grpc_max_recv_message_length:
            grpc_max_recv_message_length = int(grpc_max_recv_message_length)

        if grpc_max_send_message_length:
            grpc_max_send_message_length = int(grpc_max_send_message_length)

        if not host_port:
            # extract host and port from token
            server_url, grpc_broadcast_address = get_addresses_from_jwt(token)
            host_port = grpc_broadcast_address

        if not tenant_id:
            tenant_id = get_tenant_id_from_jwt(token)

        tls_config = self._load_tls_config(config_data["tls"], host_port)

        worker_healthcheck_port = int(
            get_config_value(
                "worker_healthcheck_port", "HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT"
            )
            or 8001
        )

        worker_healthcheck_enabled = (
            str(
                get_config_value(
                    "worker_healthcheck_port",
                    "HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED",
                )
            )
            == "True"
        )

        #  Add preset labels to the worker config
        worker_preset_labels: dict[str, str] = defaults.worker_preset_labels

        autoscaling_target = get_config_value(
            "autoscaling_target", "HATCHET_CLIENT_AUTOSCALING_TARGET"
        )

        if autoscaling_target:
            worker_preset_labels["hatchet-autoscaling-target"] = autoscaling_target

        legacy_otlp_headers = get_config_value(
            "otel_exporter_otlp_endpoint", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT"
        )

        legacy_otlp_headers = get_config_value(
            "otel_exporter_otlp_headers", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS"
        )

        if legacy_otlp_headers or legacy_otlp_headers:
            warn(
                "The `otel_exporter_otlp_*` fields are no longer supported as of SDK version `0.46.0`. Please see the documentation on OpenTelemetry at https://docs.hatchet.run/home/features/opentelemetry for more information on how to migrate to the new `HatchetInstrumentor`."
            )

        enable_force_kill_sync_threads = bool(
            get_config_value(
                "enable_force_kill_sync_threads",
                "HATCHET_CLIENT_ENABLE_FORCE_KILL_SYNC_THREADS",
            )
            == "True"
            or False
        )
        return ClientConfig(
            tenant_id=tenant_id,
            tls_config=tls_config,
            token=token,
            host_port=host_port,
            server_url=server_url,
            namespace=namespace,
            listener_v2_timeout=listener_v2_timeout,
            logger=defaults.logInterceptor,
            grpc_max_recv_message_length=grpc_max_recv_message_length,
            grpc_max_send_message_length=grpc_max_send_message_length,
            worker_healthcheck_port=worker_healthcheck_port,
            worker_healthcheck_enabled=worker_healthcheck_enabled,
            worker_preset_labels=worker_preset_labels,
            enable_force_kill_sync_threads=enable_force_kill_sync_threads,
        )

    def _load_tls_config(self, tls_data: Dict, host_port) -> ClientTLSConfig:
        tls_strategy = (
            tls_data["tlsStrategy"]
            if "tlsStrategy" in tls_data
            else self._get_env_var("HATCHET_CLIENT_TLS_STRATEGY")
        )

        if not tls_strategy:
            tls_strategy = "tls"

        cert_file = (
            tls_data["tlsCertFile"]
            if "tlsCertFile" in tls_data
            else self._get_env_var("HATCHET_CLIENT_TLS_CERT_FILE")
        )
        key_file = (
            tls_data["tlsKeyFile"]
            if "tlsKeyFile" in tls_data
            else self._get_env_var("HATCHET_CLIENT_TLS_KEY_FILE")
        )
        ca_file = (
            tls_data["tlsRootCAFile"]
            if "tlsRootCAFile" in tls_data
            else self._get_env_var("HATCHET_CLIENT_TLS_ROOT_CA_FILE")
        )

        server_name = (
            tls_data["tlsServerName"]
            if "tlsServerName" in tls_data
            else self._get_env_var("HATCHET_CLIENT_TLS_SERVER_NAME")
        )

        # if server_name is not set, use the host from the host_port
        if not server_name:
            server_name = host_port.split(":")[0]

        return ClientTLSConfig(tls_strategy, cert_file, key_file, ca_file, server_name)

    @staticmethod
    def _get_env_var(env_var: str, default: Optional[str] = None) -> str:
        return os.environ.get(env_var, default)