import os from logging import Logger, getLogger from typing import Dict, Optional from warnings import warn import yaml from .token import get_addresses_from_jwt, get_tenant_id_from_jwt class ClientTLSConfig: def __init__( self, tls_strategy: str, cert_file: str, key_file: str, ca_file: str, server_name: str, ): self.tls_strategy = tls_strategy self.cert_file = cert_file self.key_file = key_file self.ca_file = ca_file self.server_name = server_name class ClientConfig: logInterceptor: Logger def __init__( self, tenant_id: str = None, tls_config: ClientTLSConfig = None, token: str = None, host_port: str = "localhost:7070", server_url: str = "https://app.dev.hatchet-tools.com", namespace: str = None, listener_v2_timeout: int = None, logger: Logger = None, grpc_max_recv_message_length: int = 4 * 1024 * 1024, # 4MB grpc_max_send_message_length: int = 4 * 1024 * 1024, # 4MB worker_healthcheck_port: int | None = None, worker_healthcheck_enabled: bool | None = None, worker_preset_labels: dict[str, str] = {}, enable_force_kill_sync_threads: bool = False, ): self.tenant_id = tenant_id self.tls_config = tls_config self.host_port = host_port self.token = token self.server_url = server_url self.namespace = "" self.logInterceptor = logger self.grpc_max_recv_message_length = grpc_max_recv_message_length self.grpc_max_send_message_length = grpc_max_send_message_length self.worker_healthcheck_port = worker_healthcheck_port self.worker_healthcheck_enabled = worker_healthcheck_enabled self.worker_preset_labels = worker_preset_labels self.enable_force_kill_sync_threads = enable_force_kill_sync_threads if not self.logInterceptor: self.logInterceptor = getLogger() # case on whether the namespace already has a trailing underscore if namespace and not namespace.endswith("_"): self.namespace = f"{namespace}_" elif namespace: self.namespace = namespace self.namespace = self.namespace.lower() self.listener_v2_timeout = listener_v2_timeout class ConfigLoader: def __init__(self, directory: str): self.directory = directory def load_client_config(self, defaults: ClientConfig) -> ClientConfig: config_file_path = os.path.join(self.directory, "client.yaml") config_data: object = {"tls": {}} # determine if client.yaml exists if os.path.exists(config_file_path): with open(config_file_path, "r") as file: config_data = yaml.safe_load(file) def get_config_value(key, env_var): if key in config_data: return config_data[key] if self._get_env_var(env_var) is not None: return self._get_env_var(env_var) return getattr(defaults, key, None) namespace = get_config_value("namespace", "HATCHET_CLIENT_NAMESPACE") tenant_id = get_config_value("tenantId", "HATCHET_CLIENT_TENANT_ID") token = get_config_value("token", "HATCHET_CLIENT_TOKEN") listener_v2_timeout = get_config_value( "listener_v2_timeout", "HATCHET_CLIENT_LISTENER_V2_TIMEOUT" ) listener_v2_timeout = int(listener_v2_timeout) if listener_v2_timeout else None if not token: raise ValueError( "Token must be set via HATCHET_CLIENT_TOKEN environment variable" ) host_port = get_config_value("hostPort", "HATCHET_CLIENT_HOST_PORT") server_url: str | None = None grpc_max_recv_message_length = get_config_value( "grpc_max_recv_message_length", "HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH", ) grpc_max_send_message_length = get_config_value( "grpc_max_send_message_length", "HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH", ) if grpc_max_recv_message_length: grpc_max_recv_message_length = int(grpc_max_recv_message_length) if grpc_max_send_message_length: grpc_max_send_message_length = int(grpc_max_send_message_length) if not host_port: # extract host and port from token server_url, grpc_broadcast_address = get_addresses_from_jwt(token) host_port = grpc_broadcast_address if not tenant_id: tenant_id = get_tenant_id_from_jwt(token) tls_config = self._load_tls_config(config_data["tls"], host_port) worker_healthcheck_port = int( get_config_value( "worker_healthcheck_port", "HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT" ) or 8001 ) worker_healthcheck_enabled = ( str( get_config_value( "worker_healthcheck_port", "HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED", ) ) == "True" ) # Add preset labels to the worker config worker_preset_labels: dict[str, str] = defaults.worker_preset_labels autoscaling_target = get_config_value( "autoscaling_target", "HATCHET_CLIENT_AUTOSCALING_TARGET" ) if autoscaling_target: worker_preset_labels["hatchet-autoscaling-target"] = autoscaling_target legacy_otlp_headers = get_config_value( "otel_exporter_otlp_endpoint", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT" ) legacy_otlp_headers = get_config_value( "otel_exporter_otlp_headers", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS" ) if legacy_otlp_headers or legacy_otlp_headers: warn( "The `otel_exporter_otlp_*` fields are no longer supported as of SDK version `0.46.0`. Please see the documentation on OpenTelemetry at https://docs.hatchet.run/home/features/opentelemetry for more information on how to migrate to the new `HatchetInstrumentor`." ) enable_force_kill_sync_threads = bool( get_config_value( "enable_force_kill_sync_threads", "HATCHET_CLIENT_ENABLE_FORCE_KILL_SYNC_THREADS", ) == "True" or False ) return ClientConfig( tenant_id=tenant_id, tls_config=tls_config, token=token, host_port=host_port, server_url=server_url, namespace=namespace, listener_v2_timeout=listener_v2_timeout, logger=defaults.logInterceptor, grpc_max_recv_message_length=grpc_max_recv_message_length, grpc_max_send_message_length=grpc_max_send_message_length, worker_healthcheck_port=worker_healthcheck_port, worker_healthcheck_enabled=worker_healthcheck_enabled, worker_preset_labels=worker_preset_labels, enable_force_kill_sync_threads=enable_force_kill_sync_threads, ) def _load_tls_config(self, tls_data: Dict, host_port) -> ClientTLSConfig: tls_strategy = ( tls_data["tlsStrategy"] if "tlsStrategy" in tls_data else self._get_env_var("HATCHET_CLIENT_TLS_STRATEGY") ) if not tls_strategy: tls_strategy = "tls" cert_file = ( tls_data["tlsCertFile"] if "tlsCertFile" in tls_data else self._get_env_var("HATCHET_CLIENT_TLS_CERT_FILE") ) key_file = ( tls_data["tlsKeyFile"] if "tlsKeyFile" in tls_data else self._get_env_var("HATCHET_CLIENT_TLS_KEY_FILE") ) ca_file = ( tls_data["tlsRootCAFile"] if "tlsRootCAFile" in tls_data else self._get_env_var("HATCHET_CLIENT_TLS_ROOT_CA_FILE") ) server_name = ( tls_data["tlsServerName"] if "tlsServerName" in tls_data else self._get_env_var("HATCHET_CLIENT_TLS_SERVER_NAME") ) # if server_name is not set, use the host from the host_port if not server_name: server_name = host_port.split(":")[0] return ClientTLSConfig(tls_strategy, cert_file, key_file, ca_file, server_name) @staticmethod def _get_env_var(env_var: str, default: Optional[str] = None) -> str: return os.environ.get(env_var, default)