diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/loader.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/hatchet_sdk/loader.py | 244 |
1 files changed, 244 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/loader.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/loader.py new file mode 100644 index 00000000..0252f33a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/loader.py @@ -0,0 +1,244 @@ +import os +from logging import Logger, getLogger +from typing import Dict, Optional +from warnings import warn + +import yaml + +from .token import get_addresses_from_jwt, get_tenant_id_from_jwt + + +class ClientTLSConfig: + def __init__( + self, + tls_strategy: str, + cert_file: str, + key_file: str, + ca_file: str, + server_name: str, + ): + self.tls_strategy = tls_strategy + self.cert_file = cert_file + self.key_file = key_file + self.ca_file = ca_file + self.server_name = server_name + + +class ClientConfig: + logInterceptor: Logger + + def __init__( + self, + tenant_id: str = None, + tls_config: ClientTLSConfig = None, + token: str = None, + host_port: str = "localhost:7070", + server_url: str = "https://app.dev.hatchet-tools.com", + namespace: str = None, + listener_v2_timeout: int = None, + logger: Logger = None, + grpc_max_recv_message_length: int = 4 * 1024 * 1024, # 4MB + grpc_max_send_message_length: int = 4 * 1024 * 1024, # 4MB + worker_healthcheck_port: int | None = None, + worker_healthcheck_enabled: bool | None = None, + worker_preset_labels: dict[str, str] = {}, + enable_force_kill_sync_threads: bool = False, + ): + self.tenant_id = tenant_id + self.tls_config = tls_config + self.host_port = host_port + self.token = token + self.server_url = server_url + self.namespace = "" + self.logInterceptor = logger + self.grpc_max_recv_message_length = grpc_max_recv_message_length + self.grpc_max_send_message_length = grpc_max_send_message_length + self.worker_healthcheck_port = worker_healthcheck_port + self.worker_healthcheck_enabled = worker_healthcheck_enabled + self.worker_preset_labels = worker_preset_labels + self.enable_force_kill_sync_threads = enable_force_kill_sync_threads + + if not self.logInterceptor: + self.logInterceptor = getLogger() + + # case on whether the namespace already has a trailing underscore + if namespace and not namespace.endswith("_"): + self.namespace = f"{namespace}_" + elif namespace: + self.namespace = namespace + + self.namespace = self.namespace.lower() + + self.listener_v2_timeout = listener_v2_timeout + + +class ConfigLoader: + def __init__(self, directory: str): + self.directory = directory + + def load_client_config(self, defaults: ClientConfig) -> ClientConfig: + config_file_path = os.path.join(self.directory, "client.yaml") + config_data: object = {"tls": {}} + + # determine if client.yaml exists + if os.path.exists(config_file_path): + with open(config_file_path, "r") as file: + config_data = yaml.safe_load(file) + + def get_config_value(key, env_var): + if key in config_data: + return config_data[key] + + if self._get_env_var(env_var) is not None: + return self._get_env_var(env_var) + + return getattr(defaults, key, None) + + namespace = get_config_value("namespace", "HATCHET_CLIENT_NAMESPACE") + + tenant_id = get_config_value("tenantId", "HATCHET_CLIENT_TENANT_ID") + token = get_config_value("token", "HATCHET_CLIENT_TOKEN") + listener_v2_timeout = get_config_value( + "listener_v2_timeout", "HATCHET_CLIENT_LISTENER_V2_TIMEOUT" + ) + listener_v2_timeout = int(listener_v2_timeout) if listener_v2_timeout else None + + if not token: + raise ValueError( + "Token must be set via HATCHET_CLIENT_TOKEN environment variable" + ) + + host_port = get_config_value("hostPort", "HATCHET_CLIENT_HOST_PORT") + server_url: str | None = None + + grpc_max_recv_message_length = get_config_value( + "grpc_max_recv_message_length", + "HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH", + ) + grpc_max_send_message_length = get_config_value( + "grpc_max_send_message_length", + "HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH", + ) + + if grpc_max_recv_message_length: + grpc_max_recv_message_length = int(grpc_max_recv_message_length) + + if grpc_max_send_message_length: + grpc_max_send_message_length = int(grpc_max_send_message_length) + + if not host_port: + # extract host and port from token + server_url, grpc_broadcast_address = get_addresses_from_jwt(token) + host_port = grpc_broadcast_address + + if not tenant_id: + tenant_id = get_tenant_id_from_jwt(token) + + tls_config = self._load_tls_config(config_data["tls"], host_port) + + worker_healthcheck_port = int( + get_config_value( + "worker_healthcheck_port", "HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT" + ) + or 8001 + ) + + worker_healthcheck_enabled = ( + str( + get_config_value( + "worker_healthcheck_port", + "HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED", + ) + ) + == "True" + ) + + # Add preset labels to the worker config + worker_preset_labels: dict[str, str] = defaults.worker_preset_labels + + autoscaling_target = get_config_value( + "autoscaling_target", "HATCHET_CLIENT_AUTOSCALING_TARGET" + ) + + if autoscaling_target: + worker_preset_labels["hatchet-autoscaling-target"] = autoscaling_target + + legacy_otlp_headers = get_config_value( + "otel_exporter_otlp_endpoint", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT" + ) + + legacy_otlp_headers = get_config_value( + "otel_exporter_otlp_headers", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS" + ) + + if legacy_otlp_headers or legacy_otlp_headers: + warn( + "The `otel_exporter_otlp_*` fields are no longer supported as of SDK version `0.46.0`. Please see the documentation on OpenTelemetry at https://docs.hatchet.run/home/features/opentelemetry for more information on how to migrate to the new `HatchetInstrumentor`." + ) + + enable_force_kill_sync_threads = bool( + get_config_value( + "enable_force_kill_sync_threads", + "HATCHET_CLIENT_ENABLE_FORCE_KILL_SYNC_THREADS", + ) + == "True" + or False + ) + return ClientConfig( + tenant_id=tenant_id, + tls_config=tls_config, + token=token, + host_port=host_port, + server_url=server_url, + namespace=namespace, + listener_v2_timeout=listener_v2_timeout, + logger=defaults.logInterceptor, + grpc_max_recv_message_length=grpc_max_recv_message_length, + grpc_max_send_message_length=grpc_max_send_message_length, + worker_healthcheck_port=worker_healthcheck_port, + worker_healthcheck_enabled=worker_healthcheck_enabled, + worker_preset_labels=worker_preset_labels, + enable_force_kill_sync_threads=enable_force_kill_sync_threads, + ) + + def _load_tls_config(self, tls_data: Dict, host_port) -> ClientTLSConfig: + tls_strategy = ( + tls_data["tlsStrategy"] + if "tlsStrategy" in tls_data + else self._get_env_var("HATCHET_CLIENT_TLS_STRATEGY") + ) + + if not tls_strategy: + tls_strategy = "tls" + + cert_file = ( + tls_data["tlsCertFile"] + if "tlsCertFile" in tls_data + else self._get_env_var("HATCHET_CLIENT_TLS_CERT_FILE") + ) + key_file = ( + tls_data["tlsKeyFile"] + if "tlsKeyFile" in tls_data + else self._get_env_var("HATCHET_CLIENT_TLS_KEY_FILE") + ) + ca_file = ( + tls_data["tlsRootCAFile"] + if "tlsRootCAFile" in tls_data + else self._get_env_var("HATCHET_CLIENT_TLS_ROOT_CA_FILE") + ) + + server_name = ( + tls_data["tlsServerName"] + if "tlsServerName" in tls_data + else self._get_env_var("HATCHET_CLIENT_TLS_SERVER_NAME") + ) + + # if server_name is not set, use the host from the host_port + if not server_name: + server_name = host_port.split(":")[0] + + return ClientTLSConfig(tls_strategy, cert_file, key_file, ca_file, server_name) + + @staticmethod + def _get_env_var(env_var: str, default: Optional[str] = None) -> str: + return os.environ.get(env_var, default) |