about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/hatchet_sdk/loader.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/loader.py')
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/loader.py244
1 files changed, 244 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/loader.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/loader.py
new file mode 100644
index 00000000..0252f33a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/loader.py
@@ -0,0 +1,244 @@
+import os
+from logging import Logger, getLogger
+from typing import Dict, Optional
+from warnings import warn
+
+import yaml
+
+from .token import get_addresses_from_jwt, get_tenant_id_from_jwt
+
+
+class ClientTLSConfig:
+    def __init__(
+        self,
+        tls_strategy: str,
+        cert_file: str,
+        key_file: str,
+        ca_file: str,
+        server_name: str,
+    ):
+        self.tls_strategy = tls_strategy
+        self.cert_file = cert_file
+        self.key_file = key_file
+        self.ca_file = ca_file
+        self.server_name = server_name
+
+
+class ClientConfig:
+    logInterceptor: Logger
+
+    def __init__(
+        self,
+        tenant_id: str = None,
+        tls_config: ClientTLSConfig = None,
+        token: str = None,
+        host_port: str = "localhost:7070",
+        server_url: str = "https://app.dev.hatchet-tools.com",
+        namespace: str = None,
+        listener_v2_timeout: int = None,
+        logger: Logger = None,
+        grpc_max_recv_message_length: int = 4 * 1024 * 1024,  # 4MB
+        grpc_max_send_message_length: int = 4 * 1024 * 1024,  # 4MB
+        worker_healthcheck_port: int | None = None,
+        worker_healthcheck_enabled: bool | None = None,
+        worker_preset_labels: dict[str, str] = {},
+        enable_force_kill_sync_threads: bool = False,
+    ):
+        self.tenant_id = tenant_id
+        self.tls_config = tls_config
+        self.host_port = host_port
+        self.token = token
+        self.server_url = server_url
+        self.namespace = ""
+        self.logInterceptor = logger
+        self.grpc_max_recv_message_length = grpc_max_recv_message_length
+        self.grpc_max_send_message_length = grpc_max_send_message_length
+        self.worker_healthcheck_port = worker_healthcheck_port
+        self.worker_healthcheck_enabled = worker_healthcheck_enabled
+        self.worker_preset_labels = worker_preset_labels
+        self.enable_force_kill_sync_threads = enable_force_kill_sync_threads
+
+        if not self.logInterceptor:
+            self.logInterceptor = getLogger()
+
+        # case on whether the namespace already has a trailing underscore
+        if namespace and not namespace.endswith("_"):
+            self.namespace = f"{namespace}_"
+        elif namespace:
+            self.namespace = namespace
+
+        self.namespace = self.namespace.lower()
+
+        self.listener_v2_timeout = listener_v2_timeout
+
+
+class ConfigLoader:
+    def __init__(self, directory: str):
+        self.directory = directory
+
+    def load_client_config(self, defaults: ClientConfig) -> ClientConfig:
+        config_file_path = os.path.join(self.directory, "client.yaml")
+        config_data: object = {"tls": {}}
+
+        # determine if client.yaml exists
+        if os.path.exists(config_file_path):
+            with open(config_file_path, "r") as file:
+                config_data = yaml.safe_load(file)
+
+        def get_config_value(key, env_var):
+            if key in config_data:
+                return config_data[key]
+
+            if self._get_env_var(env_var) is not None:
+                return self._get_env_var(env_var)
+
+            return getattr(defaults, key, None)
+
+        namespace = get_config_value("namespace", "HATCHET_CLIENT_NAMESPACE")
+
+        tenant_id = get_config_value("tenantId", "HATCHET_CLIENT_TENANT_ID")
+        token = get_config_value("token", "HATCHET_CLIENT_TOKEN")
+        listener_v2_timeout = get_config_value(
+            "listener_v2_timeout", "HATCHET_CLIENT_LISTENER_V2_TIMEOUT"
+        )
+        listener_v2_timeout = int(listener_v2_timeout) if listener_v2_timeout else None
+
+        if not token:
+            raise ValueError(
+                "Token must be set via HATCHET_CLIENT_TOKEN environment variable"
+            )
+
+        host_port = get_config_value("hostPort", "HATCHET_CLIENT_HOST_PORT")
+        server_url: str | None = None
+
+        grpc_max_recv_message_length = get_config_value(
+            "grpc_max_recv_message_length",
+            "HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH",
+        )
+        grpc_max_send_message_length = get_config_value(
+            "grpc_max_send_message_length",
+            "HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH",
+        )
+
+        if grpc_max_recv_message_length:
+            grpc_max_recv_message_length = int(grpc_max_recv_message_length)
+
+        if grpc_max_send_message_length:
+            grpc_max_send_message_length = int(grpc_max_send_message_length)
+
+        if not host_port:
+            # extract host and port from token
+            server_url, grpc_broadcast_address = get_addresses_from_jwt(token)
+            host_port = grpc_broadcast_address
+
+        if not tenant_id:
+            tenant_id = get_tenant_id_from_jwt(token)
+
+        tls_config = self._load_tls_config(config_data["tls"], host_port)
+
+        worker_healthcheck_port = int(
+            get_config_value(
+                "worker_healthcheck_port", "HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT"
+            )
+            or 8001
+        )
+
+        worker_healthcheck_enabled = (
+            str(
+                get_config_value(
+                    "worker_healthcheck_port",
+                    "HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED",
+                )
+            )
+            == "True"
+        )
+
+        #  Add preset labels to the worker config
+        worker_preset_labels: dict[str, str] = defaults.worker_preset_labels
+
+        autoscaling_target = get_config_value(
+            "autoscaling_target", "HATCHET_CLIENT_AUTOSCALING_TARGET"
+        )
+
+        if autoscaling_target:
+            worker_preset_labels["hatchet-autoscaling-target"] = autoscaling_target
+
+        legacy_otlp_headers = get_config_value(
+            "otel_exporter_otlp_endpoint", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT"
+        )
+
+        legacy_otlp_headers = get_config_value(
+            "otel_exporter_otlp_headers", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS"
+        )
+
+        if legacy_otlp_headers or legacy_otlp_headers:
+            warn(
+                "The `otel_exporter_otlp_*` fields are no longer supported as of SDK version `0.46.0`. Please see the documentation on OpenTelemetry at https://docs.hatchet.run/home/features/opentelemetry for more information on how to migrate to the new `HatchetInstrumentor`."
+            )
+
+        enable_force_kill_sync_threads = bool(
+            get_config_value(
+                "enable_force_kill_sync_threads",
+                "HATCHET_CLIENT_ENABLE_FORCE_KILL_SYNC_THREADS",
+            )
+            == "True"
+            or False
+        )
+        return ClientConfig(
+            tenant_id=tenant_id,
+            tls_config=tls_config,
+            token=token,
+            host_port=host_port,
+            server_url=server_url,
+            namespace=namespace,
+            listener_v2_timeout=listener_v2_timeout,
+            logger=defaults.logInterceptor,
+            grpc_max_recv_message_length=grpc_max_recv_message_length,
+            grpc_max_send_message_length=grpc_max_send_message_length,
+            worker_healthcheck_port=worker_healthcheck_port,
+            worker_healthcheck_enabled=worker_healthcheck_enabled,
+            worker_preset_labels=worker_preset_labels,
+            enable_force_kill_sync_threads=enable_force_kill_sync_threads,
+        )
+
+    def _load_tls_config(self, tls_data: Dict, host_port) -> ClientTLSConfig:
+        tls_strategy = (
+            tls_data["tlsStrategy"]
+            if "tlsStrategy" in tls_data
+            else self._get_env_var("HATCHET_CLIENT_TLS_STRATEGY")
+        )
+
+        if not tls_strategy:
+            tls_strategy = "tls"
+
+        cert_file = (
+            tls_data["tlsCertFile"]
+            if "tlsCertFile" in tls_data
+            else self._get_env_var("HATCHET_CLIENT_TLS_CERT_FILE")
+        )
+        key_file = (
+            tls_data["tlsKeyFile"]
+            if "tlsKeyFile" in tls_data
+            else self._get_env_var("HATCHET_CLIENT_TLS_KEY_FILE")
+        )
+        ca_file = (
+            tls_data["tlsRootCAFile"]
+            if "tlsRootCAFile" in tls_data
+            else self._get_env_var("HATCHET_CLIENT_TLS_ROOT_CA_FILE")
+        )
+
+        server_name = (
+            tls_data["tlsServerName"]
+            if "tlsServerName" in tls_data
+            else self._get_env_var("HATCHET_CLIENT_TLS_SERVER_NAME")
+        )
+
+        # if server_name is not set, use the host from the host_port
+        if not server_name:
+            server_name = host_port.split(":")[0]
+
+        return ClientTLSConfig(tls_strategy, cert_file, key_file, ca_file, server_name)
+
+    @staticmethod
+    def _get_env_var(env_var: str, default: Optional[str] = None) -> str:
+        return os.environ.get(env_var, default)