aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/aiohttp_handler.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/aiohttp_handler.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/aiohttp_handler.py595
1 files changed, 595 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/aiohttp_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/aiohttp_handler.py
new file mode 100644
index 00000000..c865fee1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/aiohttp_handler.py
@@ -0,0 +1,595 @@
+from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union, cast
+
+import aiohttp
+import httpx # type: ignore
+from aiohttp import ClientSession, FormData
+
+import litellm
+import litellm.litellm_core_utils
+import litellm.types
+import litellm.types.utils
+from litellm.llms.base_llm.chat.transformation import BaseConfig
+from litellm.llms.base_llm.image_variations.transformation import (
+ BaseImageVariationConfig,
+)
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ _get_httpx_client,
+)
+from litellm.types.llms.openai import FileTypes
+from litellm.types.utils import HttpHandlerRequestFields, ImageResponse, LlmProviders
+from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
+
+ LiteLLMLoggingObj = _LiteLLMLoggingObj
+else:
+ LiteLLMLoggingObj = Any
+
+DEFAULT_TIMEOUT = 600
+
+
+class BaseLLMAIOHTTPHandler:
+
+ def __init__(self):
+ self.client_session: Optional[aiohttp.ClientSession] = None
+
+ def _get_async_client_session(
+ self, dynamic_client_session: Optional[ClientSession] = None
+ ) -> ClientSession:
+ if dynamic_client_session:
+ return dynamic_client_session
+ elif self.client_session:
+ return self.client_session
+ else:
+ # init client session, and then return new session
+ self.client_session = aiohttp.ClientSession()
+ return self.client_session
+
+ async def _make_common_async_call(
+ self,
+ async_client_session: Optional[ClientSession],
+ provider_config: BaseConfig,
+ api_base: str,
+ headers: dict,
+ data: Optional[dict],
+ timeout: Union[float, httpx.Timeout],
+ litellm_params: dict,
+ form_data: Optional[FormData] = None,
+ stream: bool = False,
+ ) -> aiohttp.ClientResponse:
+ """Common implementation across stream + non-stream calls. Meant to ensure consistent error-handling."""
+ max_retry_on_unprocessable_entity_error = (
+ provider_config.max_retry_on_unprocessable_entity_error
+ )
+
+ response: Optional[aiohttp.ClientResponse] = None
+ async_client_session = self._get_async_client_session(
+ dynamic_client_session=async_client_session
+ )
+
+ for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
+ try:
+ response = await async_client_session.post(
+ url=api_base,
+ headers=headers,
+ json=data,
+ data=form_data,
+ )
+ if not response.ok:
+ response.raise_for_status()
+ except aiohttp.ClientResponseError as e:
+ setattr(e, "text", e.message)
+ raise self._handle_error(e=e, provider_config=provider_config)
+ except Exception as e:
+ raise self._handle_error(e=e, provider_config=provider_config)
+ break
+
+ if response is None:
+ raise provider_config.get_error_class(
+ error_message="No response from the API",
+ status_code=422,
+ headers={},
+ )
+
+ return response
+
+ def _make_common_sync_call(
+ self,
+ sync_httpx_client: HTTPHandler,
+ provider_config: BaseConfig,
+ api_base: str,
+ headers: dict,
+ data: dict,
+ timeout: Union[float, httpx.Timeout],
+ litellm_params: dict,
+ stream: bool = False,
+ files: Optional[dict] = None,
+ content: Any = None,
+ params: Optional[dict] = None,
+ ) -> httpx.Response:
+
+ max_retry_on_unprocessable_entity_error = (
+ provider_config.max_retry_on_unprocessable_entity_error
+ )
+
+ response: Optional[httpx.Response] = None
+
+ for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
+ try:
+ response = sync_httpx_client.post(
+ url=api_base,
+ headers=headers,
+ data=data, # do not json dump the data here. let the individual endpoint handle this.
+ timeout=timeout,
+ stream=stream,
+ files=files,
+ content=content,
+ params=params,
+ )
+ except httpx.HTTPStatusError as e:
+ hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
+ should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
+ e=e, litellm_params=litellm_params
+ )
+ if should_retry and not hit_max_retry:
+ data = (
+ provider_config.transform_request_on_unprocessable_entity_error(
+ e=e, request_data=data
+ )
+ )
+ continue
+ else:
+ raise self._handle_error(e=e, provider_config=provider_config)
+ except Exception as e:
+ raise self._handle_error(e=e, provider_config=provider_config)
+ break
+
+ if response is None:
+ raise provider_config.get_error_class(
+ error_message="No response from the API",
+ status_code=422, # don't retry on this error
+ headers={},
+ )
+
+ return response
+
+ async def async_completion(
+ self,
+ custom_llm_provider: str,
+ provider_config: BaseConfig,
+ api_base: str,
+ headers: dict,
+ data: dict,
+ timeout: Union[float, httpx.Timeout],
+ model: str,
+ model_response: ModelResponse,
+ logging_obj: LiteLLMLoggingObj,
+ messages: list,
+ optional_params: dict,
+ litellm_params: dict,
+ encoding: Any,
+ api_key: Optional[str] = None,
+ client: Optional[ClientSession] = None,
+ ):
+ _response = await self._make_common_async_call(
+ async_client_session=client,
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers,
+ data=data,
+ timeout=timeout,
+ litellm_params=litellm_params,
+ stream=False,
+ )
+ _transformed_response = await provider_config.transform_response( # type: ignore
+ model=model,
+ raw_response=_response, # type: ignore
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ request_data=data,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=encoding,
+ )
+ return _transformed_response
+
+ def completion(
+ self,
+ model: str,
+ messages: list,
+ api_base: str,
+ custom_llm_provider: str,
+ model_response: ModelResponse,
+ encoding,
+ logging_obj: LiteLLMLoggingObj,
+ optional_params: dict,
+ timeout: Union[float, httpx.Timeout],
+ litellm_params: dict,
+ acompletion: bool,
+ stream: Optional[bool] = False,
+ fake_stream: bool = False,
+ api_key: Optional[str] = None,
+ headers: Optional[dict] = {},
+ client: Optional[Union[HTTPHandler, AsyncHTTPHandler, ClientSession]] = None,
+ ):
+ provider_config = ProviderConfigManager.get_provider_chat_config(
+ model=model, provider=litellm.LlmProviders(custom_llm_provider)
+ )
+ # get config from model, custom llm provider
+ headers = provider_config.validate_environment(
+ api_key=api_key,
+ headers=headers or {},
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ api_base=api_base,
+ )
+
+ api_base = provider_config.get_complete_url(
+ api_base=api_base,
+ model=model,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ stream=stream,
+ )
+
+ data = provider_config.transform_request(
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ headers=headers,
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=messages,
+ api_key=api_key,
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": api_base,
+ "headers": headers,
+ },
+ )
+
+ if acompletion is True:
+ return self.async_completion(
+ custom_llm_provider=custom_llm_provider,
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers,
+ data=data,
+ timeout=timeout,
+ model=model,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=encoding,
+ client=(
+ client
+ if client is not None and isinstance(client, ClientSession)
+ else None
+ ),
+ )
+
+ if stream is True:
+ if fake_stream is not True:
+ data["stream"] = stream
+ completion_stream, headers = self.make_sync_call(
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers, # type: ignore
+ data=data,
+ model=model,
+ messages=messages,
+ logging_obj=logging_obj,
+ timeout=timeout,
+ fake_stream=fake_stream,
+ client=(
+ client
+ if client is not None and isinstance(client, HTTPHandler)
+ else None
+ ),
+ litellm_params=litellm_params,
+ )
+ return CustomStreamWrapper(
+ completion_stream=completion_stream,
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ logging_obj=logging_obj,
+ )
+
+ if client is None or not isinstance(client, HTTPHandler):
+ sync_httpx_client = _get_httpx_client()
+ else:
+ sync_httpx_client = client
+
+ response = self._make_common_sync_call(
+ sync_httpx_client=sync_httpx_client,
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers,
+ timeout=timeout,
+ litellm_params=litellm_params,
+ data=data,
+ )
+ return provider_config.transform_response(
+ model=model,
+ raw_response=response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ request_data=data,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=encoding,
+ )
+
+ def make_sync_call(
+ self,
+ provider_config: BaseConfig,
+ api_base: str,
+ headers: dict,
+ data: dict,
+ model: str,
+ messages: list,
+ logging_obj,
+ litellm_params: dict,
+ timeout: Union[float, httpx.Timeout],
+ fake_stream: bool = False,
+ client: Optional[HTTPHandler] = None,
+ ) -> Tuple[Any, dict]:
+ if client is None or not isinstance(client, HTTPHandler):
+ sync_httpx_client = _get_httpx_client()
+ else:
+ sync_httpx_client = client
+ stream = True
+ if fake_stream is True:
+ stream = False
+
+ response = self._make_common_sync_call(
+ sync_httpx_client=sync_httpx_client,
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers,
+ data=data,
+ timeout=timeout,
+ litellm_params=litellm_params,
+ stream=stream,
+ )
+
+ if fake_stream is True:
+ completion_stream = provider_config.get_model_response_iterator(
+ streaming_response=response.json(), sync_stream=True
+ )
+ else:
+ completion_stream = provider_config.get_model_response_iterator(
+ streaming_response=response.iter_lines(), sync_stream=True
+ )
+
+ # LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response="first stream response received",
+ additional_args={"complete_input_dict": data},
+ )
+
+ return completion_stream, dict(response.headers)
+
+ async def async_image_variations(
+ self,
+ client: Optional[ClientSession],
+ provider_config: BaseImageVariationConfig,
+ api_base: str,
+ headers: dict,
+ data: HttpHandlerRequestFields,
+ timeout: float,
+ litellm_params: dict,
+ model_response: ImageResponse,
+ logging_obj: LiteLLMLoggingObj,
+ api_key: str,
+ model: Optional[str],
+ image: FileTypes,
+ optional_params: dict,
+ ) -> ImageResponse:
+ # create aiohttp form data if files in data
+ form_data: Optional[FormData] = None
+ if "files" in data and "data" in data:
+ form_data = FormData()
+ for k, v in data["files"].items():
+ form_data.add_field(k, v[1], filename=v[0], content_type=v[2])
+
+ for key, value in data["data"].items():
+ form_data.add_field(key, value)
+
+ _response = await self._make_common_async_call(
+ async_client_session=client,
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers,
+ data=None if form_data is not None else cast(dict, data),
+ form_data=form_data,
+ timeout=timeout,
+ litellm_params=litellm_params,
+ stream=False,
+ )
+
+ ## LOGGING
+ logging_obj.post_call(
+ api_key=api_key,
+ original_response=_response.text,
+ additional_args={
+ "headers": headers,
+ "api_base": api_base,
+ },
+ )
+
+ ## RESPONSE OBJECT
+ return await provider_config.async_transform_response_image_variation(
+ model=model,
+ model_response=model_response,
+ raw_response=_response,
+ logging_obj=logging_obj,
+ request_data=cast(dict, data),
+ image=image,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=None,
+ api_key=api_key,
+ )
+
+ def image_variations(
+ self,
+ model_response: ImageResponse,
+ api_key: str,
+ model: Optional[str],
+ image: FileTypes,
+ timeout: float,
+ custom_llm_provider: str,
+ logging_obj: LiteLLMLoggingObj,
+ optional_params: dict,
+ litellm_params: dict,
+ print_verbose: Optional[Callable] = None,
+ api_base: Optional[str] = None,
+ aimage_variation: bool = False,
+ logger_fn=None,
+ client=None,
+ organization: Optional[str] = None,
+ headers: Optional[dict] = None,
+ ) -> ImageResponse:
+ if model is None:
+ raise ValueError("model is required for non-openai image variations")
+
+ provider_config = ProviderConfigManager.get_provider_image_variation_config(
+ model=model, # openai defaults to dall-e-2
+ provider=LlmProviders(custom_llm_provider),
+ )
+
+ if provider_config is None:
+ raise ValueError(
+ f"image variation provider not found: {custom_llm_provider}."
+ )
+
+ api_base = provider_config.get_complete_url(
+ api_base=api_base,
+ model=model,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ stream=False,
+ )
+
+ headers = provider_config.validate_environment(
+ api_key=api_key,
+ headers=headers or {},
+ model=model,
+ messages=[{"role": "user", "content": "test"}],
+ optional_params=optional_params,
+ api_base=api_base,
+ )
+
+ data = provider_config.transform_request_image_variation(
+ model=model,
+ image=image,
+ optional_params=optional_params,
+ headers=headers,
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input="",
+ api_key=api_key,
+ additional_args={
+ "headers": headers,
+ "api_base": api_base,
+ "complete_input_dict": data.copy(),
+ },
+ )
+
+ if litellm_params.get("async_call", False):
+ return self.async_image_variations(
+ api_base=api_base,
+ data=data,
+ headers=headers,
+ model_response=model_response,
+ api_key=api_key,
+ logging_obj=logging_obj,
+ model=model,
+ timeout=timeout,
+ client=client,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ image=image,
+ provider_config=provider_config,
+ ) # type: ignore
+
+ if client is None or not isinstance(client, HTTPHandler):
+ sync_httpx_client = _get_httpx_client()
+ else:
+ sync_httpx_client = client
+
+ response = self._make_common_sync_call(
+ sync_httpx_client=sync_httpx_client,
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers,
+ timeout=timeout,
+ litellm_params=litellm_params,
+ stream=False,
+ data=data.get("data") or {},
+ files=data.get("files"),
+ content=data.get("content"),
+ params=data.get("params"),
+ )
+
+ ## LOGGING
+ logging_obj.post_call(
+ api_key=api_key,
+ original_response=response.text,
+ additional_args={
+ "headers": headers,
+ "api_base": api_base,
+ },
+ )
+
+ ## RESPONSE OBJECT
+ return provider_config.transform_response_image_variation(
+ model=model,
+ model_response=model_response,
+ raw_response=response,
+ logging_obj=logging_obj,
+ request_data=cast(dict, data),
+ image=image,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=None,
+ api_key=api_key,
+ )
+
+ def _handle_error(self, e: Exception, provider_config: BaseConfig):
+ status_code = getattr(e, "status_code", 500)
+ error_headers = getattr(e, "headers", None)
+ error_text = getattr(e, "text", str(e))
+ error_response = getattr(e, "response", None)
+ if error_headers is None and error_response:
+ error_headers = getattr(error_response, "headers", None)
+ if error_response and hasattr(error_response, "text"):
+ error_text = getattr(error_response, "text", error_text)
+ if error_headers:
+ error_headers = dict(error_headers)
+ else:
+ error_headers = {}
+ raise provider_config.get_error_class(
+ error_message=error_text,
+ status_code=status_code,
+ headers=error_headers,
+ )