aboutsummaryrefslogtreecommitdiff
"""
Get num retries for an exception. 

- Account for retry policy by exception type.
"""

from typing import Dict, Optional, Union

from litellm.exceptions import (
    AuthenticationError,
    BadRequestError,
    ContentPolicyViolationError,
    RateLimitError,
    Timeout,
)
from litellm.types.router import RetryPolicy


def get_num_retries_from_retry_policy(
    exception: Exception,
    retry_policy: Optional[Union[RetryPolicy, dict]] = None,
    model_group: Optional[str] = None,
    model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = None,
):
    """
    BadRequestErrorRetries: Optional[int] = None
    AuthenticationErrorRetries: Optional[int] = None
    TimeoutErrorRetries: Optional[int] = None
    RateLimitErrorRetries: Optional[int] = None
    ContentPolicyViolationErrorRetries: Optional[int] = None
    """
    # if we can find the exception then in the retry policy -> return the number of retries

    if (
        model_group_retry_policy is not None
        and model_group is not None
        and model_group in model_group_retry_policy
    ):
        retry_policy = model_group_retry_policy.get(model_group, None)  # type: ignore

    if retry_policy is None:
        return None
    if isinstance(retry_policy, dict):
        retry_policy = RetryPolicy(**retry_policy)

    if (
        isinstance(exception, BadRequestError)
        and retry_policy.BadRequestErrorRetries is not None
    ):
        return retry_policy.BadRequestErrorRetries
    if (
        isinstance(exception, AuthenticationError)
        and retry_policy.AuthenticationErrorRetries is not None
    ):
        return retry_policy.AuthenticationErrorRetries
    if isinstance(exception, Timeout) and retry_policy.TimeoutErrorRetries is not None:
        return retry_policy.TimeoutErrorRetries
    if (
        isinstance(exception, RateLimitError)
        and retry_policy.RateLimitErrorRetries is not None
    ):
        return retry_policy.RateLimitErrorRetries
    if (
        isinstance(exception, ContentPolicyViolationError)
        and retry_policy.ContentPolicyViolationErrorRetries is not None
    ):
        return retry_policy.ContentPolicyViolationErrorRetries


def reset_retry_policy() -> RetryPolicy:
    return RetryPolicy()