1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
|
import os
from logging import Logger, getLogger
from typing import Dict, Optional
from warnings import warn
import yaml
from .token import get_addresses_from_jwt, get_tenant_id_from_jwt
class ClientTLSConfig:
def __init__(
self,
tls_strategy: str,
cert_file: str,
key_file: str,
ca_file: str,
server_name: str,
):
self.tls_strategy = tls_strategy
self.cert_file = cert_file
self.key_file = key_file
self.ca_file = ca_file
self.server_name = server_name
class ClientConfig:
logInterceptor: Logger
def __init__(
self,
tenant_id: str = None,
tls_config: ClientTLSConfig = None,
token: str = None,
host_port: str = "localhost:7070",
server_url: str = "https://app.dev.hatchet-tools.com",
namespace: str = None,
listener_v2_timeout: int = None,
logger: Logger = None,
grpc_max_recv_message_length: int = 4 * 1024 * 1024, # 4MB
grpc_max_send_message_length: int = 4 * 1024 * 1024, # 4MB
worker_healthcheck_port: int | None = None,
worker_healthcheck_enabled: bool | None = None,
worker_preset_labels: dict[str, str] = {},
enable_force_kill_sync_threads: bool = False,
):
self.tenant_id = tenant_id
self.tls_config = tls_config
self.host_port = host_port
self.token = token
self.server_url = server_url
self.namespace = ""
self.logInterceptor = logger
self.grpc_max_recv_message_length = grpc_max_recv_message_length
self.grpc_max_send_message_length = grpc_max_send_message_length
self.worker_healthcheck_port = worker_healthcheck_port
self.worker_healthcheck_enabled = worker_healthcheck_enabled
self.worker_preset_labels = worker_preset_labels
self.enable_force_kill_sync_threads = enable_force_kill_sync_threads
if not self.logInterceptor:
self.logInterceptor = getLogger()
# case on whether the namespace already has a trailing underscore
if namespace and not namespace.endswith("_"):
self.namespace = f"{namespace}_"
elif namespace:
self.namespace = namespace
self.namespace = self.namespace.lower()
self.listener_v2_timeout = listener_v2_timeout
class ConfigLoader:
def __init__(self, directory: str):
self.directory = directory
def load_client_config(self, defaults: ClientConfig) -> ClientConfig:
config_file_path = os.path.join(self.directory, "client.yaml")
config_data: object = {"tls": {}}
# determine if client.yaml exists
if os.path.exists(config_file_path):
with open(config_file_path, "r") as file:
config_data = yaml.safe_load(file)
def get_config_value(key, env_var):
if key in config_data:
return config_data[key]
if self._get_env_var(env_var) is not None:
return self._get_env_var(env_var)
return getattr(defaults, key, None)
namespace = get_config_value("namespace", "HATCHET_CLIENT_NAMESPACE")
tenant_id = get_config_value("tenantId", "HATCHET_CLIENT_TENANT_ID")
token = get_config_value("token", "HATCHET_CLIENT_TOKEN")
listener_v2_timeout = get_config_value(
"listener_v2_timeout", "HATCHET_CLIENT_LISTENER_V2_TIMEOUT"
)
listener_v2_timeout = int(listener_v2_timeout) if listener_v2_timeout else None
if not token:
raise ValueError(
"Token must be set via HATCHET_CLIENT_TOKEN environment variable"
)
host_port = get_config_value("hostPort", "HATCHET_CLIENT_HOST_PORT")
server_url: str | None = None
grpc_max_recv_message_length = get_config_value(
"grpc_max_recv_message_length",
"HATCHET_CLIENT_GRPC_MAX_RECV_MESSAGE_LENGTH",
)
grpc_max_send_message_length = get_config_value(
"grpc_max_send_message_length",
"HATCHET_CLIENT_GRPC_MAX_SEND_MESSAGE_LENGTH",
)
if grpc_max_recv_message_length:
grpc_max_recv_message_length = int(grpc_max_recv_message_length)
if grpc_max_send_message_length:
grpc_max_send_message_length = int(grpc_max_send_message_length)
if not host_port:
# extract host and port from token
server_url, grpc_broadcast_address = get_addresses_from_jwt(token)
host_port = grpc_broadcast_address
if not tenant_id:
tenant_id = get_tenant_id_from_jwt(token)
tls_config = self._load_tls_config(config_data["tls"], host_port)
worker_healthcheck_port = int(
get_config_value(
"worker_healthcheck_port", "HATCHET_CLIENT_WORKER_HEALTHCHECK_PORT"
)
or 8001
)
worker_healthcheck_enabled = (
str(
get_config_value(
"worker_healthcheck_port",
"HATCHET_CLIENT_WORKER_HEALTHCHECK_ENABLED",
)
)
== "True"
)
# Add preset labels to the worker config
worker_preset_labels: dict[str, str] = defaults.worker_preset_labels
autoscaling_target = get_config_value(
"autoscaling_target", "HATCHET_CLIENT_AUTOSCALING_TARGET"
)
if autoscaling_target:
worker_preset_labels["hatchet-autoscaling-target"] = autoscaling_target
legacy_otlp_headers = get_config_value(
"otel_exporter_otlp_endpoint", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_ENDPOINT"
)
legacy_otlp_headers = get_config_value(
"otel_exporter_otlp_headers", "HATCHET_CLIENT_OTEL_EXPORTER_OTLP_HEADERS"
)
if legacy_otlp_headers or legacy_otlp_headers:
warn(
"The `otel_exporter_otlp_*` fields are no longer supported as of SDK version `0.46.0`. Please see the documentation on OpenTelemetry at https://docs.hatchet.run/home/features/opentelemetry for more information on how to migrate to the new `HatchetInstrumentor`."
)
enable_force_kill_sync_threads = bool(
get_config_value(
"enable_force_kill_sync_threads",
"HATCHET_CLIENT_ENABLE_FORCE_KILL_SYNC_THREADS",
)
== "True"
or False
)
return ClientConfig(
tenant_id=tenant_id,
tls_config=tls_config,
token=token,
host_port=host_port,
server_url=server_url,
namespace=namespace,
listener_v2_timeout=listener_v2_timeout,
logger=defaults.logInterceptor,
grpc_max_recv_message_length=grpc_max_recv_message_length,
grpc_max_send_message_length=grpc_max_send_message_length,
worker_healthcheck_port=worker_healthcheck_port,
worker_healthcheck_enabled=worker_healthcheck_enabled,
worker_preset_labels=worker_preset_labels,
enable_force_kill_sync_threads=enable_force_kill_sync_threads,
)
def _load_tls_config(self, tls_data: Dict, host_port) -> ClientTLSConfig:
tls_strategy = (
tls_data["tlsStrategy"]
if "tlsStrategy" in tls_data
else self._get_env_var("HATCHET_CLIENT_TLS_STRATEGY")
)
if not tls_strategy:
tls_strategy = "tls"
cert_file = (
tls_data["tlsCertFile"]
if "tlsCertFile" in tls_data
else self._get_env_var("HATCHET_CLIENT_TLS_CERT_FILE")
)
key_file = (
tls_data["tlsKeyFile"]
if "tlsKeyFile" in tls_data
else self._get_env_var("HATCHET_CLIENT_TLS_KEY_FILE")
)
ca_file = (
tls_data["tlsRootCAFile"]
if "tlsRootCAFile" in tls_data
else self._get_env_var("HATCHET_CLIENT_TLS_ROOT_CA_FILE")
)
server_name = (
tls_data["tlsServerName"]
if "tlsServerName" in tls_data
else self._get_env_var("HATCHET_CLIENT_TLS_SERVER_NAME")
)
# if server_name is not set, use the host from the host_port
if not server_name:
server_name = host_port.split(":")[0]
return ClientTLSConfig(tls_strategy, cert_file, key_file, ca_file, server_name)
@staticmethod
def _get_env_var(env_var: str, default: Optional[str] = None) -> str:
return os.environ.get(env_var, default)
|