aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/hatchet_sdk/clients/rest/tenacity_utils.py
blob: 377266a1ac9ba964f7198a6d173c0d08a9b225e0 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from typing import Callable, ParamSpec, TypeVar

import grpc
import tenacity

from hatchet_sdk.logger import logger

P = ParamSpec("P")
R = TypeVar("R")


def tenacity_retry(func: Callable[P, R]) -> Callable[P, R]:
    return tenacity.retry(
        reraise=True,
        wait=tenacity.wait_exponential_jitter(),
        stop=tenacity.stop_after_attempt(5),
        before_sleep=tenacity_alert_retry,
        retry=tenacity.retry_if_exception(tenacity_should_retry),
    )(func)


def tenacity_alert_retry(retry_state: tenacity.RetryCallState) -> None:
    """Called between tenacity retries."""
    logger.debug(
        f"Retrying {retry_state.fn}: attempt "
        f"{retry_state.attempt_number} ended with: {retry_state.outcome}",
    )


def tenacity_should_retry(ex: Exception) -> bool:
    if isinstance(ex, (grpc.aio.AioRpcError, grpc.RpcError)):
        if ex.code() in [
            grpc.StatusCode.UNIMPLEMENTED,
            grpc.StatusCode.NOT_FOUND,
        ]:
            return False
        return True
    else:
        return False