aboutsummaryrefslogtreecommitdiff
from __future__ import annotations

import asyncio
import logging
import sys
from abc import abstractmethod
from dataclasses import dataclass
from typing import (
    TYPE_CHECKING,
    Any,
    Awaitable,
    Callable,
    Generator,
    List,
    Tuple,
    Union,
)

from aiohttp import ClientResponse, ClientSession, hdrs
from aiohttp.typedefs import StrOrURL
from yarl import URL as YARL_URL

from .retry_options import ExponentialRetry, RetryOptionsBase

_MIN_SERVER_ERROR_STATUS = 500

if TYPE_CHECKING:
    from types import TracebackType

if sys.version_info >= (3, 8):
    from typing import Protocol
else:
    from typing_extensions import Protocol


class _Logger(Protocol):
    """_Logger defines which methods logger object should have."""

    @abstractmethod
    def debug(self, msg: str, *args: Any, **kwargs: Any) -> None:
        pass

    @abstractmethod
    def warning(self, msg: str, *args: Any, **kwargs: Any) -> None:
        pass

    @abstractmethod
    def exception(self, msg: str, *args: Any, **kwargs: Any) -> None:
        pass


# url itself or list of urls for changing between retries
_RAW_URL_TYPE = Union[StrOrURL, YARL_URL]
_URL_TYPE = Union[_RAW_URL_TYPE, List[_RAW_URL_TYPE], Tuple[_RAW_URL_TYPE, ...]]
_LoggerType = Union[_Logger, logging.Logger]

RequestFunc = Callable[..., Awaitable[ClientResponse]]


@dataclass
class RequestParams:
    method: str
    url: _RAW_URL_TYPE
    headers: dict[str, Any] | None = None
    trace_request_ctx: dict[str, Any] | None = None
    kwargs: dict[str, Any] | None = None


