diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/openai/common_utils.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/openai/common_utils.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/openai/common_utils.py | 208 |
1 files changed, 208 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/openai/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/openai/common_utils.py new file mode 100644 index 00000000..55da16d6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/openai/common_utils.py @@ -0,0 +1,208 @@ +""" +Common helpers / utils across al OpenAI endpoints +""" + +import hashlib +import json +from typing import Any, Dict, List, Literal, Optional, Union + +import httpx +import openai +from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI + +import litellm +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS + + +class OpenAIError(BaseLLMException): + def __init__( + self, + status_code: int, + message: str, + request: Optional[httpx.Request] = None, + response: Optional[httpx.Response] = None, + headers: Optional[Union[dict, httpx.Headers]] = None, + body: Optional[dict] = None, + ): + self.status_code = status_code + self.message = message + self.headers = headers + if request: + self.request = request + else: + self.request = httpx.Request(method="POST", url="https://api.openai.com/v1") + if response: + self.response = response + else: + self.response = httpx.Response( + status_code=status_code, request=self.request + ) + super().__init__( + status_code=status_code, + message=self.message, + headers=self.headers, + request=self.request, + response=self.response, + body=body, + ) + + +####### Error Handling Utils for OpenAI API ####################### +################################################################### +def drop_params_from_unprocessable_entity_error( + e: Union[openai.UnprocessableEntityError, httpx.HTTPStatusError], + data: Dict[str, Any], +) -> Dict[str, Any]: + """ + Helper function to read OpenAI UnprocessableEntityError and drop the params that raised an error from the error message. + + Args: + e (UnprocessableEntityError): The UnprocessableEntityError exception + data (Dict[str, Any]): The original data dictionary containing all parameters + + Returns: + Dict[str, Any]: A new dictionary with invalid parameters removed + """ + invalid_params: List[str] = [] + if isinstance(e, httpx.HTTPStatusError): + error_json = e.response.json() + error_message = error_json.get("error", {}) + error_body = error_message + else: + error_body = e.body + if ( + error_body is not None + and isinstance(error_body, dict) + and error_body.get("message") + ): + message = error_body.get("message", {}) + if isinstance(message, str): + try: + message = json.loads(message) + except json.JSONDecodeError: + message = {"detail": message} + detail = message.get("detail") + + if isinstance(detail, List) and len(detail) > 0 and isinstance(detail[0], dict): + for error_dict in detail: + if ( + error_dict.get("loc") + and isinstance(error_dict.get("loc"), list) + and len(error_dict.get("loc")) == 2 + ): + invalid_params.append(error_dict["loc"][1]) + + new_data = {k: v for k, v in data.items() if k not in invalid_params} + + return new_data + + +class BaseOpenAILLM: + """ + Base class for OpenAI LLMs for getting their httpx clients and SSL verification settings + """ + + @staticmethod + def get_cached_openai_client( + client_initialization_params: dict, client_type: Literal["openai", "azure"] + ) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]]: + """Retrieves the OpenAI client from the in-memory cache based on the client initialization parameters""" + _cache_key = BaseOpenAILLM.get_openai_client_cache_key( + client_initialization_params=client_initialization_params, + client_type=client_type, + ) + _cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key) + return _cached_client + + @staticmethod + def set_cached_openai_client( + openai_client: Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI], + client_type: Literal["openai", "azure"], + client_initialization_params: dict, + ): + """Stores the OpenAI client in the in-memory cache for _DEFAULT_TTL_FOR_HTTPX_CLIENTS SECONDS""" + _cache_key = BaseOpenAILLM.get_openai_client_cache_key( + client_initialization_params=client_initialization_params, + client_type=client_type, + ) + litellm.in_memory_llm_clients_cache.set_cache( + key=_cache_key, + value=openai_client, + ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS, + ) + + @staticmethod + def get_openai_client_cache_key( + client_initialization_params: dict, client_type: Literal["openai", "azure"] + ) -> str: + """Creates a cache key for the OpenAI client based on the client initialization parameters""" + hashed_api_key = None + if client_initialization_params.get("api_key") is not None: + hash_object = hashlib.sha256( + client_initialization_params.get("api_key", "").encode() + ) + # Hexadecimal representation of the hash + hashed_api_key = hash_object.hexdigest() + + # Create a more readable cache key using a list of key-value pairs + key_parts = [ + f"hashed_api_key={hashed_api_key}", + f"is_async={client_initialization_params.get('is_async')}", + ] + + LITELLM_CLIENT_SPECIFIC_PARAMS = [ + "timeout", + "max_retries", + "organization", + "api_base", + ] + openai_client_fields = ( + BaseOpenAILLM.get_openai_client_initialization_param_fields( + client_type=client_type + ) + + LITELLM_CLIENT_SPECIFIC_PARAMS + ) + + for param in openai_client_fields: + key_parts.append(f"{param}={client_initialization_params.get(param)}") + + _cache_key = ",".join(key_parts) + return _cache_key + + @staticmethod + def get_openai_client_initialization_param_fields( + client_type: Literal["openai", "azure"] + ) -> List[str]: + """Returns a list of fields that are used to initialize the OpenAI client""" + import inspect + + from openai import AzureOpenAI, OpenAI + + if client_type == "openai": + signature = inspect.signature(OpenAI.__init__) + else: + signature = inspect.signature(AzureOpenAI.__init__) + + # Extract parameter names, excluding 'self' + param_names = [param for param in signature.parameters if param != "self"] + return param_names + + @staticmethod + def _get_async_http_client() -> Optional[httpx.AsyncClient]: + if litellm.aclient_session is not None: + return litellm.aclient_session + + return httpx.AsyncClient( + limits=httpx.Limits(max_connections=1000, max_keepalive_connections=100), + verify=litellm.ssl_verify, + ) + + @staticmethod + def _get_sync_http_client() -> Optional[httpx.Client]: + if litellm.client_session is not None: + return litellm.client_session + return httpx.Client( + limits=httpx.Limits(max_connections=1000, max_keepalive_connections=100), + verify=litellm.ssl_verify, + ) |