diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/transformation.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/transformation.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/transformation.py | 268 |
1 files changed, 268 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/transformation.py new file mode 100644 index 00000000..154f3455 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/transformation.py @@ -0,0 +1,268 @@ +from typing import Any, List, Optional, Tuple, cast +from urllib.parse import urlparse + +import httpx +from httpx import Response + +import litellm +from litellm._logging import verbose_logger +from litellm.litellm_core_utils.prompt_templates.common_utils import ( + _audio_or_image_in_message_content, + convert_content_list_to_str, +) +from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj +from litellm.llms.openai.common_utils import drop_params_from_unprocessable_entity_error +from litellm.llms.openai.openai import OpenAIConfig +from litellm.secret_managers.main import get_secret_str +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ModelResponse, ProviderField +from litellm.utils import _add_path_to_api_base, supports_tool_choice + + +class AzureAIStudioConfig(OpenAIConfig): + def get_supported_openai_params(self, model: str) -> List: + model_supports_tool_choice = True # azure ai supports this by default + if not supports_tool_choice(model=f"azure_ai/{model}"): + model_supports_tool_choice = False + supported_params = super().get_supported_openai_params(model) + if not model_supports_tool_choice: + filtered_supported_params = [] + for param in supported_params: + if param != "tool_choice": + filtered_supported_params.append(param) + return filtered_supported_params + return supported_params + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + if api_base and self._should_use_api_key_header(api_base): + headers["api-key"] = api_key + else: + headers["Authorization"] = f"Bearer {api_key}" + + return headers + + def _should_use_api_key_header(self, api_base: str) -> bool: + """ + Returns True if the request should use `api-key` header for authentication. + """ + parsed_url = urlparse(api_base) + host = parsed_url.hostname + if host and ( + host.endswith(".services.ai.azure.com") + or host.endswith(".openai.azure.com") + ): + return True + return False + + def get_complete_url( + self, + api_base: Optional[str], + model: str, + optional_params: dict, + litellm_params: dict, + stream: Optional[bool] = None, + ) -> str: + """ + Constructs a complete URL for the API request. + + Args: + - api_base: Base URL, e.g., + "https://litellm8397336933.services.ai.azure.com" + OR + "https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview" + - model: Model name. + - optional_params: Additional query parameters, including "api_version". + - stream: If streaming is required (optional). + + Returns: + - A complete URL string, e.g., + "https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview" + """ + if api_base is None: + raise ValueError( + f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`" + ) + original_url = httpx.URL(api_base) + + # Extract api_version or use default + api_version = cast(Optional[str], litellm_params.get("api_version")) + + # Create a new dictionary with existing params + query_params = dict(original_url.params) + + # Add api_version if needed + if "api-version" not in query_params and api_version: + query_params["api-version"] = api_version + + # Add the path to the base URL + if "services.ai.azure.com" in api_base: + new_url = _add_path_to_api_base( + api_base=api_base, ending_path="/models/chat/completions" + ) + else: + new_url = _add_path_to_api_base( + api_base=api_base, ending_path="/chat/completions" + ) + + # Use the new query_params dictionary + final_url = httpx.URL(new_url).copy_with(params=query_params) + + return str(final_url) + + def get_required_params(self) -> List[ProviderField]: + """For a given provider, return it's required fields with a description""" + return [ + ProviderField( + field_name="api_key", + field_type="string", + field_description="Your Azure AI Studio API Key.", + field_value="zEJ...", + ), + ProviderField( + field_name="api_base", + field_type="string", + field_description="Your Azure AI Studio API Base.", + field_value="https://Mistral-serverless.", + ), + ] + + def _transform_messages( + self, + messages: List[AllMessageValues], + model: str, + ) -> List: + """ + - Azure AI Studio doesn't support content as a list. This handles: + 1. Transforms list content to a string. + 2. If message contains an image or audio, send as is (user-intended) + """ + for message in messages: + + # Do nothing if the message contains an image or audio + if _audio_or_image_in_message_content(message): + continue + + texts = convert_content_list_to_str(message=message) + if texts: + message["content"] = texts + return messages + + def _is_azure_openai_model(self, model: str, api_base: Optional[str]) -> bool: + try: + if "/" in model: + model = model.split("/", 1)[1] + if ( + model in litellm.open_ai_chat_completion_models + or model in litellm.open_ai_text_completion_models + or model in litellm.open_ai_embedding_models + ): + return True + + except Exception: + return False + return False + + def _get_openai_compatible_provider_info( + self, + model: str, + api_base: Optional[str], + api_key: Optional[str], + custom_llm_provider: str, + ) -> Tuple[Optional[str], Optional[str], str]: + api_base = api_base or get_secret_str("AZURE_AI_API_BASE") + dynamic_api_key = api_key or get_secret_str("AZURE_AI_API_KEY") + + if self._is_azure_openai_model(model=model, api_base=api_base): + verbose_logger.debug( + "Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format( + model + ) + ) + custom_llm_provider = "azure" + return api_base, dynamic_api_key, custom_llm_provider + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + extra_body = optional_params.pop("extra_body", {}) + if extra_body and isinstance(extra_body, dict): + optional_params.update(extra_body) + optional_params.pop("max_retries", None) + return super().transform_request( + model, messages, optional_params, litellm_params, headers + ) + + def transform_response( + self, + model: str, + raw_response: Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + model_response.model = f"azure_ai/{model}" + return super().transform_response( + model=model, + raw_response=raw_response, + model_response=model_response, + logging_obj=logging_obj, + request_data=request_data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + api_key=api_key, + json_mode=json_mode, + ) + + def should_retry_llm_api_inside_llm_translation_on_http_error( + self, e: httpx.HTTPStatusError, litellm_params: dict + ) -> bool: + should_drop_params = litellm_params.get("drop_params") or litellm.drop_params + error_text = e.response.text + if should_drop_params and "Extra inputs are not permitted" in error_text: + return True + elif ( + "unknown field: parameter index is not a valid field" in error_text + ): # remove index from tool calls + return True + return super().should_retry_llm_api_inside_llm_translation_on_http_error( + e=e, litellm_params=litellm_params + ) + + @property + def max_retry_on_unprocessable_entity_error(self) -> int: + return 2 + + def transform_request_on_unprocessable_entity_error( + self, e: httpx.HTTPStatusError, request_data: dict + ) -> dict: + _messages = cast(Optional[List[AllMessageValues]], request_data.get("messages")) + if ( + "unknown field: parameter index is not a valid field" in e.response.text + and _messages is not None + ): + litellm.remove_index_from_tool_calls( + messages=_messages, + ) + data = drop_params_from_unprocessable_entity_error(e=e, data=request_data) + return data |