aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/hatchet_sdk/connection.py
blob: 185395e4cd9598f6a89a82339ceea34d80c630d8 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
from typing import TYPE_CHECKING, Any

import grpc

if TYPE_CHECKING:
    from hatchet_sdk.loader import ClientConfig


def new_conn(config: "ClientConfig", aio=False):

    credentials: grpc.ChannelCredentials | None = None

    # load channel credentials
    if config.tls_config.tls_strategy == "tls":
        root: Any | None = None

        if config.tls_config.ca_file:
            root = open(config.tls_config.ca_file, "rb").read()

        credentials = grpc.ssl_channel_credentials(root_certificates=root)
    elif config.tls_config.tls_strategy == "mtls":
        root = open(config.tls_config.ca_file, "rb").read()
        private_key = open(config.tls_config.key_file, "rb").read()
        certificate_chain = open(config.tls_config.cert_file, "rb").read()

        credentials = grpc.ssl_channel_credentials(
            root_certificates=root,
            private_key=private_key,
            certificate_chain=certificate_chain,
        )

    start = grpc if not aio else grpc.aio

    channel_options = [
        ("grpc.max_send_message_length", config.grpc_max_send_message_length),
        ("grpc.max_receive_message_length", config.grpc_max_recv_message_length),
        ("grpc.keepalive_time_ms", 10 * 1000),
        ("grpc.keepalive_timeout_ms", 60 * 1000),
        ("grpc.client_idle_timeout_ms", 60 * 1000),
        ("grpc.http2.max_pings_without_data", 0),
        ("grpc.keepalive_permit_without_calls", 1),
    ]

    # Set environment variable to disable fork support. Reference: https://github.com/grpc/grpc/issues/28557
    # When steps execute via os.fork, we see `TSI_DATA_CORRUPTED` errors.
    os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "False"

    if config.tls_config.tls_strategy == "none":
        conn = start.insecure_channel(
            target=config.host_port,
            options=channel_options,
        )
    else:
        channel_options.append(
            ("grpc.ssl_target_name_override", config.tls_config.server_name)
        )

        conn = start.secure_channel(
            target=config.host_port,
            credentials=credentials,
            options=channel_options,
        )
    return conn