aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/hatchet_sdk/loader.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/hatchet_sdk/loader.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/hatchet_sdk/loader.py')
-rw-r--r--.venv/lib/python3.12/site-packages/hatchet_sdk/loader.py244
1 files changed, 244 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/hatchet_sdk/loader.py b/.venv/lib/python3.12/site-packages/hatchet_sdk/loader.py
new file mode 100644
index 00000000..0252f33a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/hatchet_sdk/loader.py
@@ -0,0 +1,244 @@
+import os
+from logging import Logger, getLogger
+from typing import Dict, Optional
+from warnings import warn
+
+import yaml
+
+from .token import get_addresses_from_jwt, get_tenant_id_from_jwt
+
+
+class ClientTLSConfig:
+ def __init__(
+ self,
+ tls_strategy: str,
+ cert_file: str,
+ key_file: str,
+ ca_file: str,
+ server_name: str,
+ ):
+ self.tls_strategy = tls_strategy
+ self.cert_file = cert_file
+ self.key_file = key_file
+ self.ca_file = ca_file
+ self.server_name = server_name
+
+
+class ClientConfig:
+ logInterceptor: Logger
+
+ def __init__(
+ self,
+ tenant_id: str = None,
+ tls_config: ClientTLSConfig = None,
+ token: str = None,
+ host_port: str = "localhost:7070",
+ server_url: str = "https://app.dev.hatchet-tools.com",
+ namespace: str = None,
+ listener_v2_timeout: int = None,
+ logger: Logger = None,
+ grpc_max_recv_message_length: int = 4 * 1024 * 1024, # 4MB
+ grpc_max_send_message_length: int = 4 * 1024 * 1024, # 4MB
+ worker_healthcheck_port: int | None = None,
+ worker_healthcheck_enabled: bool | None = None,
+ worker_preset_labels: dict[str, str] = {},
+ enable_force_kill_sync_threads: bool = False,
+ ):
+ self.tenant_id = tenant_id
+ self.tls_config = tls_config
+ self.host_port = host_port
+ self.token = token
+ self.server_url = server_url
+ self.namespace = ""
+ self.logInterceptor = logger
+ self.grpc_max_recv_message_length = grpc_max_recv_message_length
+ self.grpc_max_send_message_length = grpc_max_send_message_length
+ self.worker_healthcheck_port = worker_healthcheck_port
+ self.worker_healthcheck_enabled = worker_healthcheck_enabled
+ self.worker_preset_labels = worker_preset_labels
+ self.enable_force_kill_sync_threads = enable_force_kill_sync_threads
+
+ if not self.logInterceptor:
+ self.logInterceptor = getLogger()
+
+ # case on whether the namespace already has a trailing underscore
+ if namespace and not namespace.endswith("_"):
+ self.namespace = f"{namespace}_"
+ elif namespace:
+ self.namespace = namespace
+
+ self.namespace = self.namespace.lower()
+
+ self.listener_v2_timeout = listener_v2_timeout
+
+
+class ConfigLoader:
+ def __init__(self, directory: str):
+ self.directory = directory
+
+ def load_client_config(self, defaults: ClientConfig) -> ClientConfig:
+ config_file_path = os.path.join(self.directory, "client.yaml")
+ config_data: object = {"tls": {}}
+
+ # determine if client.yaml exists
+ if os.path.exists(config_file_path):
+ with open(config_file_path, "r") as file:
+ config_data = yaml.safe_load(file)
+
+ def get_config_value(key, env_var):
+ if key in config_data:
+ return config_data[key]
+
+ if self._get_env_var(env_var) is not None:
+ return self._get_env_var(env_var)
+
+ return getattr(defaults, key, None)
+
+ namespace = get_config_value("namespace", "HATCHET_CLIENT_NAMESPACE")
+
+ tenant_id = get_config_value("tenantId", "HATCHET_CLIENT_TENANT_ID")
+ token = get_config_value("token", "HATCHET_CLIENT_TOKEN")
+ listener_v2_timeout = get_config_value(
+ "listener_v2_timeout", "HATCHET_CLIENT_LISTENER_V2_TIMEOUT"
+ )
+ listener_v2_timeout = int(listener_v2_timeout) if listener_v2_timeout else None
+
+ if not token:
+ raise ValueError(
+ "Token must be set via HATCHET_CLIENT_TOKEN environment variable"
+ )
+
+ host_port = get_config_value("hostPort", "HATCHET_CLIENT_HOST_PORT")
+ server_url: str | None = None
+
+ grpc_max_recv_message_length = get_config_value(
+ "grpc_max_recv_message_length",
+ "HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH",
+ )
+ grpc_max_send_message_length = get_config_value(
+ "grpc_max_send_message_length",
+ "HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH",
+ )
+
+ if grpc_max_recv_message_length:
+ grpc_max_recv_message_length = int(grpc_max_recv_message_length)
+
+ if grpc_max_send_message_length:
+ grpc_max_send_message_length = int(grpc_max_send_message_length)
+
+ if not host_port:
+ # extract host and port from token
+ server_url, grpc_broadcast_address = get_addresses_from_jwt(token)
+ host_port = grpc_broadcast_address
+
+ if not tenant_id:
+ tenant_id = get_tenant_id_from_jwt(token)
+
+ tls_config = self._load_tls_config(config_data["tls"], host_port)
+
+ worker_healthcheck_port = int(
+ get_config_value(
+ "worker_healthcheck_port", "HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT"
+ )
+ or 8001
+ )
+
+ worker_healthcheck_enabled = (
+ str(
+ get_config_value(
+ "worker_healthcheck_port",
+ "HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED",
+ )
+ )
+ == "True"
+ )
+
+ # Add preset labels to the worker config
+ worker_preset_labels: dict[str, str] = defaults.worker_preset_labels
+
+ autoscaling_target = get_config_value(
+ "autoscaling_target", "HATCHET_CLIENT_AUTOSCALING_TARGET"
+ )
+
+ if autoscaling_target:
+ worker_preset_labels["hatchet-autoscaling-target"] = autoscaling_target
+
+ legacy_otlp_headers = get_config_value(
+ "otel_exporter_otlp_endpoint", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT"
+ )
+
+ legacy_otlp_headers = get_config_value(
+ "otel_exporter_otlp_headers", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS"
+ )
+
+ if legacy_otlp_headers or legacy_otlp_headers:
+ warn(
+ "The `otel_exporter_otlp_*` fields are no longer supported as of SDK version `0.46.0`. Please see the documentation on OpenTelemetry at https://docs.hatchet.run/home/features/opentelemetry for more information on how to migrate to the new `HatchetInstrumentor`."
+ )
+
+ enable_force_kill_sync_threads = bool(
+ get_config_value(
+ "enable_force_kill_sync_threads",
+ "HATCHET_CLIENT_ENABLE_FORCE_KILL_SYNC_THREADS",
+ )
+ == "True"
+ or False
+ )
+ return ClientConfig(
+ tenant_id=tenant_id,
+ tls_config=tls_config,
+ token=token,
+ host_port=host_port,
+ server_url=server_url,
+ namespace=namespace,
+ listener_v2_timeout=listener_v2_timeout,
+ logger=defaults.logInterceptor,
+ grpc_max_recv_message_length=grpc_max_recv_message_length,
+ grpc_max_send_message_length=grpc_max_send_message_length,
+ worker_healthcheck_port=worker_healthcheck_port,
+ worker_healthcheck_enabled=worker_healthcheck_enabled,
+ worker_preset_labels=worker_preset_labels,
+ enable_force_kill_sync_threads=enable_force_kill_sync_threads,
+ )
+
+ def _load_tls_config(self, tls_data: Dict, host_port) -> ClientTLSConfig:
+ tls_strategy = (
+ tls_data["tlsStrategy"]
+ if "tlsStrategy" in tls_data
+ else self._get_env_var("HATCHET_CLIENT_TLS_STRATEGY")
+ )
+
+ if not tls_strategy:
+ tls_strategy = "tls"
+
+ cert_file = (
+ tls_data["tlsCertFile"]
+ if "tlsCertFile" in tls_data
+ else self._get_env_var("HATCHET_CLIENT_TLS_CERT_FILE")
+ )
+ key_file = (
+ tls_data["tlsKeyFile"]
+ if "tlsKeyFile" in tls_data
+ else self._get_env_var("HATCHET_CLIENT_TLS_KEY_FILE")
+ )
+ ca_file = (
+ tls_data["tlsRootCAFile"]
+ if "tlsRootCAFile" in tls_data
+ else self._get_env_var("HATCHET_CLIENT_TLS_ROOT_CA_FILE")
+ )
+
+ server_name = (
+ tls_data["tlsServerName"]
+ if "tlsServerName" in tls_data
+ else self._get_env_var("HATCHET_CLIENT_TLS_SERVER_NAME")
+ )
+
+ # if server_name is not set, use the host from the host_port
+ if not server_name:
+ server_name = host_port.split(":")[0]
+
+ return ClientTLSConfig(tls_strategy, cert_file, key_file, ca_file, server_name)
+
+ @staticmethod
+ def _get_env_var(env_var: str, default: Optional[str] = None) -> str:
+ return os.environ.get(env_var, default)