aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/aiohttp_handler.py595
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/http_handler.py746
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/httpx_handler.py49
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/llm_http_handler.py1260
4 files changed, 2650 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/aiohttp_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/aiohttp_handler.py
new file mode 100644
index 00000000..c865fee1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/aiohttp_handler.py
@@ -0,0 +1,595 @@
+from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union, cast
+
+import aiohttp
+import httpx # type: ignore
+from aiohttp import ClientSession, FormData
+
+import litellm
+import litellm.litellm_core_utils
+import litellm.types
+import litellm.types.utils
+from litellm.llms.base_llm.chat.transformation import BaseConfig
+from litellm.llms.base_llm.image_variations.transformation import (
+ BaseImageVariationConfig,
+)
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ _get_httpx_client,
+)
+from litellm.types.llms.openai import FileTypes
+from litellm.types.utils import HttpHandlerRequestFields, ImageResponse, LlmProviders
+from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
+
+ LiteLLMLoggingObj = _LiteLLMLoggingObj
+else:
+ LiteLLMLoggingObj = Any
+
+DEFAULT_TIMEOUT = 600
+
+
+class BaseLLMAIOHTTPHandler:
+
+ def __init__(self):
+ self.client_session: Optional[aiohttp.ClientSession] = None
+
+ def _get_async_client_session(
+ self, dynamic_client_session: Optional[ClientSession] = None
+ ) -> ClientSession:
+ if dynamic_client_session:
+ return dynamic_client_session
+ elif self.client_session:
+ return self.client_session
+ else:
+ # init client session, and then return new session
+ self.client_session = aiohttp.ClientSession()
+ return self.client_session
+
+ async def _make_common_async_call(
+ self,
+ async_client_session: Optional[ClientSession],
+ provider_config: BaseConfig,
+ api_base: str,
+ headers: dict,
+ data: Optional[dict],
+ timeout: Union[float, httpx.Timeout],
+ litellm_params: dict,
+ form_data: Optional[FormData] = None,
+ stream: bool = False,
+ ) -> aiohttp.ClientResponse:
+ """Common implementation across stream + non-stream calls. Meant to ensure consistent error-handling."""
+ max_retry_on_unprocessable_entity_error = (
+ provider_config.max_retry_on_unprocessable_entity_error
+ )
+
+ response: Optional[aiohttp.ClientResponse] = None
+ async_client_session = self._get_async_client_session(
+ dynamic_client_session=async_client_session
+ )
+
+ for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
+ try:
+ response = await async_client_session.post(
+ url=api_base,
+ headers=headers,
+ json=data,
+ data=form_data,
+ )
+ if not response.ok:
+ response.raise_for_status()
+ except aiohttp.ClientResponseError as e:
+ setattr(e, "text", e.message)
+ raise self._handle_error(e=e, provider_config=provider_config)
+ except Exception as e:
+ raise self._handle_error(e=e, provider_config=provider_config)
+ break
+
+ if response is None:
+ raise provider_config.get_error_class(
+ error_message="No response from the API",
+ status_code=422,
+ headers={},
+ )
+
+ return response
+
+ def _make_common_sync_call(
+ self,
+ sync_httpx_client: HTTPHandler,
+ provider_config: BaseConfig,
+ api_base: str,
+ headers: dict,
+ data: dict,
+ timeout: Union[float, httpx.Timeout],
+ litellm_params: dict,
+ stream: bool = False,
+ files: Optional[dict] = None,
+ content: Any = None,
+ params: Optional[dict] = None,
+ ) -> httpx.Response:
+
+ max_retry_on_unprocessable_entity_error = (
+ provider_config.max_retry_on_unprocessable_entity_error
+ )
+
+ response: Optional[httpx.Response] = None
+
+ for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
+ try:
+ response = sync_httpx_client.post(
+ url=api_base,
+ headers=headers,
+ data=data, # do not json dump the data here. let the individual endpoint handle this.
+ timeout=timeout,
+ stream=stream,
+ files=files,
+ content=content,
+ params=params,
+ )
+ except httpx.HTTPStatusError as e:
+ hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
+ should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
+ e=e, litellm_params=litellm_params
+ )
+ if should_retry and not hit_max_retry:
+ data = (
+ provider_config.transform_request_on_unprocessable_entity_error(
+ e=e, request_data=data
+ )
+ )
+ continue
+ else:
+ raise self._handle_error(e=e, provider_config=provider_config)
+ except Exception as e:
+ raise self._handle_error(e=e, provider_config=provider_config)
+ break
+
+ if response is None:
+ raise provider_config.get_error_class(
+ error_message="No response from the API",
+ status_code=422, # don't retry on this error
+ headers={},
+ )
+
+ return response
+
+ async def async_completion(
+ self,
+ custom_llm_provider: str,
+ provider_config: BaseConfig,
+ api_base: str,
+ headers: dict,
+ data: dict,
+ timeout: Union[float, httpx.Timeout],
+ model: str,
+ model_response: ModelResponse,
+ logging_obj: LiteLLMLoggingObj,
+ messages: list,
+ optional_params: dict,
+ litellm_params: dict,
+ encoding: Any,
+ api_key: Optional[str] = None,
+ client: Optional[ClientSession] = None,
+ ):
+ _response = await self._make_common_async_call(
+ async_client_session=client,
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers,
+ data=data,
+ timeout=timeout,
+ litellm_params=litellm_params,
+ stream=False,
+ )
+ _transformed_response = await provider_config.transform_response( # type: ignore
+ model=model,
+ raw_response=_response, # type: ignore
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ request_data=data,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=encoding,
+ )
+ return _transformed_response
+
+ def completion(
+ self,
+ model: str,
+ messages: list,
+ api_base: str,
+ custom_llm_provider: str,
+ model_response: ModelResponse,
+ encoding,
+ logging_obj: LiteLLMLoggingObj,
+ optional_params: dict,
+ timeout: Union[float, httpx.Timeout],
+ litellm_params: dict,
+ acompletion: bool,
+ stream: Optional[bool] = False,
+ fake_stream: bool = False,
+ api_key: Optional[str] = None,
+ headers: Optional[dict] = {},
+ client: Optional[Union[HTTPHandler, AsyncHTTPHandler, ClientSession]] = None,
+ ):
+ provider_config = ProviderConfigManager.get_provider_chat_config(
+ model=model, provider=litellm.LlmProviders(custom_llm_provider)
+ )
+ # get config from model, custom llm provider
+ headers = provider_config.validate_environment(
+ api_key=api_key,
+ headers=headers or {},
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ api_base=api_base,
+ )
+
+ api_base = provider_config.get_complete_url(
+ api_base=api_base,
+ model=model,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ stream=stream,
+ )
+
+ data = provider_config.transform_request(
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ headers=headers,
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=messages,
+ api_key=api_key,
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": api_base,
+ "headers": headers,
+ },
+ )
+
+ if acompletion is True:
+ return self.async_completion(
+ custom_llm_provider=custom_llm_provider,
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers,
+ data=data,
+ timeout=timeout,
+ model=model,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=encoding,
+ client=(
+ client
+ if client is not None and isinstance(client, ClientSession)
+ else None
+ ),
+ )
+
+ if stream is True:
+ if fake_stream is not True:
+ data["stream"] = stream
+ completion_stream, headers = self.make_sync_call(
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers, # type: ignore
+ data=data,
+ model=model,
+ messages=messages,
+ logging_obj=logging_obj,
+ timeout=timeout,
+ fake_stream=fake_stream,
+ client=(
+ client
+ if client is not None and isinstance(client, HTTPHandler)
+ else None
+ ),
+ litellm_params=litellm_params,
+ )
+ return CustomStreamWrapper(
+ completion_stream=completion_stream,
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ logging_obj=logging_obj,
+ )
+
+ if client is None or not isinstance(client, HTTPHandler):
+ sync_httpx_client = _get_httpx_client()
+ else:
+ sync_httpx_client = client
+
+ response = self._make_common_sync_call(
+ sync_httpx_client=sync_httpx_client,
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers,
+ timeout=timeout,
+ litellm_params=litellm_params,
+ data=data,
+ )
+ return provider_config.transform_response(
+ model=model,
+ raw_response=response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ request_data=data,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=encoding,
+ )
+
+ def make_sync_call(
+ self,
+ provider_config: BaseConfig,
+ api_base: str,
+ headers: dict,
+ data: dict,
+ model: str,
+ messages: list,
+ logging_obj,
+ litellm_params: dict,
+ timeout: Union[float, httpx.Timeout],
+ fake_stream: bool = False,
+ client: Optional[HTTPHandler] = None,
+ ) -> Tuple[Any, dict]:
+ if client is None or not isinstance(client, HTTPHandler):
+ sync_httpx_client = _get_httpx_client()
+ else:
+ sync_httpx_client = client
+ stream = True
+ if fake_stream is True:
+ stream = False
+
+ response = self._make_common_sync_call(
+ sync_httpx_client=sync_httpx_client,
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers,
+ data=data,
+ timeout=timeout,
+ litellm_params=litellm_params,
+ stream=stream,
+ )
+
+ if fake_stream is True:
+ completion_stream = provider_config.get_model_response_iterator(
+ streaming_response=response.json(), sync_stream=True
+ )
+ else:
+ completion_stream = provider_config.get_model_response_iterator(
+ streaming_response=response.iter_lines(), sync_stream=True
+ )
+
+ # LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response="first stream response received",
+ additional_args={"complete_input_dict": data},
+ )
+
+ return completion_stream, dict(response.headers)
+
+ async def async_image_variations(
+ self,
+ client: Optional[ClientSession],
+ provider_config: BaseImageVariationConfig,
+ api_base: str,
+ headers: dict,
+ data: HttpHandlerRequestFields,
+ timeout: float,
+ litellm_params: dict,
+ model_response: ImageResponse,
+ logging_obj: LiteLLMLoggingObj,
+ api_key: str,
+ model: Optional[str],
+ image: FileTypes,
+ optional_params: dict,
+ ) -> ImageResponse:
+ # create aiohttp form data if files in data
+ form_data: Optional[FormData] = None
+ if "files" in data and "data" in data:
+ form_data = FormData()
+ for k, v in data["files"].items():
+ form_data.add_field(k, v[1], filename=v[0], content_type=v[2])
+
+ for key, value in data["data"].items():
+ form_data.add_field(key, value)
+
+ _response = await self._make_common_async_call(
+ async_client_session=client,
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers,
+ data=None if form_data is not None else cast(dict, data),
+ form_data=form_data,
+ timeout=timeout,
+ litellm_params=litellm_params,
+ stream=False,
+ )
+
+ ## LOGGING
+ logging_obj.post_call(
+ api_key=api_key,
+ original_response=_response.text,
+ additional_args={
+ "headers": headers,
+ "api_base": api_base,
+ },
+ )
+
+ ## RESPONSE OBJECT
+ return await provider_config.async_transform_response_image_variation(
+ model=model,
+ model_response=model_response,
+ raw_response=_response,
+ logging_obj=logging_obj,
+ request_data=cast(dict, data),
+ image=image,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=None,
+ api_key=api_key,
+ )
+
+ def image_variations(
+ self,
+ model_response: ImageResponse,
+ api_key: str,
+ model: Optional[str],
+ image: FileTypes,
+ timeout: float,
+ custom_llm_provider: str,
+ logging_obj: LiteLLMLoggingObj,
+ optional_params: dict,
+ litellm_params: dict,
+ print_verbose: Optional[Callable] = None,
+ api_base: Optional[str] = None,
+ aimage_variation: bool = False,
+ logger_fn=None,
+ client=None,
+ organization: Optional[str] = None,
+ headers: Optional[dict] = None,
+ ) -> ImageResponse:
+ if model is None:
+ raise ValueError("model is required for non-openai image variations")
+
+ provider_config = ProviderConfigManager.get_provider_image_variation_config(
+ model=model, # openai defaults to dall-e-2
+ provider=LlmProviders(custom_llm_provider),
+ )
+
+ if provider_config is None:
+ raise ValueError(
+ f"image variation provider not found: {custom_llm_provider}."
+ )
+
+ api_base = provider_config.get_complete_url(
+ api_base=api_base,
+ model=model,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ stream=False,
+ )
+
+ headers = provider_config.validate_environment(
+ api_key=api_key,
+ headers=headers or {},
+ model=model,
+ messages=[{"role": "user", "content": "test"}],
+ optional_params=optional_params,
+ api_base=api_base,
+ )
+
+ data = provider_config.transform_request_image_variation(
+ model=model,
+ image=image,
+ optional_params=optional_params,
+ headers=headers,
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input="",
+ api_key=api_key,
+ additional_args={
+ "headers": headers,
+ "api_base": api_base,
+ "complete_input_dict": data.copy(),
+ },
+ )
+
+ if litellm_params.get("async_call", False):
+ return self.async_image_variations(
+ api_base=api_base,
+ data=data,
+ headers=headers,
+ model_response=model_response,
+ api_key=api_key,
+ logging_obj=logging_obj,
+ model=model,
+ timeout=timeout,
+ client=client,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ image=image,
+ provider_config=provider_config,
+ ) # type: ignore
+
+ if client is None or not isinstance(client, HTTPHandler):
+ sync_httpx_client = _get_httpx_client()
+ else:
+ sync_httpx_client = client
+
+ response = self._make_common_sync_call(
+ sync_httpx_client=sync_httpx_client,
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers,
+ timeout=timeout,
+ litellm_params=litellm_params,
+ stream=False,
+ data=data.get("data") or {},
+ files=data.get("files"),
+ content=data.get("content"),
+ params=data.get("params"),
+ )
+
+ ## LOGGING
+ logging_obj.post_call(
+ api_key=api_key,
+ original_response=response.text,
+ additional_args={
+ "headers": headers,
+ "api_base": api_base,
+ },
+ )
+
+ ## RESPONSE OBJECT
+ return provider_config.transform_response_image_variation(
+ model=model,
+ model_response=model_response,
+ raw_response=response,
+ logging_obj=logging_obj,
+ request_data=cast(dict, data),
+ image=image,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=None,
+ api_key=api_key,
+ )
+
+ def _handle_error(self, e: Exception, provider_config: BaseConfig):
+ status_code = getattr(e, "status_code", 500)
+ error_headers = getattr(e, "headers", None)
+ error_text = getattr(e, "text", str(e))
+ error_response = getattr(e, "response", None)
+ if error_headers is None and error_response:
+ error_headers = getattr(error_response, "headers", None)
+ if error_response and hasattr(error_response, "text"):
+ error_text = getattr(error_response, "text", error_text)
+ if error_headers:
+ error_headers = dict(error_headers)
+ else:
+ error_headers = {}
+ raise provider_config.get_error_class(
+ error_message=error_text,
+ status_code=status_code,
+ headers=error_headers,
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/http_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/http_handler.py
new file mode 100644
index 00000000..34d70434
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/http_handler.py
@@ -0,0 +1,746 @@
+import asyncio
+import os
+import ssl
+import time
+from typing import TYPE_CHECKING, Any, Callable, List, Mapping, Optional, Union
+
+import httpx
+from httpx import USE_CLIENT_DEFAULT, AsyncHTTPTransport, HTTPTransport
+
+import litellm
+from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
+from litellm.types.llms.custom_http import *
+
+if TYPE_CHECKING:
+ from litellm import LlmProviders
+ from litellm.litellm_core_utils.litellm_logging import (
+ Logging as LiteLLMLoggingObject,
+ )
+else:
+ LlmProviders = Any
+ LiteLLMLoggingObject = Any
+
+try:
+ from litellm._version import version
+except Exception:
+ version = "0.0.0"
+
+headers = {
+ "User-Agent": f"litellm/{version}",
+}
+
+# https://www.python-httpx.org/advanced/timeouts
+_DEFAULT_TIMEOUT = httpx.Timeout(timeout=5.0, connect=5.0)
+_DEFAULT_TTL_FOR_HTTPX_CLIENTS = 3600 # 1 hour, re-use the same httpx client for 1 hour
+
+
+def mask_sensitive_info(error_message):
+ # Find the start of the key parameter
+ if isinstance(error_message, str):
+ key_index = error_message.find("key=")
+ else:
+ return error_message
+
+ # If key is found
+ if key_index != -1:
+ # Find the end of the key parameter (next & or end of string)
+ next_param = error_message.find("&", key_index)
+
+ if next_param == -1:
+ # If no more parameters, mask until the end of the string
+ masked_message = error_message[: key_index + 4] + "[REDACTED_API_KEY]"
+ else:
+ # Replace the key with redacted value, keeping other parameters
+ masked_message = (
+ error_message[: key_index + 4]
+ + "[REDACTED_API_KEY]"
+ + error_message[next_param:]
+ )
+
+ return masked_message
+
+ return error_message
+
+
+class MaskedHTTPStatusError(httpx.HTTPStatusError):
+ def __init__(
+ self, original_error, message: Optional[str] = None, text: Optional[str] = None
+ ):
+ # Create a new error with the masked URL
+ masked_url = mask_sensitive_info(str(original_error.request.url))
+ # Create a new error that looks like the original, but with a masked URL
+
+ super().__init__(
+ message=original_error.message,
+ request=httpx.Request(
+ method=original_error.request.method,
+ url=masked_url,
+ headers=original_error.request.headers,
+ content=original_error.request.content,
+ ),
+ response=httpx.Response(
+ status_code=original_error.response.status_code,
+ content=original_error.response.content,
+ headers=original_error.response.headers,
+ ),
+ )
+ self.message = message
+ self.text = text
+
+
+class AsyncHTTPHandler:
+ def __init__(
+ self,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]] = None,
+ concurrent_limit=1000,
+ client_alias: Optional[str] = None, # name for client in logs
+ ssl_verify: Optional[VerifyTypes] = None,
+ ):
+ self.timeout = timeout
+ self.event_hooks = event_hooks
+ self.client = self.create_client(
+ timeout=timeout,
+ concurrent_limit=concurrent_limit,
+ event_hooks=event_hooks,
+ ssl_verify=ssl_verify,
+ )
+ self.client_alias = client_alias
+
+ def create_client(
+ self,
+ timeout: Optional[Union[float, httpx.Timeout]],
+ concurrent_limit: int,
+ event_hooks: Optional[Mapping[str, List[Callable[..., Any]]]],
+ ssl_verify: Optional[VerifyTypes] = None,
+ ) -> httpx.AsyncClient:
+
+ # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
+ # /path/to/certificate.pem
+ if ssl_verify is None:
+ ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
+
+ ssl_security_level = os.getenv("SSL_SECURITY_LEVEL")
+
+ # If ssl_verify is not False and we need a lower security level
+ if (
+ not ssl_verify
+ and ssl_security_level
+ and isinstance(ssl_security_level, str)
+ ):
+ # Create a custom SSL context with reduced security level
+ custom_ssl_context = ssl.create_default_context()
+ custom_ssl_context.set_ciphers(ssl_security_level)
+
+ # If ssl_verify is a path to a CA bundle, load it into our custom context
+ if isinstance(ssl_verify, str) and os.path.exists(ssl_verify):
+ custom_ssl_context.load_verify_locations(cafile=ssl_verify)
+
+ # Use our custom SSL context instead of the original ssl_verify value
+ ssl_verify = custom_ssl_context
+
+ # An SSL certificate used by the requested host to authenticate the client.
+ # /path/to/client.pem
+ cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate)
+
+ if timeout is None:
+ timeout = _DEFAULT_TIMEOUT
+ # Create a client with a connection pool
+ transport = self._create_async_transport()
+
+ return httpx.AsyncClient(
+ transport=transport,
+ event_hooks=event_hooks,
+ timeout=timeout,
+ limits=httpx.Limits(
+ max_connections=concurrent_limit,
+ max_keepalive_connections=concurrent_limit,
+ ),
+ verify=ssl_verify,
+ cert=cert,
+ headers=headers,
+ )
+
+ async def close(self):
+ # Close the client when you're done with it
+ await self.client.aclose()
+
+ async def __aenter__(self):
+ return self.client
+
+ async def __aexit__(self):
+ # close the client when exiting
+ await self.client.aclose()
+
+ async def get(
+ self,
+ url: str,
+ params: Optional[dict] = None,
+ headers: Optional[dict] = None,
+ follow_redirects: Optional[bool] = None,
+ ):
+ # Set follow_redirects to UseClientDefault if None
+ _follow_redirects = (
+ follow_redirects if follow_redirects is not None else USE_CLIENT_DEFAULT
+ )
+
+ response = await self.client.get(
+ url, params=params, headers=headers, follow_redirects=_follow_redirects # type: ignore
+ )
+ return response
+
+ @track_llm_api_timing()
+ async def post(
+ self,
+ url: str,
+ data: Optional[Union[dict, str]] = None, # type: ignore
+ json: Optional[dict] = None,
+ params: Optional[dict] = None,
+ headers: Optional[dict] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ stream: bool = False,
+ logging_obj: Optional[LiteLLMLoggingObject] = None,
+ ):
+ start_time = time.time()
+ try:
+ if timeout is None:
+ timeout = self.timeout
+
+ req = self.client.build_request(
+ "POST", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
+ )
+ response = await self.client.send(req, stream=stream)
+ response.raise_for_status()
+ return response
+ except (httpx.RemoteProtocolError, httpx.ConnectError):
+ # Retry the request with a new session if there is a connection error
+ new_client = self.create_client(
+ timeout=timeout, concurrent_limit=1, event_hooks=self.event_hooks
+ )
+ try:
+ return await self.single_connection_post_request(
+ url=url,
+ client=new_client,
+ data=data,
+ json=json,
+ params=params,
+ headers=headers,
+ stream=stream,
+ )
+ finally:
+ await new_client.aclose()
+ except httpx.TimeoutException as e:
+ end_time = time.time()
+ time_delta = round(end_time - start_time, 3)
+ headers = {}
+ error_response = getattr(e, "response", None)
+ if error_response is not None:
+ for key, value in error_response.headers.items():
+ headers["response_headers-{}".format(key)] = value
+
+ raise litellm.Timeout(
+ message=f"Connection timed out. Timeout passed={timeout}, time taken={time_delta} seconds",
+ model="default-model-name",
+ llm_provider="litellm-httpx-handler",
+ headers=headers,
+ )
+ except httpx.HTTPStatusError as e:
+ if stream is True:
+ setattr(e, "message", await e.response.aread())
+ setattr(e, "text", await e.response.aread())
+ else:
+ setattr(e, "message", mask_sensitive_info(e.response.text))
+ setattr(e, "text", mask_sensitive_info(e.response.text))
+
+ setattr(e, "status_code", e.response.status_code)
+
+ raise e
+ except Exception as e:
+ raise e
+
+ async def put(
+ self,
+ url: str,
+ data: Optional[Union[dict, str]] = None, # type: ignore
+ json: Optional[dict] = None,
+ params: Optional[dict] = None,
+ headers: Optional[dict] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ stream: bool = False,
+ ):
+ try:
+ if timeout is None:
+ timeout = self.timeout
+
+ req = self.client.build_request(
+ "PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
+ )
+ response = await self.client.send(req)
+ response.raise_for_status()
+ return response
+ except (httpx.RemoteProtocolError, httpx.ConnectError):
+ # Retry the request with a new session if there is a connection error
+ new_client = self.create_client(
+ timeout=timeout, concurrent_limit=1, event_hooks=self.event_hooks
+ )
+ try:
+ return await self.single_connection_post_request(
+ url=url,
+ client=new_client,
+ data=data,
+ json=json,
+ params=params,
+ headers=headers,
+ stream=stream,
+ )
+ finally:
+ await new_client.aclose()
+ except httpx.TimeoutException as e:
+ headers = {}
+ error_response = getattr(e, "response", None)
+ if error_response is not None:
+ for key, value in error_response.headers.items():
+ headers["response_headers-{}".format(key)] = value
+
+ raise litellm.Timeout(
+ message=f"Connection timed out after {timeout} seconds.",
+ model="default-model-name",
+ llm_provider="litellm-httpx-handler",
+ headers=headers,
+ )
+ except httpx.HTTPStatusError as e:
+ setattr(e, "status_code", e.response.status_code)
+ if stream is True:
+ setattr(e, "message", await e.response.aread())
+ else:
+ setattr(e, "message", e.response.text)
+ raise e
+ except Exception as e:
+ raise e
+
+ async def patch(
+ self,
+ url: str,
+ data: Optional[Union[dict, str]] = None, # type: ignore
+ json: Optional[dict] = None,
+ params: Optional[dict] = None,
+ headers: Optional[dict] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ stream: bool = False,
+ ):
+ try:
+ if timeout is None:
+ timeout = self.timeout
+
+ req = self.client.build_request(
+ "PATCH", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
+ )
+ response = await self.client.send(req)
+ response.raise_for_status()
+ return response
+ except (httpx.RemoteProtocolError, httpx.ConnectError):
+ # Retry the request with a new session if there is a connection error
+ new_client = self.create_client(
+ timeout=timeout, concurrent_limit=1, event_hooks=self.event_hooks
+ )
+ try:
+ return await self.single_connection_post_request(
+ url=url,
+ client=new_client,
+ data=data,
+ json=json,
+ params=params,
+ headers=headers,
+ stream=stream,
+ )
+ finally:
+ await new_client.aclose()
+ except httpx.TimeoutException as e:
+ headers = {}
+ error_response = getattr(e, "response", None)
+ if error_response is not None:
+ for key, value in error_response.headers.items():
+ headers["response_headers-{}".format(key)] = value
+
+ raise litellm.Timeout(
+ message=f"Connection timed out after {timeout} seconds.",
+ model="default-model-name",
+ llm_provider="litellm-httpx-handler",
+ headers=headers,
+ )
+ except httpx.HTTPStatusError as e:
+ setattr(e, "status_code", e.response.status_code)
+ if stream is True:
+ setattr(e, "message", await e.response.aread())
+ else:
+ setattr(e, "message", e.response.text)
+ raise e
+ except Exception as e:
+ raise e
+
+ async def delete(
+ self,
+ url: str,
+ data: Optional[Union[dict, str]] = None, # type: ignore
+ json: Optional[dict] = None,
+ params: Optional[dict] = None,
+ headers: Optional[dict] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ stream: bool = False,
+ ):
+ try:
+ if timeout is None:
+ timeout = self.timeout
+ req = self.client.build_request(
+ "DELETE", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
+ )
+ response = await self.client.send(req, stream=stream)
+ response.raise_for_status()
+ return response
+ except (httpx.RemoteProtocolError, httpx.ConnectError):
+ # Retry the request with a new session if there is a connection error
+ new_client = self.create_client(
+ timeout=timeout, concurrent_limit=1, event_hooks=self.event_hooks
+ )
+ try:
+ return await self.single_connection_post_request(
+ url=url,
+ client=new_client,
+ data=data,
+ json=json,
+ params=params,
+ headers=headers,
+ stream=stream,
+ )
+ finally:
+ await new_client.aclose()
+ except httpx.HTTPStatusError as e:
+ setattr(e, "status_code", e.response.status_code)
+ if stream is True:
+ setattr(e, "message", await e.response.aread())
+ else:
+ setattr(e, "message", e.response.text)
+ raise e
+ except Exception as e:
+ raise e
+
+ async def single_connection_post_request(
+ self,
+ url: str,
+ client: httpx.AsyncClient,
+ data: Optional[Union[dict, str]] = None, # type: ignore
+ json: Optional[dict] = None,
+ params: Optional[dict] = None,
+ headers: Optional[dict] = None,
+ stream: bool = False,
+ ):
+ """
+ Making POST request for a single connection client.
+
+ Used for retrying connection client errors.
+ """
+ req = client.build_request(
+ "POST", url, data=data, json=json, params=params, headers=headers # type: ignore
+ )
+ response = await client.send(req, stream=stream)
+ response.raise_for_status()
+ return response
+
+ def __del__(self) -> None:
+ try:
+ asyncio.get_running_loop().create_task(self.close())
+ except Exception:
+ pass
+
+ def _create_async_transport(self) -> Optional[AsyncHTTPTransport]:
+ """
+ Create an async transport with IPv4 only if litellm.force_ipv4 is True.
+ Otherwise, return None.
+
+ Some users have seen httpx ConnectionError when using ipv6 - forcing ipv4 resolves the issue for them
+ """
+ if litellm.force_ipv4:
+ return AsyncHTTPTransport(local_address="0.0.0.0")
+ else:
+ return None
+
+
+class HTTPHandler:
+ def __init__(
+ self,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ concurrent_limit=1000,
+ client: Optional[httpx.Client] = None,
+ ssl_verify: Optional[Union[bool, str]] = None,
+ ):
+ if timeout is None:
+ timeout = _DEFAULT_TIMEOUT
+
+ # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts.
+ # /path/to/certificate.pem
+
+ if ssl_verify is None:
+ ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify)
+
+ # An SSL certificate used by the requested host to authenticate the client.
+ # /path/to/client.pem
+ cert = os.getenv("SSL_CERTIFICATE", litellm.ssl_certificate)
+
+ if client is None:
+ transport = self._create_sync_transport()
+
+ # Create a client with a connection pool
+ self.client = httpx.Client(
+ transport=transport,
+ timeout=timeout,
+ limits=httpx.Limits(
+ max_connections=concurrent_limit,
+ max_keepalive_connections=concurrent_limit,
+ ),
+ verify=ssl_verify,
+ cert=cert,
+ headers=headers,
+ )
+ else:
+ self.client = client
+
+ def close(self):
+ # Close the client when you're done with it
+ self.client.close()
+
+ def get(
+ self,
+ url: str,
+ params: Optional[dict] = None,
+ headers: Optional[dict] = None,
+ follow_redirects: Optional[bool] = None,
+ ):
+ # Set follow_redirects to UseClientDefault if None
+ _follow_redirects = (
+ follow_redirects if follow_redirects is not None else USE_CLIENT_DEFAULT
+ )
+
+ response = self.client.get(
+ url, params=params, headers=headers, follow_redirects=_follow_redirects # type: ignore
+ )
+ return response
+
+ def post(
+ self,
+ url: str,
+ data: Optional[Union[dict, str]] = None,
+ json: Optional[Union[dict, str, List]] = None,
+ params: Optional[dict] = None,
+ headers: Optional[dict] = None,
+ stream: bool = False,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ files: Optional[dict] = None,
+ content: Any = None,
+ logging_obj: Optional[LiteLLMLoggingObject] = None,
+ ):
+ try:
+ if timeout is not None:
+ req = self.client.build_request(
+ "POST",
+ url,
+ data=data, # type: ignore
+ json=json,
+ params=params,
+ headers=headers,
+ timeout=timeout,
+ files=files,
+ content=content, # type: ignore
+ )
+ else:
+ req = self.client.build_request(
+ "POST", url, data=data, json=json, params=params, headers=headers, files=files, content=content # type: ignore
+ )
+ response = self.client.send(req, stream=stream)
+ response.raise_for_status()
+ return response
+ except httpx.TimeoutException:
+ raise litellm.Timeout(
+ message=f"Connection timed out after {timeout} seconds.",
+ model="default-model-name",
+ llm_provider="litellm-httpx-handler",
+ )
+ except httpx.HTTPStatusError as e:
+ if stream is True:
+ setattr(e, "message", mask_sensitive_info(e.response.read()))
+ setattr(e, "text", mask_sensitive_info(e.response.read()))
+ else:
+ error_text = mask_sensitive_info(e.response.text)
+ setattr(e, "message", error_text)
+ setattr(e, "text", error_text)
+
+ setattr(e, "status_code", e.response.status_code)
+
+ raise e
+ except Exception as e:
+ raise e
+
+ def patch(
+ self,
+ url: str,
+ data: Optional[Union[dict, str]] = None,
+ json: Optional[Union[dict, str]] = None,
+ params: Optional[dict] = None,
+ headers: Optional[dict] = None,
+ stream: bool = False,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ ):
+ try:
+
+ if timeout is not None:
+ req = self.client.build_request(
+ "PATCH", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
+ )
+ else:
+ req = self.client.build_request(
+ "PATCH", url, data=data, json=json, params=params, headers=headers # type: ignore
+ )
+ response = self.client.send(req, stream=stream)
+ response.raise_for_status()
+ return response
+ except httpx.TimeoutException:
+ raise litellm.Timeout(
+ message=f"Connection timed out after {timeout} seconds.",
+ model="default-model-name",
+ llm_provider="litellm-httpx-handler",
+ )
+ except httpx.HTTPStatusError as e:
+
+ if stream is True:
+ setattr(e, "message", mask_sensitive_info(e.response.read()))
+ setattr(e, "text", mask_sensitive_info(e.response.read()))
+ else:
+ error_text = mask_sensitive_info(e.response.text)
+ setattr(e, "message", error_text)
+ setattr(e, "text", error_text)
+
+ setattr(e, "status_code", e.response.status_code)
+
+ raise e
+ except Exception as e:
+ raise e
+
+ def put(
+ self,
+ url: str,
+ data: Optional[Union[dict, str]] = None,
+ json: Optional[Union[dict, str]] = None,
+ params: Optional[dict] = None,
+ headers: Optional[dict] = None,
+ stream: bool = False,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ ):
+ try:
+
+ if timeout is not None:
+ req = self.client.build_request(
+ "PUT", url, data=data, json=json, params=params, headers=headers, timeout=timeout # type: ignore
+ )
+ else:
+ req = self.client.build_request(
+ "PUT", url, data=data, json=json, params=params, headers=headers # type: ignore
+ )
+ response = self.client.send(req, stream=stream)
+ return response
+ except httpx.TimeoutException:
+ raise litellm.Timeout(
+ message=f"Connection timed out after {timeout} seconds.",
+ model="default-model-name",
+ llm_provider="litellm-httpx-handler",
+ )
+ except Exception as e:
+ raise e
+
+ def __del__(self) -> None:
+ try:
+ self.close()
+ except Exception:
+ pass
+
+ def _create_sync_transport(self) -> Optional[HTTPTransport]:
+ """
+ Create an HTTP transport with IPv4 only if litellm.force_ipv4 is True.
+ Otherwise, return None.
+
+ Some users have seen httpx ConnectionError when using ipv6 - forcing ipv4 resolves the issue for them
+ """
+ if litellm.force_ipv4:
+ return HTTPTransport(local_address="0.0.0.0")
+ else:
+ return None
+
+
+def get_async_httpx_client(
+ llm_provider: Union[LlmProviders, httpxSpecialProvider],
+ params: Optional[dict] = None,
+) -> AsyncHTTPHandler:
+ """
+ Retrieves the async HTTP client from the cache
+ If not present, creates a new client
+
+ Caches the new client and returns it.
+ """
+ _params_key_name = ""
+ if params is not None:
+ for key, value in params.items():
+ try:
+ _params_key_name += f"{key}_{value}"
+ except Exception:
+ pass
+
+ _cache_key_name = "async_httpx_client" + _params_key_name + llm_provider
+ _cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key_name)
+ if _cached_client:
+ return _cached_client
+
+ if params is not None:
+ _new_client = AsyncHTTPHandler(**params)
+ else:
+ _new_client = AsyncHTTPHandler(
+ timeout=httpx.Timeout(timeout=600.0, connect=5.0)
+ )
+
+ litellm.in_memory_llm_clients_cache.set_cache(
+ key=_cache_key_name,
+ value=_new_client,
+ ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
+ )
+ return _new_client
+
+
+def _get_httpx_client(params: Optional[dict] = None) -> HTTPHandler:
+ """
+ Retrieves the HTTP client from the cache
+ If not present, creates a new client
+
+ Caches the new client and returns it.
+ """
+ _params_key_name = ""
+ if params is not None:
+ for key, value in params.items():
+ try:
+ _params_key_name += f"{key}_{value}"
+ except Exception:
+ pass
+
+ _cache_key_name = "httpx_client" + _params_key_name
+
+ _cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key_name)
+ if _cached_client:
+ return _cached_client
+
+ if params is not None:
+ _new_client = HTTPHandler(**params)
+ else:
+ _new_client = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
+
+ litellm.in_memory_llm_clients_cache.set_cache(
+ key=_cache_key_name,
+ value=_new_client,
+ ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
+ )
+ return _new_client
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/httpx_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/httpx_handler.py
new file mode 100644
index 00000000..6f684ba0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/httpx_handler.py
@@ -0,0 +1,49 @@
+from typing import Optional, Union
+
+import httpx
+
+try:
+ from litellm._version import version
+except Exception:
+ version = "0.0.0"
+
+headers = {
+ "User-Agent": f"litellm/{version}",
+}
+
+
+class HTTPHandler:
+ def __init__(self, concurrent_limit=1000):
+ # Create a client with a connection pool
+ self.client = httpx.AsyncClient(
+ limits=httpx.Limits(
+ max_connections=concurrent_limit,
+ max_keepalive_connections=concurrent_limit,
+ ),
+ headers=headers,
+ )
+
+ async def close(self):
+ # Close the client when you're done with it
+ await self.client.aclose()
+
+ async def get(
+ self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
+ ):
+ response = await self.client.get(url, params=params, headers=headers)
+ return response
+
+ async def post(
+ self,
+ url: str,
+ data: Optional[Union[dict, str]] = None,
+ params: Optional[dict] = None,
+ headers: Optional[dict] = None,
+ ):
+ try:
+ response = await self.client.post(
+ url, data=data, params=params, headers=headers # type: ignore
+ )
+ return response
+ except Exception as e:
+ raise e
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/llm_http_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/llm_http_handler.py
new file mode 100644
index 00000000..00caf552
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/custom_httpx/llm_http_handler.py
@@ -0,0 +1,1260 @@
+import io
+import json
+from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Union
+
+import httpx # type: ignore
+
+import litellm
+import litellm.litellm_core_utils
+import litellm.types
+import litellm.types.utils
+from litellm.llms.base_llm.chat.transformation import BaseConfig
+from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
+from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
+from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ _get_httpx_client,
+ get_async_httpx_client,
+)
+from litellm.responses.streaming_iterator import (
+ BaseResponsesAPIStreamingIterator,
+ MockResponsesAPIStreamingIterator,
+ ResponsesAPIStreamingIterator,
+ SyncResponsesAPIStreamingIterator,
+)
+from litellm.types.llms.openai import ResponseInputParam, ResponsesAPIResponse
+from litellm.types.rerank import OptionalRerankParams, RerankResponse
+from litellm.types.router import GenericLiteLLMParams
+from litellm.types.utils import EmbeddingResponse, FileTypes, TranscriptionResponse
+from litellm.utils import CustomStreamWrapper, ModelResponse, ProviderConfigManager
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
+
+ LiteLLMLoggingObj = _LiteLLMLoggingObj
+else:
+ LiteLLMLoggingObj = Any
+
+
+class BaseLLMHTTPHandler:
+
+ async def _make_common_async_call(
+ self,
+ async_httpx_client: AsyncHTTPHandler,
+ provider_config: BaseConfig,
+ api_base: str,
+ headers: dict,
+ data: dict,
+ timeout: Union[float, httpx.Timeout],
+ litellm_params: dict,
+ logging_obj: LiteLLMLoggingObj,
+ stream: bool = False,
+ ) -> httpx.Response:
+ """Common implementation across stream + non-stream calls. Meant to ensure consistent error-handling."""
+ max_retry_on_unprocessable_entity_error = (
+ provider_config.max_retry_on_unprocessable_entity_error
+ )
+
+ response: Optional[httpx.Response] = None
+ for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
+ try:
+ response = await async_httpx_client.post(
+ url=api_base,
+ headers=headers,
+ data=json.dumps(data),
+ timeout=timeout,
+ stream=stream,
+ logging_obj=logging_obj,
+ )
+ except httpx.HTTPStatusError as e:
+ hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
+ should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
+ e=e, litellm_params=litellm_params
+ )
+ if should_retry and not hit_max_retry:
+ data = (
+ provider_config.transform_request_on_unprocessable_entity_error(
+ e=e, request_data=data
+ )
+ )
+ continue
+ else:
+ raise self._handle_error(e=e, provider_config=provider_config)
+ except Exception as e:
+ raise self._handle_error(e=e, provider_config=provider_config)
+ break
+
+ if response is None:
+ raise provider_config.get_error_class(
+ error_message="No response from the API",
+ status_code=422, # don't retry on this error
+ headers={},
+ )
+
+ return response
+
+ def _make_common_sync_call(
+ self,
+ sync_httpx_client: HTTPHandler,
+ provider_config: BaseConfig,
+ api_base: str,
+ headers: dict,
+ data: dict,
+ timeout: Union[float, httpx.Timeout],
+ litellm_params: dict,
+ logging_obj: LiteLLMLoggingObj,
+ stream: bool = False,
+ ) -> httpx.Response:
+
+ max_retry_on_unprocessable_entity_error = (
+ provider_config.max_retry_on_unprocessable_entity_error
+ )
+
+ response: Optional[httpx.Response] = None
+
+ for i in range(max(max_retry_on_unprocessable_entity_error, 1)):
+ try:
+ response = sync_httpx_client.post(
+ url=api_base,
+ headers=headers,
+ data=json.dumps(data),
+ timeout=timeout,
+ stream=stream,
+ logging_obj=logging_obj,
+ )
+ except httpx.HTTPStatusError as e:
+ hit_max_retry = i + 1 == max_retry_on_unprocessable_entity_error
+ should_retry = provider_config.should_retry_llm_api_inside_llm_translation_on_http_error(
+ e=e, litellm_params=litellm_params
+ )
+ if should_retry and not hit_max_retry:
+ data = (
+ provider_config.transform_request_on_unprocessable_entity_error(
+ e=e, request_data=data
+ )
+ )
+ continue
+ else:
+ raise self._handle_error(e=e, provider_config=provider_config)
+ except Exception as e:
+ raise self._handle_error(e=e, provider_config=provider_config)
+ break
+
+ if response is None:
+ raise provider_config.get_error_class(
+ error_message="No response from the API",
+ status_code=422, # don't retry on this error
+ headers={},
+ )
+
+ return response
+
+ async def async_completion(
+ self,
+ custom_llm_provider: str,
+ provider_config: BaseConfig,
+ api_base: str,
+ headers: dict,
+ data: dict,
+ timeout: Union[float, httpx.Timeout],
+ model: str,
+ model_response: ModelResponse,
+ logging_obj: LiteLLMLoggingObj,
+ messages: list,
+ optional_params: dict,
+ litellm_params: dict,
+ encoding: Any,
+ api_key: Optional[str] = None,
+ client: Optional[AsyncHTTPHandler] = None,
+ json_mode: bool = False,
+ ):
+ if client is None:
+ async_httpx_client = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders(custom_llm_provider),
+ params={"ssl_verify": litellm_params.get("ssl_verify", None)},
+ )
+ else:
+ async_httpx_client = client
+
+ response = await self._make_common_async_call(
+ async_httpx_client=async_httpx_client,
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers,
+ data=data,
+ timeout=timeout,
+ litellm_params=litellm_params,
+ stream=False,
+ logging_obj=logging_obj,
+ )
+ return provider_config.transform_response(
+ model=model,
+ raw_response=response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ request_data=data,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=encoding,
+ json_mode=json_mode,
+ )
+
+ def completion(
+ self,
+ model: str,
+ messages: list,
+ api_base: str,
+ custom_llm_provider: str,
+ model_response: ModelResponse,
+ encoding,
+ logging_obj: LiteLLMLoggingObj,
+ optional_params: dict,
+ timeout: Union[float, httpx.Timeout],
+ litellm_params: dict,
+ acompletion: bool,
+ stream: Optional[bool] = False,
+ fake_stream: bool = False,
+ api_key: Optional[str] = None,
+ headers: Optional[dict] = {},
+ client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+ ):
+ json_mode: bool = optional_params.pop("json_mode", False)
+
+ provider_config = ProviderConfigManager.get_provider_chat_config(
+ model=model, provider=litellm.LlmProviders(custom_llm_provider)
+ )
+
+ # get config from model, custom llm provider
+ headers = provider_config.validate_environment(
+ api_key=api_key,
+ headers=headers or {},
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ api_base=api_base,
+ )
+
+ api_base = provider_config.get_complete_url(
+ api_base=api_base,
+ model=model,
+ optional_params=optional_params,
+ stream=stream,
+ litellm_params=litellm_params,
+ )
+
+ data = provider_config.transform_request(
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ headers=headers,
+ )
+
+ headers = provider_config.sign_request(
+ headers=headers,
+ optional_params=optional_params,
+ request_data=data,
+ api_base=api_base,
+ stream=stream,
+ fake_stream=fake_stream,
+ model=model,
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=messages,
+ api_key=api_key,
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": api_base,
+ "headers": headers,
+ },
+ )
+
+ if acompletion is True:
+ if stream is True:
+ data = self._add_stream_param_to_request_body(
+ data=data,
+ provider_config=provider_config,
+ fake_stream=fake_stream,
+ )
+ return self.acompletion_stream_function(
+ model=model,
+ messages=messages,
+ api_base=api_base,
+ headers=headers,
+ custom_llm_provider=custom_llm_provider,
+ provider_config=provider_config,
+ timeout=timeout,
+ logging_obj=logging_obj,
+ data=data,
+ fake_stream=fake_stream,
+ client=(
+ client
+ if client is not None and isinstance(client, AsyncHTTPHandler)
+ else None
+ ),
+ litellm_params=litellm_params,
+ json_mode=json_mode,
+ )
+
+ else:
+ return self.async_completion(
+ custom_llm_provider=custom_llm_provider,
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers,
+ data=data,
+ timeout=timeout,
+ model=model,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=encoding,
+ client=(
+ client
+ if client is not None and isinstance(client, AsyncHTTPHandler)
+ else None
+ ),
+ json_mode=json_mode,
+ )
+
+ if stream is True:
+ data = self._add_stream_param_to_request_body(
+ data=data,
+ provider_config=provider_config,
+ fake_stream=fake_stream,
+ )
+ if provider_config.has_custom_stream_wrapper is True:
+ return provider_config.get_sync_custom_stream_wrapper(
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ logging_obj=logging_obj,
+ api_base=api_base,
+ headers=headers,
+ data=data,
+ messages=messages,
+ client=client,
+ json_mode=json_mode,
+ )
+ completion_stream, headers = self.make_sync_call(
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers, # type: ignore
+ data=data,
+ model=model,
+ messages=messages,
+ logging_obj=logging_obj,
+ timeout=timeout,
+ fake_stream=fake_stream,
+ client=(
+ client
+ if client is not None and isinstance(client, HTTPHandler)
+ else None
+ ),
+ litellm_params=litellm_params,
+ )
+ return CustomStreamWrapper(
+ completion_stream=completion_stream,
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ logging_obj=logging_obj,
+ )
+
+ if client is None or not isinstance(client, HTTPHandler):
+ sync_httpx_client = _get_httpx_client(
+ params={"ssl_verify": litellm_params.get("ssl_verify", None)}
+ )
+ else:
+ sync_httpx_client = client
+
+ response = self._make_common_sync_call(
+ sync_httpx_client=sync_httpx_client,
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers,
+ data=data,
+ timeout=timeout,
+ litellm_params=litellm_params,
+ logging_obj=logging_obj,
+ )
+ return provider_config.transform_response(
+ model=model,
+ raw_response=response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ request_data=data,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=encoding,
+ json_mode=json_mode,
+ )
+
+ def make_sync_call(
+ self,
+ provider_config: BaseConfig,
+ api_base: str,
+ headers: dict,
+ data: dict,
+ model: str,
+ messages: list,
+ logging_obj,
+ litellm_params: dict,
+ timeout: Union[float, httpx.Timeout],
+ fake_stream: bool = False,
+ client: Optional[HTTPHandler] = None,
+ ) -> Tuple[Any, dict]:
+ if client is None or not isinstance(client, HTTPHandler):
+ sync_httpx_client = _get_httpx_client(
+ {
+ "ssl_verify": litellm_params.get("ssl_verify", None),
+ }
+ )
+ else:
+ sync_httpx_client = client
+ stream = True
+ if fake_stream is True:
+ stream = False
+
+ response = self._make_common_sync_call(
+ sync_httpx_client=sync_httpx_client,
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers,
+ data=data,
+ timeout=timeout,
+ litellm_params=litellm_params,
+ stream=stream,
+ logging_obj=logging_obj,
+ )
+
+ if fake_stream is True:
+ completion_stream = provider_config.get_model_response_iterator(
+ streaming_response=response.json(), sync_stream=True
+ )
+ else:
+ completion_stream = provider_config.get_model_response_iterator(
+ streaming_response=response.iter_lines(), sync_stream=True
+ )
+
+ # LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response="first stream response received",
+ additional_args={"complete_input_dict": data},
+ )
+
+ return completion_stream, dict(response.headers)
+
+ async def acompletion_stream_function(
+ self,
+ model: str,
+ messages: list,
+ api_base: str,
+ custom_llm_provider: str,
+ headers: dict,
+ provider_config: BaseConfig,
+ timeout: Union[float, httpx.Timeout],
+ logging_obj: LiteLLMLoggingObj,
+ data: dict,
+ litellm_params: dict,
+ fake_stream: bool = False,
+ client: Optional[AsyncHTTPHandler] = None,
+ json_mode: Optional[bool] = None,
+ ):
+ if provider_config.has_custom_stream_wrapper is True:
+ return provider_config.get_async_custom_stream_wrapper(
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ logging_obj=logging_obj,
+ api_base=api_base,
+ headers=headers,
+ data=data,
+ messages=messages,
+ client=client,
+ json_mode=json_mode,
+ )
+
+ completion_stream, _response_headers = await self.make_async_call_stream_helper(
+ custom_llm_provider=custom_llm_provider,
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers,
+ data=data,
+ messages=messages,
+ logging_obj=logging_obj,
+ timeout=timeout,
+ fake_stream=fake_stream,
+ client=client,
+ litellm_params=litellm_params,
+ )
+ streamwrapper = CustomStreamWrapper(
+ completion_stream=completion_stream,
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ logging_obj=logging_obj,
+ )
+ return streamwrapper
+
+ async def make_async_call_stream_helper(
+ self,
+ custom_llm_provider: str,
+ provider_config: BaseConfig,
+ api_base: str,
+ headers: dict,
+ data: dict,
+ messages: list,
+ logging_obj: LiteLLMLoggingObj,
+ timeout: Union[float, httpx.Timeout],
+ litellm_params: dict,
+ fake_stream: bool = False,
+ client: Optional[AsyncHTTPHandler] = None,
+ ) -> Tuple[Any, httpx.Headers]:
+ """
+ Helper function for making an async call with stream.
+
+ Handles fake stream as well.
+ """
+ if client is None:
+ async_httpx_client = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders(custom_llm_provider),
+ params={"ssl_verify": litellm_params.get("ssl_verify", None)},
+ )
+ else:
+ async_httpx_client = client
+ stream = True
+ if fake_stream is True:
+ stream = False
+
+ response = await self._make_common_async_call(
+ async_httpx_client=async_httpx_client,
+ provider_config=provider_config,
+ api_base=api_base,
+ headers=headers,
+ data=data,
+ timeout=timeout,
+ litellm_params=litellm_params,
+ stream=stream,
+ logging_obj=logging_obj,
+ )
+
+ if fake_stream is True:
+ completion_stream = provider_config.get_model_response_iterator(
+ streaming_response=response.json(), sync_stream=False
+ )
+ else:
+ completion_stream = provider_config.get_model_response_iterator(
+ streaming_response=response.aiter_lines(), sync_stream=False
+ )
+ # LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response="first stream response received",
+ additional_args={"complete_input_dict": data},
+ )
+
+ return completion_stream, response.headers
+
+ def _add_stream_param_to_request_body(
+ self,
+ data: dict,
+ provider_config: BaseConfig,
+ fake_stream: bool,
+ ) -> dict:
+ """
+ Some providers like Bedrock invoke do not support the stream parameter in the request body, we only pass `stream` in the request body the provider supports it.
+ """
+ if fake_stream is True:
+ return data
+ if provider_config.supports_stream_param_in_request_body is True:
+ data["stream"] = True
+ return data
+
+ def embedding(
+ self,
+ model: str,
+ input: list,
+ timeout: float,
+ custom_llm_provider: str,
+ logging_obj: LiteLLMLoggingObj,
+ api_base: Optional[str],
+ optional_params: dict,
+ litellm_params: dict,
+ model_response: EmbeddingResponse,
+ api_key: Optional[str] = None,
+ client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+ aembedding: bool = False,
+ headers={},
+ ) -> EmbeddingResponse:
+
+ provider_config = ProviderConfigManager.get_provider_embedding_config(
+ model=model, provider=litellm.LlmProviders(custom_llm_provider)
+ )
+ # get config from model, custom llm provider
+ headers = provider_config.validate_environment(
+ api_key=api_key,
+ headers=headers,
+ model=model,
+ messages=[],
+ optional_params=optional_params,
+ )
+
+ api_base = provider_config.get_complete_url(
+ api_base=api_base,
+ model=model,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ )
+
+ data = provider_config.transform_embedding_request(
+ model=model,
+ input=input,
+ optional_params=optional_params,
+ headers=headers,
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=input,
+ api_key=api_key,
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": api_base,
+ "headers": headers,
+ },
+ )
+
+ if aembedding is True:
+ return self.aembedding( # type: ignore
+ request_data=data,
+ api_base=api_base,
+ headers=headers,
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ provider_config=provider_config,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ timeout=timeout,
+ client=client,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ )
+
+ if client is None or not isinstance(client, HTTPHandler):
+ sync_httpx_client = _get_httpx_client()
+ else:
+ sync_httpx_client = client
+
+ try:
+ response = sync_httpx_client.post(
+ url=api_base,
+ headers=headers,
+ data=json.dumps(data),
+ timeout=timeout,
+ )
+ except Exception as e:
+ raise self._handle_error(
+ e=e,
+ provider_config=provider_config,
+ )
+
+ return provider_config.transform_embedding_response(
+ model=model,
+ raw_response=response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ request_data=data,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ )
+
+ async def aembedding(
+ self,
+ request_data: dict,
+ api_base: str,
+ headers: dict,
+ model: str,
+ custom_llm_provider: str,
+ provider_config: BaseEmbeddingConfig,
+ model_response: EmbeddingResponse,
+ logging_obj: LiteLLMLoggingObj,
+ optional_params: dict,
+ litellm_params: dict,
+ api_key: Optional[str] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+ ) -> EmbeddingResponse:
+ if client is None or not isinstance(client, AsyncHTTPHandler):
+ async_httpx_client = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders(custom_llm_provider)
+ )
+ else:
+ async_httpx_client = client
+
+ try:
+ response = await async_httpx_client.post(
+ url=api_base,
+ headers=headers,
+ data=json.dumps(request_data),
+ timeout=timeout,
+ )
+ except Exception as e:
+ raise self._handle_error(e=e, provider_config=provider_config)
+
+ return provider_config.transform_embedding_response(
+ model=model,
+ raw_response=response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ request_data=request_data,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ )
+
+ def rerank(
+ self,
+ model: str,
+ custom_llm_provider: str,
+ logging_obj: LiteLLMLoggingObj,
+ provider_config: BaseRerankConfig,
+ optional_rerank_params: OptionalRerankParams,
+ timeout: Optional[Union[float, httpx.Timeout]],
+ model_response: RerankResponse,
+ _is_async: bool = False,
+ headers: dict = {},
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+ ) -> RerankResponse:
+
+ # get config from model, custom llm provider
+ headers = provider_config.validate_environment(
+ api_key=api_key,
+ headers=headers,
+ model=model,
+ )
+
+ api_base = provider_config.get_complete_url(
+ api_base=api_base,
+ model=model,
+ )
+
+ data = provider_config.transform_rerank_request(
+ model=model,
+ optional_rerank_params=optional_rerank_params,
+ headers=headers,
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=optional_rerank_params.get("query", ""),
+ api_key=api_key,
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": api_base,
+ "headers": headers,
+ },
+ )
+
+ if _is_async is True:
+ return self.arerank( # type: ignore
+ model=model,
+ request_data=data,
+ custom_llm_provider=custom_llm_provider,
+ provider_config=provider_config,
+ logging_obj=logging_obj,
+ model_response=model_response,
+ api_base=api_base,
+ headers=headers,
+ api_key=api_key,
+ timeout=timeout,
+ client=client,
+ )
+
+ if client is None or not isinstance(client, HTTPHandler):
+ sync_httpx_client = _get_httpx_client()
+ else:
+ sync_httpx_client = client
+
+ try:
+ response = sync_httpx_client.post(
+ url=api_base,
+ headers=headers,
+ data=json.dumps(data),
+ timeout=timeout,
+ )
+ except Exception as e:
+ raise self._handle_error(
+ e=e,
+ provider_config=provider_config,
+ )
+
+ return provider_config.transform_rerank_response(
+ model=model,
+ raw_response=response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ request_data=data,
+ )
+
+ async def arerank(
+ self,
+ model: str,
+ request_data: dict,
+ custom_llm_provider: str,
+ provider_config: BaseRerankConfig,
+ logging_obj: LiteLLMLoggingObj,
+ model_response: RerankResponse,
+ api_base: str,
+ headers: dict,
+ api_key: Optional[str] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+ ) -> RerankResponse:
+
+ if client is None or not isinstance(client, AsyncHTTPHandler):
+ async_httpx_client = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders(custom_llm_provider)
+ )
+ else:
+ async_httpx_client = client
+ try:
+ response = await async_httpx_client.post(
+ url=api_base,
+ headers=headers,
+ data=json.dumps(request_data),
+ timeout=timeout,
+ )
+ except Exception as e:
+ raise self._handle_error(e=e, provider_config=provider_config)
+
+ return provider_config.transform_rerank_response(
+ model=model,
+ raw_response=response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ request_data=request_data,
+ )
+
+ def handle_audio_file(self, audio_file: FileTypes) -> bytes:
+ """
+ Processes the audio file input based on its type and returns the binary data.
+
+ Args:
+ audio_file: Can be a file path (str), a tuple (filename, file_content), or binary data (bytes).
+
+ Returns:
+ The binary data of the audio file.
+ """
+ binary_data: bytes # Explicitly declare the type
+
+ # Handle the audio file based on type
+ if isinstance(audio_file, str):
+ # If it's a file path
+ with open(audio_file, "rb") as f:
+ binary_data = f.read() # `f.read()` always returns `bytes`
+ elif isinstance(audio_file, tuple):
+ # Handle tuple case
+ _, file_content = audio_file[:2]
+ if isinstance(file_content, str):
+ with open(file_content, "rb") as f:
+ binary_data = f.read() # `f.read()` always returns `bytes`
+ elif isinstance(file_content, bytes):
+ binary_data = file_content
+ else:
+ raise TypeError(
+ f"Unexpected type in tuple: {type(file_content)}. Expected str or bytes."
+ )
+ elif isinstance(audio_file, bytes):
+ # Assume it's already binary data
+ binary_data = audio_file
+ elif isinstance(audio_file, io.BufferedReader) or isinstance(
+ audio_file, io.BytesIO
+ ):
+ # Handle file-like objects
+ binary_data = audio_file.read()
+
+ else:
+ raise TypeError(f"Unsupported type for audio_file: {type(audio_file)}")
+
+ return binary_data
+
+ def audio_transcriptions(
+ self,
+ model: str,
+ audio_file: FileTypes,
+ optional_params: dict,
+ model_response: TranscriptionResponse,
+ timeout: float,
+ max_retries: int,
+ logging_obj: LiteLLMLoggingObj,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ custom_llm_provider: str,
+ client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+ atranscription: bool = False,
+ headers: dict = {},
+ litellm_params: dict = {},
+ ) -> TranscriptionResponse:
+ provider_config = ProviderConfigManager.get_provider_audio_transcription_config(
+ model=model, provider=litellm.LlmProviders(custom_llm_provider)
+ )
+ if provider_config is None:
+ raise ValueError(
+ f"No provider config found for model: {model} and provider: {custom_llm_provider}"
+ )
+ headers = provider_config.validate_environment(
+ api_key=api_key,
+ headers=headers,
+ model=model,
+ messages=[],
+ optional_params=optional_params,
+ )
+
+ if client is None or not isinstance(client, HTTPHandler):
+ client = _get_httpx_client()
+
+ complete_url = provider_config.get_complete_url(
+ api_base=api_base,
+ model=model,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ )
+
+ # Handle the audio file based on type
+ binary_data = self.handle_audio_file(audio_file)
+
+ try:
+ # Make the POST request
+ response = client.post(
+ url=complete_url,
+ headers=headers,
+ content=binary_data,
+ timeout=timeout,
+ )
+ except Exception as e:
+ raise self._handle_error(e=e, provider_config=provider_config)
+
+ if isinstance(provider_config, litellm.DeepgramAudioTranscriptionConfig):
+ returned_response = provider_config.transform_audio_transcription_response(
+ model=model,
+ raw_response=response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ request_data={},
+ optional_params=optional_params,
+ litellm_params={},
+ api_key=api_key,
+ )
+ return returned_response
+ return model_response
+
+ def response_api_handler(
+ self,
+ model: str,
+ input: Union[str, ResponseInputParam],
+ responses_api_provider_config: BaseResponsesAPIConfig,
+ response_api_optional_request_params: Dict,
+ custom_llm_provider: str,
+ litellm_params: GenericLiteLLMParams,
+ logging_obj: LiteLLMLoggingObj,
+ extra_headers: Optional[Dict[str, Any]] = None,
+ extra_body: Optional[Dict[str, Any]] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+ _is_async: bool = False,
+ fake_stream: bool = False,
+ ) -> Union[
+ ResponsesAPIResponse,
+ BaseResponsesAPIStreamingIterator,
+ Coroutine[
+ Any, Any, Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]
+ ],
+ ]:
+ """
+ Handles responses API requests.
+ When _is_async=True, returns a coroutine instead of making the call directly.
+ """
+ if _is_async:
+ # Return the async coroutine if called with _is_async=True
+ return self.async_response_api_handler(
+ model=model,
+ input=input,
+ responses_api_provider_config=responses_api_provider_config,
+ response_api_optional_request_params=response_api_optional_request_params,
+ custom_llm_provider=custom_llm_provider,
+ litellm_params=litellm_params,
+ logging_obj=logging_obj,
+ extra_headers=extra_headers,
+ extra_body=extra_body,
+ timeout=timeout,
+ client=client if isinstance(client, AsyncHTTPHandler) else None,
+ fake_stream=fake_stream,
+ )
+
+ if client is None or not isinstance(client, HTTPHandler):
+ sync_httpx_client = _get_httpx_client(
+ params={"ssl_verify": litellm_params.get("ssl_verify", None)}
+ )
+ else:
+ sync_httpx_client = client
+
+ headers = responses_api_provider_config.validate_environment(
+ api_key=litellm_params.api_key,
+ headers=response_api_optional_request_params.get("extra_headers", {}) or {},
+ model=model,
+ )
+
+ if extra_headers:
+ headers.update(extra_headers)
+
+ api_base = responses_api_provider_config.get_complete_url(
+ api_base=litellm_params.api_base,
+ model=model,
+ )
+
+ data = responses_api_provider_config.transform_responses_api_request(
+ model=model,
+ input=input,
+ response_api_optional_request_params=response_api_optional_request_params,
+ litellm_params=litellm_params,
+ headers=headers,
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=input,
+ api_key="",
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": api_base,
+ "headers": headers,
+ },
+ )
+
+ # Check if streaming is requested
+ stream = response_api_optional_request_params.get("stream", False)
+
+ try:
+ if stream:
+ # For streaming, use stream=True in the request
+ if fake_stream is True:
+ stream, data = self._prepare_fake_stream_request(
+ stream=stream,
+ data=data,
+ fake_stream=fake_stream,
+ )
+ response = sync_httpx_client.post(
+ url=api_base,
+ headers=headers,
+ data=json.dumps(data),
+ timeout=timeout
+ or response_api_optional_request_params.get("timeout"),
+ stream=stream,
+ )
+ if fake_stream is True:
+ return MockResponsesAPIStreamingIterator(
+ response=response,
+ model=model,
+ logging_obj=logging_obj,
+ responses_api_provider_config=responses_api_provider_config,
+ )
+
+ return SyncResponsesAPIStreamingIterator(
+ response=response,
+ model=model,
+ logging_obj=logging_obj,
+ responses_api_provider_config=responses_api_provider_config,
+ )
+ else:
+ # For non-streaming requests
+ response = sync_httpx_client.post(
+ url=api_base,
+ headers=headers,
+ data=json.dumps(data),
+ timeout=timeout
+ or response_api_optional_request_params.get("timeout"),
+ )
+ except Exception as e:
+ raise self._handle_error(
+ e=e,
+ provider_config=responses_api_provider_config,
+ )
+
+ return responses_api_provider_config.transform_response_api_response(
+ model=model,
+ raw_response=response,
+ logging_obj=logging_obj,
+ )
+
+ async def async_response_api_handler(
+ self,
+ model: str,
+ input: Union[str, ResponseInputParam],
+ responses_api_provider_config: BaseResponsesAPIConfig,
+ response_api_optional_request_params: Dict,
+ custom_llm_provider: str,
+ litellm_params: GenericLiteLLMParams,
+ logging_obj: LiteLLMLoggingObj,
+ extra_headers: Optional[Dict[str, Any]] = None,
+ extra_body: Optional[Dict[str, Any]] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+ fake_stream: bool = False,
+ ) -> Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]:
+ """
+ Async version of the responses API handler.
+ Uses async HTTP client to make requests.
+ """
+ if client is None or not isinstance(client, AsyncHTTPHandler):
+ async_httpx_client = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders(custom_llm_provider),
+ params={"ssl_verify": litellm_params.get("ssl_verify", None)},
+ )
+ else:
+ async_httpx_client = client
+
+ headers = responses_api_provider_config.validate_environment(
+ api_key=litellm_params.api_key,
+ headers=response_api_optional_request_params.get("extra_headers", {}) or {},
+ model=model,
+ )
+
+ if extra_headers:
+ headers.update(extra_headers)
+
+ api_base = responses_api_provider_config.get_complete_url(
+ api_base=litellm_params.api_base,
+ model=model,
+ )
+
+ data = responses_api_provider_config.transform_responses_api_request(
+ model=model,
+ input=input,
+ response_api_optional_request_params=response_api_optional_request_params,
+ litellm_params=litellm_params,
+ headers=headers,
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=input,
+ api_key="",
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": api_base,
+ "headers": headers,
+ },
+ )
+ # Check if streaming is requested
+ stream = response_api_optional_request_params.get("stream", False)
+
+ try:
+ if stream:
+ # For streaming, we need to use stream=True in the request
+ if fake_stream is True:
+ stream, data = self._prepare_fake_stream_request(
+ stream=stream,
+ data=data,
+ fake_stream=fake_stream,
+ )
+
+ response = await async_httpx_client.post(
+ url=api_base,
+ headers=headers,
+ data=json.dumps(data),
+ timeout=timeout
+ or response_api_optional_request_params.get("timeout"),
+ stream=stream,
+ )
+
+ if fake_stream is True:
+ return MockResponsesAPIStreamingIterator(
+ response=response,
+ model=model,
+ logging_obj=logging_obj,
+ responses_api_provider_config=responses_api_provider_config,
+ )
+
+ # Return the streaming iterator
+ return ResponsesAPIStreamingIterator(
+ response=response,
+ model=model,
+ logging_obj=logging_obj,
+ responses_api_provider_config=responses_api_provider_config,
+ )
+ else:
+ # For non-streaming, proceed as before
+ response = await async_httpx_client.post(
+ url=api_base,
+ headers=headers,
+ data=json.dumps(data),
+ timeout=timeout
+ or response_api_optional_request_params.get("timeout"),
+ )
+
+ except Exception as e:
+ raise self._handle_error(
+ e=e,
+ provider_config=responses_api_provider_config,
+ )
+
+ return responses_api_provider_config.transform_response_api_response(
+ model=model,
+ raw_response=response,
+ logging_obj=logging_obj,
+ )
+
+ def _prepare_fake_stream_request(
+ self,
+ stream: bool,
+ data: dict,
+ fake_stream: bool,
+ ) -> Tuple[bool, dict]:
+ """
+ Handles preparing a request when `fake_stream` is True.
+ """
+ if fake_stream is True:
+ stream = False
+ data.pop("stream", None)
+ return stream, data
+ return stream, data
+
+ def _handle_error(
+ self,
+ e: Exception,
+ provider_config: Union[BaseConfig, BaseRerankConfig, BaseResponsesAPIConfig],
+ ):
+ status_code = getattr(e, "status_code", 500)
+ error_headers = getattr(e, "headers", None)
+ error_text = getattr(e, "text", str(e))
+ error_response = getattr(e, "response", None)
+ if error_headers is None and error_response:
+ error_headers = getattr(error_response, "headers", None)
+ if error_response and hasattr(error_response, "text"):
+ error_text = getattr(error_response, "text", error_text)
+ if error_headers:
+ error_headers = dict(error_headers)
+ else:
+ error_headers = {}
+ raise provider_config.get_error_class(
+ error_message=error_text,
+ status_code=status_code,
+ headers=error_headers,
+ )