diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx')
4 files changed, 2650 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/aiohttp_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/aiohttp_handler.py new file mode 100644 index 00000000..c865fee1 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/aiohttp_handler.py @@ -0,0 +1,595 @@ +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union, cast + +import aiohttp +import httpx # type: ignore +from aiohttp import ClientSession, FormData + +import litellm +import litellm.litellm_core_utils +import litellm.types +import litellm.types.utils +from litellm.llms.base_llm.chat.transformation import BaseConfig +from litellm.llms.base_llm.image_variations.transformation import ( + BaseImageVariationConfig, +) +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + _get_httpx_client, +) +from litellm.types.llms.openai import FileTypes +from litellm.types.utils import HttpHandlerRequestFields, ImageResponse, LlmProviders +from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + +DEFAULT_TIMEOUT = 600 + + +class BaseLLMAIOHTTPHandler: + + def __init__(self): + self.client_session: Optional[aiohttp.ClientSession] = None + + def _get_async_client_session( + self, dynamic_client_session: Optional[ClientSession] = None + ) -> ClientSession: + if dynamic_client_session: + return dynamic_client_session + elif self.client_session: + return self.client_session + else: + # init client session, and then return new session + self.client_session = aiohttp.ClientSession() + return self.client_session + + async def _make_common_async_call( + self, + async_client_session: Optional[ClientSession], + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: Optional[dict], + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + form_data: Optional[FormData] = None, + stream: bool = False, + ) -> aiohttp.ClientResponse: + """Common implementation across stream + non-stream calls. Meant to ensure consistent error-handling.""" + max_retry_on_unprocessable_entity_error = ( + provider_config.max_retry_on_unprocessable_entity_error + ) + + response: Optional[aiohttp.ClientResponse] = None + async_client_session = self._get_async_client_session( + dynamic_client_session=async_client_session + ) + + for i in range(max(max_retry_on_unprocessable_entity_error, 1)): + try: + response = await async_client_session.post( + url=api_base, + headers=headers, + json=data, + data=form_data, + ) + if not response.ok: + response.raise_for_status() + except aiohttp.ClientResponseError as e: + setattr(e, "text", e.message) + raise self._handle_error(e=e, provider_config=provider_config) + except Exception as e: + raise self._handle_error(e=e, provider_config=provider_config) + break + + if response is None: + raise provider_config.get_error_class( + error_message="No response from the API", + status_code=422, + headers={}, + ) + + return response + + def _make_common_sync_call( + self, + sync_httpx_client: HTTPHandler, + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: dict, + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + stream: bool = False, + files: Optional[dict] = None, + content: Any = None, + params: Optional[dict] = None, + ) -> httpx.Response: + + max_retry_on_unprocessable_entity_error = ( + provider_config.max_retry_on_unprocessable_entity_error + ) + + response: Optional[httpx.Response] = None + + for i in range(max(max_retry_on_unprocessable_entity_error, 1)): + try: + response = sync_httpx_client.post( + url=api_base, + headers=headers, + data=data, # do not json dump the data here. let the individual endpoint handle this. + timeout=timeout, + stream=stream, + files=files, + content=content, + params=params, + ) + except httpx.HTTPStatusError as e: + hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error + should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error( + e=e, litellm_params=litellm_params + ) + if should_retry and not hit_max_retry: + data = ( + provider_config.transform_request_on_unprocessable_entity_error( + e=e, request_data=data + ) + ) + continue + else: + raise self._handle_error(e=e, provider_config=provider_config) + except Exception as e: + raise self._handle_error(e=e, provider_config=provider_config) + break + + if response is None: + raise provider_config.get_error_class( + error_message="No response from the API", + status_code=422, # don't retry on this error + headers={}, + ) + + return response + + async def async_completion( + self, + custom_llm_provider: str, + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: dict, + timeout: Union[float, httpx.Timeout], + model: str, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + messages: list, + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + client: Optional[ClientSession] = None, + ): + _response = await self._make_common_async_call( + async_client_session=client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + litellm_params=litellm_params, + stream=False, + ) + _transformed_response = await provider_config.transform_response( # type: ignore + model=model, + raw_response=_response, # type: ignore + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + ) + return _transformed_response + + def completion( + self, + model: str, + messages: list, + api_base: str, + custom_llm_provider: str, + model_response: ModelResponse, + encoding, + logging_obj: LiteLLMLoggingObj, + optional_params: dict, + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + acompletion: bool, + stream: Optional[bool] = False, + fake_stream: bool = False, + api_key: Optional[str] = None, + headers: Optional[dict] = {}, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler, ClientSession]] = None, + ): + provider_config = ProviderConfigManager.get_provider_chat_config( + model=model, provider=litellm.LlmProviders(custom_llm_provider) + ) + # get config from model, custom llm provider + headers = provider_config.validate_environment( + api_key=api_key, + headers=headers or {}, + model=model, + messages=messages, + optional_params=optional_params, + api_base=api_base, + ) + + api_base = provider_config.get_complete_url( + api_base=api_base, + model=model, + optional_params=optional_params, + litellm_params=litellm_params, + stream=stream, + ) + + data = provider_config.transform_request( + model=model, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + headers=headers, + ) + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": headers, + }, + ) + + if acompletion is True: + return self.async_completion( + custom_llm_provider=custom_llm_provider, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + model=model, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + client=( + client + if client is not None and isinstance(client, ClientSession) + else None + ), + ) + + if stream is True: + if fake_stream is not True: + data["stream"] = stream + completion_stream, headers = self.make_sync_call( + provider_config=provider_config, + api_base=api_base, + headers=headers, # type: ignore + data=data, + model=model, + messages=messages, + logging_obj=logging_obj, + timeout=timeout, + fake_stream=fake_stream, + client=( + client + if client is not None and isinstance(client, HTTPHandler) + else None + ), + litellm_params=litellm_params, + ) + return CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider=custom_llm_provider, + logging_obj=logging_obj, + ) + + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client() + else: + sync_httpx_client = client + + response = self._make_common_sync_call( + sync_httpx_client=sync_httpx_client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + timeout=timeout, + litellm_params=litellm_params, + data=data, + ) + return provider_config.transform_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + ) + + def make_sync_call( + self, + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: dict, + model: str, + messages: list, + logging_obj, + litellm_params: dict, + timeout: Union[float, httpx.Timeout], + fake_stream: bool = False, + client: Optional[HTTPHandler] = None, + ) -> Tuple[Any, dict]: + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client() + else: + sync_httpx_client = client + stream = True + if fake_stream is True: + stream = False + + response = self._make_common_sync_call( + sync_httpx_client=sync_httpx_client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + litellm_params=litellm_params, + stream=stream, + ) + + if fake_stream is True: + completion_stream = provider_config.get_model_response_iterator( + streaming_response=response.json(), sync_stream=True + ) + else: + completion_stream = provider_config.get_model_response_iterator( + streaming_response=response.iter_lines(), sync_stream=True + ) + + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response="first stream response received", + additional_args={"complete_input_dict": data}, + ) + + return completion_stream, dict(response.headers) + + async def async_image_variations( + self, + client: Optional[ClientSession], + provider_config: BaseImageVariationConfig, + api_base: str, + headers: dict, + data: HttpHandlerRequestFields, + timeout: float, + litellm_params: dict, + model_response: ImageResponse, + logging_obj: LiteLLMLoggingObj, + api_key: str, + model: Optional[str], + image: FileTypes, + optional_params: dict, + ) -> ImageResponse: + # create aiohttp form data if files in data + form_data: Optional[FormData] = None + if "files" in data and "data" in data: + form_data = FormData() + for k, v in data["files"].items(): + form_data.add_field(k, v[1], filename=v[0], content_type=v[2]) + + for key, value in data["data"].items(): + form_data.add_field(key, value) + + _response = await self._make_common_async_call( + async_client_session=client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=None if form_data is not None else cast(dict, data), + form_data=form_data, + timeout=timeout, + litellm_params=litellm_params, + stream=False, + ) + + ## LOGGING + logging_obj.post_call( + api_key=api_key, + original_response=_response.text, + additional_args={ + "headers": headers, + "api_base": api_base, + }, + ) + + ## RESPONSE OBJECT + return await provider_config.async_transform_response_image_variation( + model=model, + model_response=model_response, + raw_response=_response, + logging_obj=logging_obj, + request_data=cast(dict, data), + image=image, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=None, + api_key=api_key, + ) + + def image_variations( + self, + model_response: ImageResponse, + api_key: str, + model: Optional[str], + image: FileTypes, + timeout: float, + custom_llm_provider: str, + logging_obj: LiteLLMLoggingObj, + optional_params: dict, + litellm_params: dict, + print_verbose: Optional[Callable] = None, + api_base: Optional[str] = None, + aimage_variation: bool = False, + logger_fn=None, + client=None, + organization: Optional[str] = None, + headers: Optional[dict] = None, + ) -> ImageResponse: + if model is None: + raise ValueError("model is required for non-openai image variations") + + provider_config = ProviderConfigManager.get_provider_image_variation_config( + model=model, # openai defaults to dall-e-2 + provider=LlmProviders(custom_llm_provider), + ) + + if provider_config is None: + raise ValueError( + f"image variation provider not found: {custom_llm_provider}." + ) + + api_base = provider_config.get_complete_url( + api_base=api_base, + model=model, + optional_params=optional_params, + litellm_params=litellm_params, + stream=False, + ) + + headers = provider_config.validate_environment( + api_key=api_key, + headers=headers or {}, + model=model, + messages=[{"role": "user", "content": "test"}], + optional_params=optional_params, + api_base=api_base, + ) + + data = provider_config.transform_request_image_variation( + model=model, + image=image, + optional_params=optional_params, + headers=headers, + ) + + ## LOGGING + logging_obj.pre_call( + input="", + api_key=api_key, + additional_args={ + "headers": headers, + "api_base": api_base, + "complete_input_dict": data.copy(), + }, + ) + + if litellm_params.get("async_call", False): + return self.async_image_variations( + api_base=api_base, + data=data, + headers=headers, + model_response=model_response, + api_key=api_key, + logging_obj=logging_obj, + model=model, + timeout=timeout, + client=client, + optional_params=optional_params, + litellm_params=litellm_params, + image=image, + provider_config=provider_config, + ) # type: ignore + + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client() + else: + sync_httpx_client = client + + response = self._make_common_sync_call( + sync_httpx_client=sync_httpx_client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + timeout=timeout, + litellm_params=litellm_params, + stream=False, + data=data.get("data") or {}, + files=data.get("files"), + content=data.get("content"), + params=data.get("params"), + ) + + ## LOGGING + logging_obj.post_call( + api_key=api_key, + original_response=response.text, + additional_args={ + "headers": headers, + "api_base": api_base, + }, + ) + + ## RESPONSE OBJECT + return provider_config.transform_response_image_variation( + model=model, + model_response=model_response, + raw_response=response, + logging_obj=logging_obj, + request_data=cast(dict, data), + image=image, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=None, + api_key=api_key, + ) + + def _handle_error(self, e: Exception, provider_config: BaseConfig): + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + if error_response and hasattr(error_response, "text"): + error_text = getattr(error_response, "text", error_text) + if error_headers: + error_headers = dict(error_headers) + else: + error_headers = {} + raise provider_config.get_error_class( + error_message=error_text, + status_code=status_code, + headers=error_headers, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/http_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/http_handler.py new file mode 100644 index 00000000..34d70434 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/http_handler.py @@ -0,0 +1,746 @@ +import asyncio +import os +import ssl +import time +from typing import TYPE_CHECKING, Any, Callable, List, Mapping, Optional, Union + +import httpx +from httpx import USE_CLIENT_DEFAULT, AsyncHTTPTransport, HTTPTransport + +import litellm +from litellm.litellm_core_utils.logging_utils import track_llm_api_timing +from litellm.types.llms.custom_http import * + +if TYPE_CHECKING: + from litellm import LlmProviders + from litellm.litellm_core_utils.litellm_logging import ( + Logging as LiteLLMLoggingObject, + ) +else: + LlmProviders = Any + LiteLLMLoggingObject = Any + +try: + from litellm._version import version +except Exception: + version = "0.0.0" + +headers = { + "User-Agent": f"litellm/{version}", +} + +# https://www.python-httpx.org/advanced/timeouts +_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0) +_DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour + + +def mask_sensitive_info(error_message): + # Find the start of the key parameter + if isinstance(error_message, str): + key_index = error_message.find("key=") + else: + return error_message + + # If key is found + if key_index != -1: + # Find the end of the key parameter (next & or end of string) + next_param = error_message.find("&", key_index) + + if next_param == -1: + # If no more parameters, mask until the end of the string + masked_message = error_message[: key_index + 4] + "[REDACTED_API_KEY]" + else: + # Replace the key with redacted value, keeping other parameters + masked_message = ( + error_message[: key_index + 4] + + "[REDACTED_API_KEY]" + + error_message[next_param:] + ) + + return masked_message + + return error_message + + +class MaskedHTTPStatusError(httpx.HTTPStatusError): + def __init__( + self, original_error, message: Optional[str] = None, text: Optional[str] = None + ): + # Create a new error with the masked URL + masked_url = mask_sensitive_info(str(original_error.request.url)) + # Create a new error that looks like the original, but with a masked URL + + super().__init__( + message=original_error.message, + request=httpx.Request( + method=original_error.request.method, + url=masked_url, + headers=original_error.request.headers, + content=original_error.request.content, + ), + response=httpx.Response( + status_code=original_error.response.status_code, + content=original_error.response.content, + headers=original_error.response.headers, + ), + ) + self.message = message + self.text = text + + +class AsyncHTTPHandler: + def __init__( + self, + timeout: Optional[Union[float, httpx.Timeout]] = None, + event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]] = None, + concurrent_limit=1000, + client_alias: Optional[str] = None, # name for client in logs + ssl_verify: Optional[VerifyTypes] = None, + ): + self.timeout = timeout + self.event_hooks = event_hooks + self.client = self.create_client( + timeout=timeout, + concurrent_limit=concurrent_limit, + event_hooks=event_hooks, + ssl_verify=ssl_verify, + ) + self.client_alias = client_alias + + def create_client( + self, + timeout: Optional[Union[float, httpx.Timeout]], + concurrent_limit: int, + event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]], + ssl_verify: Optional[VerifyTypes] = None, + ) -> httpx.AsyncClient: + + # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts. + # /path/to/certificate.pem + if ssl_verify is None: + ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify) + + ssl_security_level = os.getenv("SSL_SECURITY_LEVEL") + + # If ssl_verify is not False and we need a lower security level + if ( + not ssl_verify + and ssl_security_level + and isinstance(ssl_security_level, str) + ): + # Create a custom SSL context with reduced security level + custom_ssl_context = ssl.create_default_context() + custom_ssl_context.set_ciphers(ssl_security_level) + + # If ssl_verify is a path to a CA bundle, load it into our custom context + if isinstance(ssl_verify, str) and os.path.exists(ssl_verify): + custom_ssl_context.load_verify_locations(cafile=ssl_verify) + + # Use our custom SSL context instead of the original ssl_verify value + ssl_verify = custom_ssl_context + + # An SSL certificate used by the requested host to authenticate the client. + # /path/to/client.pem + cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate) + + if timeout is None: + timeout = _DEFAULT_TIMEOUT + # Create a client with a connection pool + transport = self._create_async_transport() + + return httpx.AsyncClient( + transport=transport, + event_hooks=event_hooks, + timeout=timeout, + limits=httpx.Limits( + max_connections=concurrent_limit, + max_keepalive_connections=concurrent_limit, + ), + verify=ssl_verify, + cert=cert, + headers=headers, + ) + + async def close(self): + # Close the client when you're done with it + await self.client.aclose() + + async def __aenter__(self): + return self.client + + async def __aexit__(self): + # close the client when exiting + await self.client.aclose() + + async def get( + self, + url: str, + params: Optional[dict] = None, + headers: Optional[dict] = None, + follow_redirects: Optional[bool] = None, + ): + # Set follow_redirects to UseClientDefault if None + _follow_redirects = ( + follow_redirects if follow_redirects is not None else USE_CLIENT_DEFAULT + ) + + response = await self.client.get( + url, params=params, headers=headers, follow_redirects=_follow_redirects # type: ignore + ) + return response + + @track_llm_api_timing() + async def post( + self, + url: str, + data: Optional[Union[dict, str]] = None, # type: ignore + json: Optional[dict] = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + stream: bool = False, + logging_obj: Optional[LiteLLMLoggingObject] = None, + ): + start_time = time.time() + try: + if timeout is None: + timeout = self.timeout + + req = self.client.build_request( + "POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore + ) + response = await self.client.send(req, stream=stream) + response.raise_for_status() + return response + except (httpx.RemoteProtocolError, httpx.ConnectError): + # Retry the request with a new session if there is a connection error + new_client = self.create_client( + timeout=timeout, concurrent_limit=1, event_hooks=self.event_hooks + ) + try: + return await self.single_connection_post_request( + url=url, + client=new_client, + data=data, + json=json, + params=params, + headers=headers, + stream=stream, + ) + finally: + await new_client.aclose() + except httpx.TimeoutException as e: + end_time = time.time() + time_delta = round(end_time - start_time, 3) + headers = {} + error_response = getattr(e, "response", None) + if error_response is not None: + for key, value in error_response.headers.items(): + headers["response_headers-{}".format(key)] = value + + raise litellm.Timeout( + message=f"Connection timed out. Timeout passed={timeout}, time taken={time_delta} seconds", + model="default-model-name", + llm_provider="litellm-httpx-handler", + headers=headers, + ) + except httpx.HTTPStatusError as e: + if stream is True: + setattr(e, "message", await e.response.aread()) + setattr(e, "text", await e.response.aread()) + else: + setattr(e, "message", mask_sensitive_info(e.response.text)) + setattr(e, "text", mask_sensitive_info(e.response.text)) + + setattr(e, "status_code", e.response.status_code) + + raise e + except Exception as e: + raise e + + async def put( + self, + url: str, + data: Optional[Union[dict, str]] = None, # type: ignore + json: Optional[dict] = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + stream: bool = False, + ): + try: + if timeout is None: + timeout = self.timeout + + req = self.client.build_request( + "PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore + ) + response = await self.client.send(req) + response.raise_for_status() + return response + except (httpx.RemoteProtocolError, httpx.ConnectError): + # Retry the request with a new session if there is a connection error + new_client = self.create_client( + timeout=timeout, concurrent_limit=1, event_hooks=self.event_hooks + ) + try: + return await self.single_connection_post_request( + url=url, + client=new_client, + data=data, + json=json, + params=params, + headers=headers, + stream=stream, + ) + finally: + await new_client.aclose() + except httpx.TimeoutException as e: + headers = {} + error_response = getattr(e, "response", None) + if error_response is not None: + for key, value in error_response.headers.items(): + headers["response_headers-{}".format(key)] = value + + raise litellm.Timeout( + message=f"Connection timed out after {timeout} seconds.", + model="default-model-name", + llm_provider="litellm-httpx-handler", + headers=headers, + ) + except httpx.HTTPStatusError as e: + setattr(e, "status_code", e.response.status_code) + if stream is True: + setattr(e, "message", await e.response.aread()) + else: + setattr(e, "message", e.response.text) + raise e + except Exception as e: + raise e + + async def patch( + self, + url: str, + data: Optional[Union[dict, str]] = None, # type: ignore + json: Optional[dict] = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + stream: bool = False, + ): + try: + if timeout is None: + timeout = self.timeout + + req = self.client.build_request( + "PATCH", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore + ) + response = await self.client.send(req) + response.raise_for_status() + return response + except (httpx.RemoteProtocolError, httpx.ConnectError): + # Retry the request with a new session if there is a connection error + new_client = self.create_client( + timeout=timeout, concurrent_limit=1, event_hooks=self.event_hooks + ) + try: + return await self.single_connection_post_request( + url=url, + client=new_client, + data=data, + json=json, + params=params, + headers=headers, + stream=stream, + ) + finally: + await new_client.aclose() + except httpx.TimeoutException as e: + headers = {} + error_response = getattr(e, "response", None) + if error_response is not None: + for key, value in error_response.headers.items(): + headers["response_headers-{}".format(key)] = value + + raise litellm.Timeout( + message=f"Connection timed out after {timeout} seconds.", + model="default-model-name", + llm_provider="litellm-httpx-handler", + headers=headers, + ) + except httpx.HTTPStatusError as e: + setattr(e, "status_code", e.response.status_code) + if stream is True: + setattr(e, "message", await e.response.aread()) + else: + setattr(e, "message", e.response.text) + raise e + except Exception as e: + raise e + + async def delete( + self, + url: str, + data: Optional[Union[dict, str]] = None, # type: ignore + json: Optional[dict] = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + stream: bool = False, + ): + try: + if timeout is None: + timeout = self.timeout + req = self.client.build_request( + "DELETE", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore + ) + response = await self.client.send(req, stream=stream) + response.raise_for_status() + return response + except (httpx.RemoteProtocolError, httpx.ConnectError): + # Retry the request with a new session if there is a connection error + new_client = self.create_client( + timeout=timeout, concurrent_limit=1, event_hooks=self.event_hooks + ) + try: + return await self.single_connection_post_request( + url=url, + client=new_client, + data=data, + json=json, + params=params, + headers=headers, + stream=stream, + ) + finally: + await new_client.aclose() + except httpx.HTTPStatusError as e: + setattr(e, "status_code", e.response.status_code) + if stream is True: + setattr(e, "message", await e.response.aread()) + else: + setattr(e, "message", e.response.text) + raise e + except Exception as e: + raise e + + async def single_connection_post_request( + self, + url: str, + client: httpx.AsyncClient, + data: Optional[Union[dict, str]] = None, # type: ignore + json: Optional[dict] = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + stream: bool = False, + ): + """ + Making POST request for a single connection client. + + Used for retrying connection client errors. + """ + req = client.build_request( + "POST", url, data=data, json=json, params=params, headers=headers # type: ignore + ) + response = await client.send(req, stream=stream) + response.raise_for_status() + return response + + def __del__(self) -> None: + try: + asyncio.get_running_loop().create_task(self.close()) + except Exception: + pass + + def _create_async_transport(self) -> Optional[AsyncHTTPTransport]: + """ + Create an async transport with IPv4 only if litellm.force_ipv4 is True. + Otherwise, return None. + + Some users have seen httpx ConnectionError when using ipv6 - forcing ipv4 resolves the issue for them + """ + if litellm.force_ipv4: + return AsyncHTTPTransport(local_address="0.0.0.0") + else: + return None + + +class HTTPHandler: + def __init__( + self, + timeout: Optional[Union[float, httpx.Timeout]] = None, + concurrent_limit=1000, + client: Optional[httpx.Client] = None, + ssl_verify: Optional[Union[bool, str]] = None, + ): + if timeout is None: + timeout = _DEFAULT_TIMEOUT + + # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts. + # /path/to/certificate.pem + + if ssl_verify is None: + ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify) + + # An SSL certificate used by the requested host to authenticate the client. + # /path/to/client.pem + cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate) + + if client is None: + transport = self._create_sync_transport() + + # Create a client with a connection pool + self.client = httpx.Client( + transport=transport, + timeout=timeout, + limits=httpx.Limits( + max_connections=concurrent_limit, + max_keepalive_connections=concurrent_limit, + ), + verify=ssl_verify, + cert=cert, + headers=headers, + ) + else: + self.client = client + + def close(self): + # Close the client when you're done with it + self.client.close() + + def get( + self, + url: str, + params: Optional[dict] = None, + headers: Optional[dict] = None, + follow_redirects: Optional[bool] = None, + ): + # Set follow_redirects to UseClientDefault if None + _follow_redirects = ( + follow_redirects if follow_redirects is not None else USE_CLIENT_DEFAULT + ) + + response = self.client.get( + url, params=params, headers=headers, follow_redirects=_follow_redirects # type: ignore + ) + return response + + def post( + self, + url: str, + data: Optional[Union[dict, str]] = None, + json: Optional[Union[dict, str, List]] = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + stream: bool = False, + timeout: Optional[Union[float, httpx.Timeout]] = None, + files: Optional[dict] = None, + content: Any = None, + logging_obj: Optional[LiteLLMLoggingObject] = None, + ): + try: + if timeout is not None: + req = self.client.build_request( + "POST", + url, + data=data, # type: ignore + json=json, + params=params, + headers=headers, + timeout=timeout, + files=files, + content=content, # type: ignore + ) + else: + req = self.client.build_request( + "POST", url, data=data, json=json, params=params, headers=headers, files=files, content=content # type: ignore + ) + response = self.client.send(req, stream=stream) + response.raise_for_status() + return response + except httpx.TimeoutException: + raise litellm.Timeout( + message=f"Connection timed out after {timeout} seconds.", + model="default-model-name", + llm_provider="litellm-httpx-handler", + ) + except httpx.HTTPStatusError as e: + if stream is True: + setattr(e, "message", mask_sensitive_info(e.response.read())) + setattr(e, "text", mask_sensitive_info(e.response.read())) + else: + error_text = mask_sensitive_info(e.response.text) + setattr(e, "message", error_text) + setattr(e, "text", error_text) + + setattr(e, "status_code", e.response.status_code) + + raise e + except Exception as e: + raise e + + def patch( + self, + url: str, + data: Optional[Union[dict, str]] = None, + json: Optional[Union[dict, str]] = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + stream: bool = False, + timeout: Optional[Union[float, httpx.Timeout]] = None, + ): + try: + + if timeout is not None: + req = self.client.build_request( + "PATCH", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore + ) + else: + req = self.client.build_request( + "PATCH", url, data=data, json=json, params=params, headers=headers # type: ignore + ) + response = self.client.send(req, stream=stream) + response.raise_for_status() + return response + except httpx.TimeoutException: + raise litellm.Timeout( + message=f"Connection timed out after {timeout} seconds.", + model="default-model-name", + llm_provider="litellm-httpx-handler", + ) + except httpx.HTTPStatusError as e: + + if stream is True: + setattr(e, "message", mask_sensitive_info(e.response.read())) + setattr(e, "text", mask_sensitive_info(e.response.read())) + else: + error_text = mask_sensitive_info(e.response.text) + setattr(e, "message", error_text) + setattr(e, "text", error_text) + + setattr(e, "status_code", e.response.status_code) + + raise e + except Exception as e: + raise e + + def put( + self, + url: str, + data: Optional[Union[dict, str]] = None, + json: Optional[Union[dict, str]] = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + stream: bool = False, + timeout: Optional[Union[float, httpx.Timeout]] = None, + ): + try: + + if timeout is not None: + req = self.client.build_request( + "PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore + ) + else: + req = self.client.build_request( + "PUT", url, data=data, json=json, params=params, headers=headers # type: ignore + ) + response = self.client.send(req, stream=stream) + return response + except httpx.TimeoutException: + raise litellm.Timeout( + message=f"Connection timed out after {timeout} seconds.", + model="default-model-name", + llm_provider="litellm-httpx-handler", + ) + except Exception as e: + raise e + + def __del__(self) -> None: + try: + self.close() + except Exception: + pass + + def _create_sync_transport(self) -> Optional[HTTPTransport]: + """ + Create an HTTP transport with IPv4 only if litellm.force_ipv4 is True. + Otherwise, return None. + + Some users have seen httpx ConnectionError when using ipv6 - forcing ipv4 resolves the issue for them + """ + if litellm.force_ipv4: + return HTTPTransport(local_address="0.0.0.0") + else: + return None + + +def get_async_httpx_client( + llm_provider: Union[LlmProviders, httpxSpecialProvider], + params: Optional[dict] = None, +) -> AsyncHTTPHandler: + """ + Retrieves the async HTTP client from the cache + If not present, creates a new client + + Caches the new client and returns it. + """ + _params_key_name = "" + if params is not None: + for key, value in params.items(): + try: + _params_key_name += f"{key}_{value}" + except Exception: + pass + + _cache_key_name = "async_httpx_client" + _params_key_name + llm_provider + _cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key_name) + if _cached_client: + return _cached_client + + if params is not None: + _new_client = AsyncHTTPHandler(**params) + else: + _new_client = AsyncHTTPHandler( + timeout=httpx.Timeout(timeout=600.0, connect=5.0) + ) + + litellm.in_memory_llm_clients_cache.set_cache( + key=_cache_key_name, + value=_new_client, + ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS, + ) + return _new_client + + +def _get_httpx_client(params: Optional[dict] = None) -> HTTPHandler: + """ + Retrieves the HTTP client from the cache + If not present, creates a new client + + Caches the new client and returns it. + """ + _params_key_name = "" + if params is not None: + for key, value in params.items(): + try: + _params_key_name += f"{key}_{value}" + except Exception: + pass + + _cache_key_name = "httpx_client" + _params_key_name + + _cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key_name) + if _cached_client: + return _cached_client + + if params is not None: + _new_client = HTTPHandler(**params) + else: + _new_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) + + litellm.in_memory_llm_clients_cache.set_cache( + key=_cache_key_name, + value=_new_client, + ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS, + ) + return _new_client diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/httpx_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/httpx_handler.py new file mode 100644 index 00000000..6f684ba0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/httpx_handler.py @@ -0,0 +1,49 @@ +from typing import Optional, Union + +import httpx + +try: + from litellm._version import version +except Exception: + version = "0.0.0" + +headers = { + "User-Agent": f"litellm/{version}", +} + + +class HTTPHandler: + def __init__(self, concurrent_limit=1000): + # Create a client with a connection pool + self.client = httpx.AsyncClient( + limits=httpx.Limits( + max_connections=concurrent_limit, + max_keepalive_connections=concurrent_limit, + ), + headers=headers, + ) + + async def close(self): + # Close the client when you're done with it + await self.client.aclose() + + async def get( + self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None + ): + response = await self.client.get(url, params=params, headers=headers) + return response + + async def post( + self, + url: str, + data: Optional[Union[dict, str]] = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + ): + try: + response = await self.client.post( + url, data=data, params=params, headers=headers # type: ignore + ) + return response + except Exception as e: + raise e diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/llm_http_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/llm_http_handler.py new file mode 100644 index 00000000..00caf552 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/llm_http_handler.py @@ -0,0 +1,1260 @@ +import io +import json +from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Union + +import httpx # type: ignore + +import litellm +import litellm.litellm_core_utils +import litellm.types +import litellm.types.utils +from litellm.llms.base_llm.chat.transformation import BaseConfig +from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig +from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig +from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + _get_httpx_client, + get_async_httpx_client, +) +from litellm.responses.streaming_iterator import ( + BaseResponsesAPIStreamingIterator, + MockResponsesAPIStreamingIterator, + ResponsesAPIStreamingIterator, + SyncResponsesAPIStreamingIterator, +) +from litellm.types.llms.openai import ResponseInputParam, ResponsesAPIResponse +from litellm.types.rerank import OptionalRerankParams, RerankResponse +from litellm.types.router import GenericLiteLLMParams +from litellm.types.utils import EmbeddingResponse, FileTypes, TranscriptionResponse +from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class BaseLLMHTTPHandler: + + async def _make_common_async_call( + self, + async_httpx_client: AsyncHTTPHandler, + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: dict, + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + logging_obj: LiteLLMLoggingObj, + stream: bool = False, + ) -> httpx.Response: + """Common implementation across stream + non-stream calls. Meant to ensure consistent error-handling.""" + max_retry_on_unprocessable_entity_error = ( + provider_config.max_retry_on_unprocessable_entity_error + ) + + response: Optional[httpx.Response] = None + for i in range(max(max_retry_on_unprocessable_entity_error, 1)): + try: + response = await async_httpx_client.post( + url=api_base, + headers=headers, + data=json.dumps(data), + timeout=timeout, + stream=stream, + logging_obj=logging_obj, + ) + except httpx.HTTPStatusError as e: + hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error + should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error( + e=e, litellm_params=litellm_params + ) + if should_retry and not hit_max_retry: + data = ( + provider_config.transform_request_on_unprocessable_entity_error( + e=e, request_data=data + ) + ) + continue + else: + raise self._handle_error(e=e, provider_config=provider_config) + except Exception as e: + raise self._handle_error(e=e, provider_config=provider_config) + break + + if response is None: + raise provider_config.get_error_class( + error_message="No response from the API", + status_code=422, # don't retry on this error + headers={}, + ) + + return response + + def _make_common_sync_call( + self, + sync_httpx_client: HTTPHandler, + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: dict, + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + logging_obj: LiteLLMLoggingObj, + stream: bool = False, + ) -> httpx.Response: + + max_retry_on_unprocessable_entity_error = ( + provider_config.max_retry_on_unprocessable_entity_error + ) + + response: Optional[httpx.Response] = None + + for i in range(max(max_retry_on_unprocessable_entity_error, 1)): + try: + response = sync_httpx_client.post( + url=api_base, + headers=headers, + data=json.dumps(data), + timeout=timeout, + stream=stream, + logging_obj=logging_obj, + ) + except httpx.HTTPStatusError as e: + hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error + should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error( + e=e, litellm_params=litellm_params + ) + if should_retry and not hit_max_retry: + data = ( + provider_config.transform_request_on_unprocessable_entity_error( + e=e, request_data=data + ) + ) + continue + else: + raise self._handle_error(e=e, provider_config=provider_config) + except Exception as e: + raise self._handle_error(e=e, provider_config=provider_config) + break + + if response is None: + raise provider_config.get_error_class( + error_message="No response from the API", + status_code=422, # don't retry on this error + headers={}, + ) + + return response + + async def async_completion( + self, + custom_llm_provider: str, + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: dict, + timeout: Union[float, httpx.Timeout], + model: str, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + messages: list, + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + client: Optional[AsyncHTTPHandler] = None, + json_mode: bool = False, + ): + if client is None: + async_httpx_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders(custom_llm_provider), + params={"ssl_verify": litellm_params.get("ssl_verify", None)}, + ) + else: + async_httpx_client = client + + response = await self._make_common_async_call( + async_httpx_client=async_httpx_client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + litellm_params=litellm_params, + stream=False, + logging_obj=logging_obj, + ) + return provider_config.transform_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + json_mode=json_mode, + ) + + def completion( + self, + model: str, + messages: list, + api_base: str, + custom_llm_provider: str, + model_response: ModelResponse, + encoding, + logging_obj: LiteLLMLoggingObj, + optional_params: dict, + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + acompletion: bool, + stream: Optional[bool] = False, + fake_stream: bool = False, + api_key: Optional[str] = None, + headers: Optional[dict] = {}, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + ): + json_mode: bool = optional_params.pop("json_mode", False) + + provider_config = ProviderConfigManager.get_provider_chat_config( + model=model, provider=litellm.LlmProviders(custom_llm_provider) + ) + + # get config from model, custom llm provider + headers = provider_config.validate_environment( + api_key=api_key, + headers=headers or {}, + model=model, + messages=messages, + optional_params=optional_params, + api_base=api_base, + ) + + api_base = provider_config.get_complete_url( + api_base=api_base, + model=model, + optional_params=optional_params, + stream=stream, + litellm_params=litellm_params, + ) + + data = provider_config.transform_request( + model=model, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + headers=headers, + ) + + headers = provider_config.sign_request( + headers=headers, + optional_params=optional_params, + request_data=data, + api_base=api_base, + stream=stream, + fake_stream=fake_stream, + model=model, + ) + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": headers, + }, + ) + + if acompletion is True: + if stream is True: + data = self._add_stream_param_to_request_body( + data=data, + provider_config=provider_config, + fake_stream=fake_stream, + ) + return self.acompletion_stream_function( + model=model, + messages=messages, + api_base=api_base, + headers=headers, + custom_llm_provider=custom_llm_provider, + provider_config=provider_config, + timeout=timeout, + logging_obj=logging_obj, + data=data, + fake_stream=fake_stream, + client=( + client + if client is not None and isinstance(client, AsyncHTTPHandler) + else None + ), + litellm_params=litellm_params, + json_mode=json_mode, + ) + + else: + return self.async_completion( + custom_llm_provider=custom_llm_provider, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + model=model, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + client=( + client + if client is not None and isinstance(client, AsyncHTTPHandler) + else None + ), + json_mode=json_mode, + ) + + if stream is True: + data = self._add_stream_param_to_request_body( + data=data, + provider_config=provider_config, + fake_stream=fake_stream, + ) + if provider_config.has_custom_stream_wrapper is True: + return provider_config.get_sync_custom_stream_wrapper( + model=model, + custom_llm_provider=custom_llm_provider, + logging_obj=logging_obj, + api_base=api_base, + headers=headers, + data=data, + messages=messages, + client=client, + json_mode=json_mode, + ) + completion_stream, headers = self.make_sync_call( + provider_config=provider_config, + api_base=api_base, + headers=headers, # type: ignore + data=data, + model=model, + messages=messages, + logging_obj=logging_obj, + timeout=timeout, + fake_stream=fake_stream, + client=( + client + if client is not None and isinstance(client, HTTPHandler) + else None + ), + litellm_params=litellm_params, + ) + return CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider=custom_llm_provider, + logging_obj=logging_obj, + ) + + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client( + params={"ssl_verify": litellm_params.get("ssl_verify", None)} + ) + else: + sync_httpx_client = client + + response = self._make_common_sync_call( + sync_httpx_client=sync_httpx_client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + litellm_params=litellm_params, + logging_obj=logging_obj, + ) + return provider_config.transform_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + json_mode=json_mode, + ) + + def make_sync_call( + self, + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: dict, + model: str, + messages: list, + logging_obj, + litellm_params: dict, + timeout: Union[float, httpx.Timeout], + fake_stream: bool = False, + client: Optional[HTTPHandler] = None, + ) -> Tuple[Any, dict]: + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client( + { + "ssl_verify": litellm_params.get("ssl_verify", None), + } + ) + else: + sync_httpx_client = client + stream = True + if fake_stream is True: + stream = False + + response = self._make_common_sync_call( + sync_httpx_client=sync_httpx_client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + litellm_params=litellm_params, + stream=stream, + logging_obj=logging_obj, + ) + + if fake_stream is True: + completion_stream = provider_config.get_model_response_iterator( + streaming_response=response.json(), sync_stream=True + ) + else: + completion_stream = provider_config.get_model_response_iterator( + streaming_response=response.iter_lines(), sync_stream=True + ) + + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response="first stream response received", + additional_args={"complete_input_dict": data}, + ) + + return completion_stream, dict(response.headers) + + async def acompletion_stream_function( + self, + model: str, + messages: list, + api_base: str, + custom_llm_provider: str, + headers: dict, + provider_config: BaseConfig, + timeout: Union[float, httpx.Timeout], + logging_obj: LiteLLMLoggingObj, + data: dict, + litellm_params: dict, + fake_stream: bool = False, + client: Optional[AsyncHTTPHandler] = None, + json_mode: Optional[bool] = None, + ): + if provider_config.has_custom_stream_wrapper is True: + return provider_config.get_async_custom_stream_wrapper( + model=model, + custom_llm_provider=custom_llm_provider, + logging_obj=logging_obj, + api_base=api_base, + headers=headers, + data=data, + messages=messages, + client=client, + json_mode=json_mode, + ) + + completion_stream, _response_headers = await self.make_async_call_stream_helper( + custom_llm_provider=custom_llm_provider, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + messages=messages, + logging_obj=logging_obj, + timeout=timeout, + fake_stream=fake_stream, + client=client, + litellm_params=litellm_params, + ) + streamwrapper = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider=custom_llm_provider, + logging_obj=logging_obj, + ) + return streamwrapper + + async def make_async_call_stream_helper( + self, + custom_llm_provider: str, + provider_config: BaseConfig, + api_base: str, + headers: dict, + data: dict, + messages: list, + logging_obj: LiteLLMLoggingObj, + timeout: Union[float, httpx.Timeout], + litellm_params: dict, + fake_stream: bool = False, + client: Optional[AsyncHTTPHandler] = None, + ) -> Tuple[Any, httpx.Headers]: + """ + Helper function for making an async call with stream. + + Handles fake stream as well. + """ + if client is None: + async_httpx_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders(custom_llm_provider), + params={"ssl_verify": litellm_params.get("ssl_verify", None)}, + ) + else: + async_httpx_client = client + stream = True + if fake_stream is True: + stream = False + + response = await self._make_common_async_call( + async_httpx_client=async_httpx_client, + provider_config=provider_config, + api_base=api_base, + headers=headers, + data=data, + timeout=timeout, + litellm_params=litellm_params, + stream=stream, + logging_obj=logging_obj, + ) + + if fake_stream is True: + completion_stream = provider_config.get_model_response_iterator( + streaming_response=response.json(), sync_stream=False + ) + else: + completion_stream = provider_config.get_model_response_iterator( + streaming_response=response.aiter_lines(), sync_stream=False + ) + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response="first stream response received", + additional_args={"complete_input_dict": data}, + ) + + return completion_stream, response.headers + + def _add_stream_param_to_request_body( + self, + data: dict, + provider_config: BaseConfig, + fake_stream: bool, + ) -> dict: + """ + Some providers like Bedrock invoke do not support the stream parameter in the request body, we only pass `stream` in the request body the provider supports it. + """ + if fake_stream is True: + return data + if provider_config.supports_stream_param_in_request_body is True: + data["stream"] = True + return data + + def embedding( + self, + model: str, + input: list, + timeout: float, + custom_llm_provider: str, + logging_obj: LiteLLMLoggingObj, + api_base: Optional[str], + optional_params: dict, + litellm_params: dict, + model_response: EmbeddingResponse, + api_key: Optional[str] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + aembedding: bool = False, + headers={}, + ) -> EmbeddingResponse: + + provider_config = ProviderConfigManager.get_provider_embedding_config( + model=model, provider=litellm.LlmProviders(custom_llm_provider) + ) + # get config from model, custom llm provider + headers = provider_config.validate_environment( + api_key=api_key, + headers=headers, + model=model, + messages=[], + optional_params=optional_params, + ) + + api_base = provider_config.get_complete_url( + api_base=api_base, + model=model, + optional_params=optional_params, + litellm_params=litellm_params, + ) + + data = provider_config.transform_embedding_request( + model=model, + input=input, + optional_params=optional_params, + headers=headers, + ) + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": headers, + }, + ) + + if aembedding is True: + return self.aembedding( # type: ignore + request_data=data, + api_base=api_base, + headers=headers, + model=model, + custom_llm_provider=custom_llm_provider, + provider_config=provider_config, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + timeout=timeout, + client=client, + optional_params=optional_params, + litellm_params=litellm_params, + ) + + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client() + else: + sync_httpx_client = client + + try: + response = sync_httpx_client.post( + url=api_base, + headers=headers, + data=json.dumps(data), + timeout=timeout, + ) + except Exception as e: + raise self._handle_error( + e=e, + provider_config=provider_config, + ) + + return provider_config.transform_embedding_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=data, + optional_params=optional_params, + litellm_params=litellm_params, + ) + + async def aembedding( + self, + request_data: dict, + api_base: str, + headers: dict, + model: str, + custom_llm_provider: str, + provider_config: BaseEmbeddingConfig, + model_response: EmbeddingResponse, + logging_obj: LiteLLMLoggingObj, + optional_params: dict, + litellm_params: dict, + api_key: Optional[str] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + ) -> EmbeddingResponse: + if client is None or not isinstance(client, AsyncHTTPHandler): + async_httpx_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders(custom_llm_provider) + ) + else: + async_httpx_client = client + + try: + response = await async_httpx_client.post( + url=api_base, + headers=headers, + data=json.dumps(request_data), + timeout=timeout, + ) + except Exception as e: + raise self._handle_error(e=e, provider_config=provider_config) + + return provider_config.transform_embedding_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=request_data, + optional_params=optional_params, + litellm_params=litellm_params, + ) + + def rerank( + self, + model: str, + custom_llm_provider: str, + logging_obj: LiteLLMLoggingObj, + provider_config: BaseRerankConfig, + optional_rerank_params: OptionalRerankParams, + timeout: Optional[Union[float, httpx.Timeout]], + model_response: RerankResponse, + _is_async: bool = False, + headers: dict = {}, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + ) -> RerankResponse: + + # get config from model, custom llm provider + headers = provider_config.validate_environment( + api_key=api_key, + headers=headers, + model=model, + ) + + api_base = provider_config.get_complete_url( + api_base=api_base, + model=model, + ) + + data = provider_config.transform_rerank_request( + model=model, + optional_rerank_params=optional_rerank_params, + headers=headers, + ) + + ## LOGGING + logging_obj.pre_call( + input=optional_rerank_params.get("query", ""), + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": headers, + }, + ) + + if _is_async is True: + return self.arerank( # type: ignore + model=model, + request_data=data, + custom_llm_provider=custom_llm_provider, + provider_config=provider_config, + logging_obj=logging_obj, + model_response=model_response, + api_base=api_base, + headers=headers, + api_key=api_key, + timeout=timeout, + client=client, + ) + + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client() + else: + sync_httpx_client = client + + try: + response = sync_httpx_client.post( + url=api_base, + headers=headers, + data=json.dumps(data), + timeout=timeout, + ) + except Exception as e: + raise self._handle_error( + e=e, + provider_config=provider_config, + ) + + return provider_config.transform_rerank_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=data, + ) + + async def arerank( + self, + model: str, + request_data: dict, + custom_llm_provider: str, + provider_config: BaseRerankConfig, + logging_obj: LiteLLMLoggingObj, + model_response: RerankResponse, + api_base: str, + headers: dict, + api_key: Optional[str] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + ) -> RerankResponse: + + if client is None or not isinstance(client, AsyncHTTPHandler): + async_httpx_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders(custom_llm_provider) + ) + else: + async_httpx_client = client + try: + response = await async_httpx_client.post( + url=api_base, + headers=headers, + data=json.dumps(request_data), + timeout=timeout, + ) + except Exception as e: + raise self._handle_error(e=e, provider_config=provider_config) + + return provider_config.transform_rerank_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=request_data, + ) + + def handle_audio_file(self, audio_file: FileTypes) -> bytes: + """ + Processes the audio file input based on its type and returns the binary data. + + Args: + audio_file: Can be a file path (str), a tuple (filename, file_content), or binary data (bytes). + + Returns: + The binary data of the audio file. + """ + binary_data: bytes # Explicitly declare the type + + # Handle the audio file based on type + if isinstance(audio_file, str): + # If it's a file path + with open(audio_file, "rb") as f: + binary_data = f.read() # `f.read()` always returns `bytes` + elif isinstance(audio_file, tuple): + # Handle tuple case + _, file_content = audio_file[:2] + if isinstance(file_content, str): + with open(file_content, "rb") as f: + binary_data = f.read() # `f.read()` always returns `bytes` + elif isinstance(file_content, bytes): + binary_data = file_content + else: + raise TypeError( + f"Unexpected type in tuple: {type(file_content)}. Expected str or bytes." + ) + elif isinstance(audio_file, bytes): + # Assume it's already binary data + binary_data = audio_file + elif isinstance(audio_file, io.BufferedReader) or isinstance( + audio_file, io.BytesIO + ): + # Handle file-like objects + binary_data = audio_file.read() + + else: + raise TypeError(f"Unsupported type for audio_file: {type(audio_file)}") + + return binary_data + + def audio_transcriptions( + self, + model: str, + audio_file: FileTypes, + optional_params: dict, + model_response: TranscriptionResponse, + timeout: float, + max_retries: int, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str], + api_base: Optional[str], + custom_llm_provider: str, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + atranscription: bool = False, + headers: dict = {}, + litellm_params: dict = {}, + ) -> TranscriptionResponse: + provider_config = ProviderConfigManager.get_provider_audio_transcription_config( + model=model, provider=litellm.LlmProviders(custom_llm_provider) + ) + if provider_config is None: + raise ValueError( + f"No provider config found for model: {model} and provider: {custom_llm_provider}" + ) + headers = provider_config.validate_environment( + api_key=api_key, + headers=headers, + model=model, + messages=[], + optional_params=optional_params, + ) + + if client is None or not isinstance(client, HTTPHandler): + client = _get_httpx_client() + + complete_url = provider_config.get_complete_url( + api_base=api_base, + model=model, + optional_params=optional_params, + litellm_params=litellm_params, + ) + + # Handle the audio file based on type + binary_data = self.handle_audio_file(audio_file) + + try: + # Make the POST request + response = client.post( + url=complete_url, + headers=headers, + content=binary_data, + timeout=timeout, + ) + except Exception as e: + raise self._handle_error(e=e, provider_config=provider_config) + + if isinstance(provider_config, litellm.DeepgramAudioTranscriptionConfig): + returned_response = provider_config.transform_audio_transcription_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + request_data={}, + optional_params=optional_params, + litellm_params={}, + api_key=api_key, + ) + return returned_response + return model_response + + def response_api_handler( + self, + model: str, + input: Union[str, ResponseInputParam], + responses_api_provider_config: BaseResponsesAPIConfig, + response_api_optional_request_params: Dict, + custom_llm_provider: str, + litellm_params: GenericLiteLLMParams, + logging_obj: LiteLLMLoggingObj, + extra_headers: Optional[Dict[str, Any]] = None, + extra_body: Optional[Dict[str, Any]] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + _is_async: bool = False, + fake_stream: bool = False, + ) -> Union[ + ResponsesAPIResponse, + BaseResponsesAPIStreamingIterator, + Coroutine[ + Any, Any, Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator] + ], + ]: + """ + Handles responses API requests. + When _is_async=True, returns a coroutine instead of making the call directly. + """ + if _is_async: + # Return the async coroutine if called with _is_async=True + return self.async_response_api_handler( + model=model, + input=input, + responses_api_provider_config=responses_api_provider_config, + response_api_optional_request_params=response_api_optional_request_params, + custom_llm_provider=custom_llm_provider, + litellm_params=litellm_params, + logging_obj=logging_obj, + extra_headers=extra_headers, + extra_body=extra_body, + timeout=timeout, + client=client if isinstance(client, AsyncHTTPHandler) else None, + fake_stream=fake_stream, + ) + + if client is None or not isinstance(client, HTTPHandler): + sync_httpx_client = _get_httpx_client( + params={"ssl_verify": litellm_params.get("ssl_verify", None)} + ) + else: + sync_httpx_client = client + + headers = responses_api_provider_config.validate_environment( + api_key=litellm_params.api_key, + headers=response_api_optional_request_params.get("extra_headers", {}) or {}, + model=model, + ) + + if extra_headers: + headers.update(extra_headers) + + api_base = responses_api_provider_config.get_complete_url( + api_base=litellm_params.api_base, + model=model, + ) + + data = responses_api_provider_config.transform_responses_api_request( + model=model, + input=input, + response_api_optional_request_params=response_api_optional_request_params, + litellm_params=litellm_params, + headers=headers, + ) + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": headers, + }, + ) + + # Check if streaming is requested + stream = response_api_optional_request_params.get("stream", False) + + try: + if stream: + # For streaming, use stream=True in the request + if fake_stream is True: + stream, data = self._prepare_fake_stream_request( + stream=stream, + data=data, + fake_stream=fake_stream, + ) + response = sync_httpx_client.post( + url=api_base, + headers=headers, + data=json.dumps(data), + timeout=timeout + or response_api_optional_request_params.get("timeout"), + stream=stream, + ) + if fake_stream is True: + return MockResponsesAPIStreamingIterator( + response=response, + model=model, + logging_obj=logging_obj, + responses_api_provider_config=responses_api_provider_config, + ) + + return SyncResponsesAPIStreamingIterator( + response=response, + model=model, + logging_obj=logging_obj, + responses_api_provider_config=responses_api_provider_config, + ) + else: + # For non-streaming requests + response = sync_httpx_client.post( + url=api_base, + headers=headers, + data=json.dumps(data), + timeout=timeout + or response_api_optional_request_params.get("timeout"), + ) + except Exception as e: + raise self._handle_error( + e=e, + provider_config=responses_api_provider_config, + ) + + return responses_api_provider_config.transform_response_api_response( + model=model, + raw_response=response, + logging_obj=logging_obj, + ) + + async def async_response_api_handler( + self, + model: str, + input: Union[str, ResponseInputParam], + responses_api_provider_config: BaseResponsesAPIConfig, + response_api_optional_request_params: Dict, + custom_llm_provider: str, + litellm_params: GenericLiteLLMParams, + logging_obj: LiteLLMLoggingObj, + extra_headers: Optional[Dict[str, Any]] = None, + extra_body: Optional[Dict[str, Any]] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + fake_stream: bool = False, + ) -> Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]: + """ + Async version of the responses API handler. + Uses async HTTP client to make requests. + """ + if client is None or not isinstance(client, AsyncHTTPHandler): + async_httpx_client = get_async_httpx_client( + llm_provider=litellm.LlmProviders(custom_llm_provider), + params={"ssl_verify": litellm_params.get("ssl_verify", None)}, + ) + else: + async_httpx_client = client + + headers = responses_api_provider_config.validate_environment( + api_key=litellm_params.api_key, + headers=response_api_optional_request_params.get("extra_headers", {}) or {}, + model=model, + ) + + if extra_headers: + headers.update(extra_headers) + + api_base = responses_api_provider_config.get_complete_url( + api_base=litellm_params.api_base, + model=model, + ) + + data = responses_api_provider_config.transform_responses_api_request( + model=model, + input=input, + response_api_optional_request_params=response_api_optional_request_params, + litellm_params=litellm_params, + headers=headers, + ) + + ## LOGGING + logging_obj.pre_call( + input=input, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": headers, + }, + ) + # Check if streaming is requested + stream = response_api_optional_request_params.get("stream", False) + + try: + if stream: + # For streaming, we need to use stream=True in the request + if fake_stream is True: + stream, data = self._prepare_fake_stream_request( + stream=stream, + data=data, + fake_stream=fake_stream, + ) + + response = await async_httpx_client.post( + url=api_base, + headers=headers, + data=json.dumps(data), + timeout=timeout + or response_api_optional_request_params.get("timeout"), + stream=stream, + ) + + if fake_stream is True: + return MockResponsesAPIStreamingIterator( + response=response, + model=model, + logging_obj=logging_obj, + responses_api_provider_config=responses_api_provider_config, + ) + + # Return the streaming iterator + return ResponsesAPIStreamingIterator( + response=response, + model=model, + logging_obj=logging_obj, + responses_api_provider_config=responses_api_provider_config, + ) + else: + # For non-streaming, proceed as before + response = await async_httpx_client.post( + url=api_base, + headers=headers, + data=json.dumps(data), + timeout=timeout + or response_api_optional_request_params.get("timeout"), + ) + + except Exception as e: + raise self._handle_error( + e=e, + provider_config=responses_api_provider_config, + ) + + return responses_api_provider_config.transform_response_api_response( + model=model, + raw_response=response, + logging_obj=logging_obj, + ) + + def _prepare_fake_stream_request( + self, + stream: bool, + data: dict, + fake_stream: bool, + ) -> Tuple[bool, dict]: + """ + Handles preparing a request when `fake_stream` is True. + """ + if fake_stream is True: + stream = False + data.pop("stream", None) + return stream, data + return stream, data + + def _handle_error( + self, + e: Exception, + provider_config: Union[BaseConfig, BaseRerankConfig, BaseResponsesAPIConfig], + ): + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_text = getattr(e, "text", str(e)) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + if error_response and hasattr(error_response, "text"): + error_text = getattr(error_response, "text", error_text) + if error_headers: + error_headers = dict(error_headers) + else: + error_headers = {} + raise provider_config.get_error_class( + error_message=error_text, + status_code=status_code, + headers=error_headers, + ) |