diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/aiohttp_handler.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/aiohttp_handler.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/aiohttp_handler.py | 595 |
1 files changed, 595 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/aiohttp_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/aiohttp_handler.py new file mode 100644 index 00000000..c865fee1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/aiohttp_handler.py @@ -0,0 +1,595 @@ +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union, cast + +import aiohttp +import httpx # type: ignore +from aiohttp import ClientSession, FormData + +import litellm +import litellm.litellm_core_utils +import litellm.types +import litellm.types.utils +from litellm.llms.base_llm.chat.transformation import BaseConfig +from litellm.llms.base_llm.image_variations.transformation import ( + BaseImageVariationConfig, +) +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + _get_httpx_client, +) +from litellm.types.llms.openai import FileTypes +from litellm.types.utils import HttpHandlerRequestFields, ImageResponse, LlmProviders +from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + +DEFAULT_TIMEOUT = 600 + + +class BaseLLMAIOHTTPHandler: + + def __init__(self): + self.client_session: Optional[aiohttp.ClientSession] = None + + def _get_async_client_session( + self, dynamic_client_session: Optional[ClientSession] = None + ) -> ClientSession: + if dynamic_client_session: + return dynamic_client_session + elif self.client_session: + return self.client_session + else: + # init client session, and then return new session + self.client_session = aiohttp.ClientSession() + return self.client_session + + async def _make_common_async_call( + self, + async_client_session: Optional[ClientSession], + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: Optional[dict], + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + form_data: Optional[FormData] = None, + stream: bool = False, + ) -> aiohttp.ClientResponse: + """Common implementation across stream + non-stream calls. Meant to ensure consistent error-handling.""" + max_retry_on_unprocessable_entity_error = ( + provider_config.max_retry_on_unprocessable_entity_error + ) + + response: Optional[aiohttp.ClientResponse] = None + async_client_session = self._get_async_client_session( + dynamic_client_session=async_client_session + ) + + for i in range(max(max_retry_on_unprocessable_entity_error, 1)): + try: + response = await async_client_session.post( + url=api_base, + headers=headers, + json=data, + data=form_data, + ) + if not response.ok: + response.raise_for_status() + except aiohttp.ClientResponseError as e: + setattr(e, "text", e.message) + raise self._handle_error(e=e, provider_config=provider_config) + except Exception as e: + raise self._handle_error(e=e, provider_config=provider_config) + break + + if response is None: + raise provider_config.get_error_class( + error_message="No response from the API", + status_code=422, + headers={}, + ) + + return response + + def _make_common_sync_call( + self, + sync_httpx_client: HTTPHandler, + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: dict, + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + stream: bool = False, + files: Optional[dict] = None, + content: Any = None, + params: Optional[dict] = None, + ) -> httpx.Response: + + max_retry_on_unprocessable_entity_error = ( + provider_config.max_retry_on_unprocessable_entity_error + ) + + response: Optional[httpx.Response] = None + + for i in range(max(max_retry_on_unprocessable_entity_error, 1)): + try: + response = sync_httpx_client.post( + url=api_base, + headers=headers, + data=data, # do not json dump the data here. let the individual endpoint handle this. + timeout=timeout, + stream=stream, + files=files, + content=content, + params=params, + ) + except httpx.HTTPStatusError as e: + hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error + should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error( + e=e, litellm_params=litellm_params + ) + if should_retry and not hit_max_retry: + data = ( + provider_config.transform_request_on_unprocessable_entity_error( + e=e, request_data=data + ) + ) + continue + else: + raise self._handle_error(e=e, provider_config=provider_config) + except Exception as e: + raise self._handle_error(e=e, provider_config=provider_config) + break + + if response is None: + raise provider_config.get_error_class( + error_message="No response from the API", + status_code=422, # don't retry on this error + headers={}, + ) + + return response + + async def async_completion( + self, + custom_llm_provider: str, + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: dict, + timeout: Union[float, httpx.Timeout], + model: str, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + messages: list, + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + client: Optional[ClientSession] = None, + ): + _response = await self._make_common_async_call( + async_client_session=client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + litellm_params=litellm_params, + stream=False, + ) + _transformed_response = await provider_config.transform_response( # type: ignore + model=model, + raw_response=_response, # type: ignore + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + ) + return _transformed_response + + def completion( + self, + model: str, + messages: list, + api_base: str, + custom_llm_provider: str, + model_response: ModelResponse, + encoding, + logging_obj: LiteLLMLoggingObj, + optional_params: dict, + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + acompletion: bool, + stream: Optional[bool] = False, + fake_stream: bool = False, + api_key: Optional[str] = None, + headers: Optional[dict] = {}, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler, ClientSession]] = None, + ): + provider_config = ProviderConfigManager.get_provider_chat_config( + model=model, provider=litellm.LlmProviders(custom_llm_provider) + ) + # get config from model, custom llm provider + headers = provider_config.validate_environment( + api_key=api_key, + headers=headers or {}, + model=model, + messages=messages, + optional_params=optional_params, + api_base=api_base, + ) + + api_base = provider_config.get_complete_url( + api_base=api_base, + model=model, + optional_params=optional_params, + litellm_params=litellm_params, + stream=stream, + ) + + data = provider_config.transform_request( + model=model, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + headers=headers, + ) + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": headers, + }, + ) + + if acompletion is True: + return self.async_completion( + custom_llm_provider=custom_llm_provider, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + model=model, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + client=( + client + if client is not None and isinstance(client, ClientSession) + else None + ), + ) + + if stream is True: + if fake_stream is not True: + data["stream"] = stream + completion_stream, headers = self.make_sync_call( + provider_config=provider_config, + api_base=api_base, + headers=headers, # type: ignore + data=data, + model=model, + messages=messages, + logging_obj=logging_obj, + timeout=timeout, + fake_stream=fake_stream, + client=( + client + if client is not None and isinstance(client, HTTPHandler) + else None + ), + litellm_params=litellm_params, + ) + return CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider=custom_llm_provider, + logging_obj=logging_obj, + ) + + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client() + else: + sync_httpx_client = client + + response = self._make_common_sync_call( + sync_httpx_client=sync_httpx_client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + timeout=timeout, + litellm_params=litellm_params, + data=data, + ) + return provider_config.transform_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + ) + + def make_sync_call( + self, + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: dict, + model: str, + messages: list, + logging_obj, + litellm_params: dict, + timeout: Union[float, httpx.Timeout], + fake_stream: bool = False, + client: Optional[HTTPHandler] = None, + ) -> Tuple[Any, dict]: + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client() + else: + sync_httpx_client = client + stream = True + if fake_stream is True: + stream = False + + response = self._make_common_sync_call( + sync_httpx_client=sync_httpx_client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + litellm_params=litellm_params, + stream=stream, + ) + + if fake_stream is True: + completion_stream = provider_config.get_model_response_iterator( + streaming_response=response.json(), sync_stream=True + ) + else: + completion_stream = provider_config.get_model_response_iterator( + streaming_response=response.iter_lines(), sync_stream=True + ) + + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response="first stream response received", + additional_args={"complete_input_dict": data}, + ) + + return completion_stream, dict(response.headers) + + async def async_image_variations( + self, + client: Optional[ClientSession], + provider_config: BaseImageVariationConfig, + api_base: str, + headers: dict, + data: HttpHandlerRequestFields, + timeout: float, + litellm_params: dict, + model_response: ImageResponse, + logging_obj: LiteLLMLoggingObj, + api_key: str, + model: Optional[str], + image: FileTypes, + optional_params: dict, + ) -> ImageResponse: + # create aiohttp form data if files in data + form_data: Optional[FormData] = None + if "files" in data and "data" in data: + form_data = FormData() + for k, v in data["files"].items(): + form_data.add_field(k, v[1], filename=v[0], content_type=v[2]) + + for key, value in data["data"].items(): + form_data.add_field(key, value) + + _response = await self._make_common_async_call( + async_client_session=client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=None if form_data is not None else cast(dict, data), + form_data=form_data, + timeout=timeout, + litellm_params=litellm_params, + stream=False, + ) + + ## LOGGING + logging_obj.post_call( + api_key=api_key, + original_response=_response.text, + additional_args={ + "headers": headers, + "api_base": api_base, + }, + ) + + ## RESPONSE OBJECT + return await provider_config.async_transform_response_image_variation( + model=model, + model_response=model_response, + raw_response=_response, + logging_obj=logging_obj, + request_data=cast(dict, data), + image=image, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=None, + api_key=api_key, + ) + + def image_variations( + self, + model_response: ImageResponse, + api_key: str, + model: Optional[str], + image: FileTypes, + timeout: float, + custom_llm_provider: str, + logging_obj: LiteLLMLoggingObj, + optional_params: dict, + litellm_params: dict, + print_verbose: Optional[Callable] = None, + api_base: Optional[str] = None, + aimage_variation: bool = False, + logger_fn=None, + client=None, + organization: Optional[str] = None, + headers: Optional[dict] = None, + ) -> ImageResponse: + if model is None: + raise ValueError("model is required for non-openai image variations") + + provider_config = ProviderConfigManager.get_provider_image_variation_config( + model=model, # openai defaults to dall-e-2 + provider=LlmProviders(custom_llm_provider), + ) + + if provider_config is None: + raise ValueError( + f"image variation provider not found: {custom_llm_provider}." + ) + + api_base = provider_config.get_complete_url( + api_base=api_base, + model=model, + optional_params=optional_params, + litellm_params=litellm_params, + stream=False, + ) + + headers = provider_config.validate_environment( + api_key=api_key, + headers=headers or {}, + model=model, + messages=[{"role": "user", "content": "test"}], + optional_params=optional_params, + api_base=api_base, + ) + + data = provider_config.transform_request_image_variation( + model=model, + image=image, + optional_params=optional_params, + headers=headers, + ) + + ## LOGGING + logging_obj.pre_call( + input="", + api_key=api_key, + additional_args={ + "headers": headers, + "api_base": api_base, + "complete_input_dict": data.copy(), + }, + ) + + if litellm_params.get("async_call", False): + return self.async_image_variations( + api_base=api_base, + data=data, + headers=headers, + model_response=model_response, + api_key=api_key, + logging_obj=logging_obj, + model=model, + timeout=timeout, + client=client, + optional_params=optional_params, + litellm_params=litellm_params, + image=image, + provider_config=provider_config, + ) # type: ignore + + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client() + else: + sync_httpx_client = client + + response = self._make_common_sync_call( + sync_httpx_client=sync_httpx_client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + timeout=timeout, + litellm_params=litellm_params, + stream=False, + data=data.get("data") or {}, + files=data.get("files"), + content=data.get("content"), + params=data.get("params"), + ) + + ## LOGGING + logging_obj.post_call( + api_key=api_key, + original_response=response.text, + additional_args={ + "headers": headers, + "api_base": api_base, + }, + ) + + ## RESPONSE OBJECT + return provider_config.transform_response_image_variation( + model=model, + model_response=model_response, + raw_response=response, + logging_obj=logging_obj, + request_data=cast(dict, data), + image=image, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=None, + api_key=api_key, + ) + + def _handle_error(self, e: Exception, provider_config: BaseConfig): + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + if error_response and hasattr(error_response, "text"): + error_text = getattr(error_response, "text", error_text) + if error_headers: + error_headers = dict(error_headers) + else: + error_headers = {} + raise provider_config.get_error_class( + error_message=error_text, + status_code=status_code, + headers=error_headers, + ) |