class _RequestContext:
    def __init__(
        self,
        request_func: RequestFunc,
        params_list: list[RequestParams],
        logger: _LoggerType,
        retry_options: RetryOptionsBase,
        raise_for_status: bool = False,
    ) -> None:
        assert len(params_list) > 0  # noqa: S101

        self._request_func = request_func
        self._params_list = params_list
        self._logger = logger
        self._retry_options = retry_options
        self._raise_for_status = raise_for_status

        self._response: ClientResponse | None = None

    async def _is_skip_retry(self, current_attempt: int, response: ClientResponse) -> bool:
        if current_attempt == self._retry_options.attempts:
            return True

        if response.method.upper() not in self._retry_options.methods:
            return True

        if response.status >= _MIN_SERVER_ERROR_STATUS and self._retry_options.retry_all_server_errors:
            return False

        if response.status in self._retry_options.statuses:
            return False

        if self._retry_options.evaluate_response_callback is None:
            return True

        return await self._retry_options.evaluate_response_callback(response)

    async def _do_request(self) -> ClientResponse:
        current_attempt = 0

        while True:
            self._logger.debug(f"Attempt {current_attempt+1} out of {self._retry_options.attempts}")

            current_attempt += 1
            try:
                try:
                    params = self._params_list[current_attempt - 1]
                except IndexError:
                    params = self._params_list[-1]

                response: ClientResponse = await self._request_func(
                    params.method,
                    params.url,
                    headers=params.headers,
                    trace_request_ctx={
                        "current_attempt": current_attempt,
                        **(params.trace_request_ctx or {}),
                    },
                    **(params.kwargs or {}),
                )

                debug_message = f"Retrying after response code: {response.status}"
                skip_retry = await self._is_skip_retry(current_attempt, response)

                if skip_retry:
                    if self._raise_for_status:
                        response.raise_for_status()
                    self._response = response
                    return self._response
                retry_wait = self._retry_options.get_timeout(attempt=current_attempt, response=response)

            except Exception as e:
                if current_attempt >= self._retry_options.attempts:
                    raise

                is_exc_valid = any(isinstance(e, exc) for exc in self._retry_options.exceptions)
                if not is_exc_valid:
                    raise

                debug_message = f"Retrying after exception: {e!r}"
                retry_wait = self._retry_options.get_timeout(attempt=current_attempt, response=None)

            self._logger.debug(debug_message)
            await asyncio.sleep(retry_wait)

    def __await__(self) -> Generator[Any, None, ClientResponse]:
        return self.__aenter__().__await__()

    async def __aenter__(self) -> ClientResponse:
        return await self._do_request()

    async def __aexit__(
        self,
        exc_type: type[BaseException] | None,
        exc_val: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        if self._response is not None and not self._response.closed:
            self._response.close()


def _url_to_urls(url: _URL_TYPE) -> tuple[StrOrURL, ...]:
    if isinstance(url, (str, YARL_URL)):
        return (url,)

    if isinstance(url, list):
        urls = tuple(url)
    elif isinstance(url, tuple):
        urls = url
    else:
        msg = "you can pass url only by str or list/tuple"  # type: ignore[unreachable]
        raise ValueError(msg)  # noqa: TRY004

    if len(urls) == 0:
        msg = "you can pass url by str or list/tuple with attempts count size"
        raise ValueError(msg)

    return urls


class RetryClient:
    def __init__(
        self,
        client_session: ClientSession | None = None,
        logger: _LoggerType | None = None,
        retry_options: RetryOptionsBase | None = None,
        raise_for_status: bool = False,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        if client_session is not None:
            client = client_session
            closed = None
        else:
            client = ClientSession(*args, **kwargs)
            closed = False

        self._client = client
        self._closed = closed

        self._logger: _LoggerType = logger or logging.getLogger("aiohttp_retry")
        self._retry_options: RetryOptionsBase = retry_options or ExponentialRetry()
        self._raise_for_status = raise_for_status

    @property
    def retry_options(self) -> RetryOptionsBase:
        return self._retry_options

    def requests(
        self,
        params_list: list[RequestParams],
        retry_options: RetryOptionsBase | None = None,
        raise_for_status: bool | None = None,
    ) -> _RequestContext:
        return self._make_requests(
            params_list=params_list,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
        )

    def request(
        self,
        method: str,
        url: StrOrURL,
        retry_options: RetryOptionsBase | None = None,
        raise_for_status: bool | None = None,
        **kwargs: Any,
    ) -> _RequestContext:
        return self._make_request(
            method=method,
            url=url,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
            **kwargs,
        )

    def get(
        self,
        url: _URL_TYPE,
        retry_options: RetryOptionsBase | None = None,
        raise_for_status: bool | None = None,
        **kwargs: Any,
    ) -> _RequestContext:
        return self._make_request(
            method=hdrs.METH_GET,
            url=url,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
            **kwargs,
        )

    def options(
        self,
        url: _URL_TYPE,
        retry_options: RetryOptionsBase | None = None,
        raise_for_status: bool | None = None,
        **kwargs: Any,
    ) -> _RequestContext:
        return self._make_request(
            method=hdrs.METH_OPTIONS,
            url=url,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
            **kwargs,
        )

    def head(
        self,
        url: _URL_TYPE,
        retry_options: RetryOptionsBase | None = None,
        raise_for_status: bool | None = None,
        **kwargs: Any,
    ) -> _RequestContext:
        return self._make_request(
            method=hdrs.METH_HEAD,
            url=url,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
            **kwargs,
        )

    def post(
        self,
        url: _URL_TYPE,
        retry_options: RetryOptionsBase | None = None,
        raise_for_status: bool | None = None,
        **kwargs: Any,
    ) -> _RequestContext:
        return self._make_request(
            method=hdrs.METH_POST,
            url=url,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
            **kwargs,
        )

    def put(
        self,
        url: _URL_TYPE,
        retry_options: RetryOptionsBase | None = None,
        raise_for_status: bool | None = None,
        **kwargs: Any,
    ) -> _RequestContext:
        return self._make_request(
            method=hdrs.METH_PUT,
            url=url,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
            **kwargs,
        )

    def patch(
        self,
        url: _URL_TYPE,
        retry_options: RetryOptionsBase | None = None,
        raise_for_status: bool | None = None,
        **kwargs: Any,
    ) -> _RequestContext:
        return self._make_request(
            method=hdrs.METH_PATCH,
            url=url,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
            **kwargs,
        )

    def delete(
        self,
        url: _URL_TYPE,
        retry_options: RetryOptionsBase | None = None,
        raise_for_status: bool | None = None,
        **kwargs: Any,
    ) -> _RequestContext:
        return self._make_request(
            method=hdrs.METH_DELETE,
            url=url,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
            **kwargs,
        )

    async def close(self) -> None:
        await self._client.close()
        self._closed = True

    def _make_request(
        self,
        method: str,
        url: _URL_TYPE,
        retry_options: RetryOptionsBase | None = None,
        raise_for_status: bool | None = None,
        **kwargs: Any,
    ) -> _RequestContext:
        url_list = _url_to_urls(url)
        params_list = [
            RequestParams(
                method=method,
                url=url,
                headers=kwargs.pop("headers", {}),
                trace_request_ctx=kwargs.pop("trace_request_ctx", None),
                kwargs=kwargs,
            )
            for url in url_list
        ]

        return self._make_requests(
            params_list=params_list,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
        )

    def _make_requests(
        self,
        params_list: list[RequestParams],
        retry_options: RetryOptionsBase | None = None,
        raise_for_status: bool | None = None,
    ) -> _RequestContext:
        if retry_options is None:
            retry_options = self._retry_options
        if raise_for_status is None:
            raise_for_status = self._raise_for_status
        return _RequestContext(
            request_func=self._client.request,
            params_list=params_list,
            logger=self._logger,
            retry_options=retry_options,
            raise_for_status=raise_for_status,
        )

    async def __aenter__(self) -> RetryClient:  # noqa: PYI034
        return self

    async def __aexit__(
        self,
        exc_type: type[BaseException] | None,
        exc_val: BaseException | None,
        exc_tb: TracebackType | None,
    ) -> None:
        await self.close()

    def __del__(self) -> None:
        if getattr(self, "_closed", None) is None:
            # in case object was not initialized (__init__ raised an exception)
            return

        if not self._closed:
            self._logger.warning("Aiohttp retry client was not closed")