aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/transformation.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/transformation.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/transformation.py268
1 files changed, 268 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/transformation.py
new file mode 100644
index 00000000..154f3455
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/transformation.py
@@ -0,0 +1,268 @@
+from typing import Any, List, Optional, Tuple, cast
+from urllib.parse import urlparse
+
+import httpx
+from httpx import Response
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.litellm_core_utils.prompt_templates.common_utils import (
+ _audio_or_image_in_message_content,
+ convert_content_list_to_str,
+)
+from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
+from litellm.llms.openai.common_utils import drop_params_from_unprocessable_entity_error
+from litellm.llms.openai.openai import OpenAIConfig
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import ModelResponse, ProviderField
+from litellm.utils import _add_path_to_api_base, supports_tool_choice
+
+
+class AzureAIStudioConfig(OpenAIConfig):
+ def get_supported_openai_params(self, model: str) -> List:
+ model_supports_tool_choice = True # azure ai supports this by default
+ if not supports_tool_choice(model=f"azure_ai/{model}"):
+ model_supports_tool_choice = False
+ supported_params = super().get_supported_openai_params(model)
+ if not model_supports_tool_choice:
+ filtered_supported_params = []
+ for param in supported_params:
+ if param != "tool_choice":
+ filtered_supported_params.append(param)
+ return filtered_supported_params
+ return supported_params
+
+ def validate_environment(
+ self,
+ headers: dict,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ ) -> dict:
+ if api_base and self._should_use_api_key_header(api_base):
+ headers["api-key"] = api_key
+ else:
+ headers["Authorization"] = f"Bearer {api_key}"
+
+ return headers
+
+ def _should_use_api_key_header(self, api_base: str) -> bool:
+ """
+ Returns True if the request should use `api-key` header for authentication.
+ """
+ parsed_url = urlparse(api_base)
+ host = parsed_url.hostname
+ if host and (
+ host.endswith(".services.ai.azure.com")
+ or host.endswith(".openai.azure.com")
+ ):
+ return True
+ return False
+
+ def get_complete_url(
+ self,
+ api_base: Optional[str],
+ model: str,
+ optional_params: dict,
+ litellm_params: dict,
+ stream: Optional[bool] = None,
+ ) -> str:
+ """
+ Constructs a complete URL for the API request.
+
+ Args:
+ - api_base: Base URL, e.g.,
+ "https://litellm8397336933.services.ai.azure.com"
+ OR
+ "https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
+ - model: Model name.
+ - optional_params: Additional query parameters, including "api_version".
+ - stream: If streaming is required (optional).
+
+ Returns:
+ - A complete URL string, e.g.,
+ "https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
+ """
+ if api_base is None:
+ raise ValueError(
+ f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`"
+ )
+ original_url = httpx.URL(api_base)
+
+ # Extract api_version or use default
+ api_version = cast(Optional[str], litellm_params.get("api_version"))
+
+ # Create a new dictionary with existing params
+ query_params = dict(original_url.params)
+
+ # Add api_version if needed
+ if "api-version" not in query_params and api_version:
+ query_params["api-version"] = api_version
+
+ # Add the path to the base URL
+ if "services.ai.azure.com" in api_base:
+ new_url = _add_path_to_api_base(
+ api_base=api_base, ending_path="/models/chat/completions"
+ )
+ else:
+ new_url = _add_path_to_api_base(
+ api_base=api_base, ending_path="/chat/completions"
+ )
+
+ # Use the new query_params dictionary
+ final_url = httpx.URL(new_url).copy_with(params=query_params)
+
+ return str(final_url)
+
+ def get_required_params(self) -> List[ProviderField]:
+ """For a given provider, return it's required fields with a description"""
+ return [
+ ProviderField(
+ field_name="api_key",
+ field_type="string",
+ field_description="Your Azure AI Studio API Key.",
+ field_value="zEJ...",
+ ),
+ ProviderField(
+ field_name="api_base",
+ field_type="string",
+ field_description="Your Azure AI Studio API Base.",
+ field_value="https://Mistral-serverless.",
+ ),
+ ]
+
+ def _transform_messages(
+ self,
+ messages: List[AllMessageValues],
+ model: str,
+ ) -> List:
+ """
+ - Azure AI Studio doesn't support content as a list. This handles:
+ 1. Transforms list content to a string.
+ 2. If message contains an image or audio, send as is (user-intended)
+ """
+ for message in messages:
+
+ # Do nothing if the message contains an image or audio
+ if _audio_or_image_in_message_content(message):
+ continue
+
+ texts = convert_content_list_to_str(message=message)
+ if texts:
+ message["content"] = texts
+ return messages
+
+ def _is_azure_openai_model(self, model: str, api_base: Optional[str]) -> bool:
+ try:
+ if "/" in model:
+ model = model.split("/", 1)[1]
+ if (
+ model in litellm.open_ai_chat_completion_models
+ or model in litellm.open_ai_text_completion_models
+ or model in litellm.open_ai_embedding_models
+ ):
+ return True
+
+ except Exception:
+ return False
+ return False
+
+ def _get_openai_compatible_provider_info(
+ self,
+ model: str,
+ api_base: Optional[str],
+ api_key: Optional[str],
+ custom_llm_provider: str,
+ ) -> Tuple[Optional[str], Optional[str], str]:
+ api_base = api_base or get_secret_str("AZURE_AI_API_BASE")
+ dynamic_api_key = api_key or get_secret_str("AZURE_AI_API_KEY")
+
+ if self._is_azure_openai_model(model=model, api_base=api_base):
+ verbose_logger.debug(
+ "Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format(
+ model
+ )
+ )
+ custom_llm_provider = "azure"
+ return api_base, dynamic_api_key, custom_llm_provider
+
+ def transform_request(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ headers: dict,
+ ) -> dict:
+ extra_body = optional_params.pop("extra_body", {})
+ if extra_body and isinstance(extra_body, dict):
+ optional_params.update(extra_body)
+ optional_params.pop("max_retries", None)
+ return super().transform_request(
+ model, messages, optional_params, litellm_params, headers
+ )
+
+ def transform_response(
+ self,
+ model: str,
+ raw_response: Response,
+ model_response: ModelResponse,
+ logging_obj: LiteLLMLoggingObj,
+ request_data: dict,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ encoding: Any,
+ api_key: Optional[str] = None,
+ json_mode: Optional[bool] = None,
+ ) -> ModelResponse:
+ model_response.model = f"azure_ai/{model}"
+ return super().transform_response(
+ model=model,
+ raw_response=raw_response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ request_data=request_data,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=encoding,
+ api_key=api_key,
+ json_mode=json_mode,
+ )
+
+ def should_retry_llm_api_inside_llm_translation_on_http_error(
+ self, e: httpx.HTTPStatusError, litellm_params: dict
+ ) -> bool:
+ should_drop_params = litellm_params.get("drop_params") or litellm.drop_params
+ error_text = e.response.text
+ if should_drop_params and "Extra inputs are not permitted" in error_text:
+ return True
+ elif (
+ "unknown field: parameter index is not a valid field" in error_text
+ ): # remove index from tool calls
+ return True
+ return super().should_retry_llm_api_inside_llm_translation_on_http_error(
+ e=e, litellm_params=litellm_params
+ )
+
+ @property
+ def max_retry_on_unprocessable_entity_error(self) -> int:
+ return 2
+
+ def transform_request_on_unprocessable_entity_error(
+ self, e: httpx.HTTPStatusError, request_data: dict
+ ) -> dict:
+ _messages = cast(Optional[List[AllMessageValues]], request_data.get("messages"))
+ if (
+ "unknown field: parameter index is not a valid field" in e.response.text
+ and _messages is not None
+ ):
+ litellm.remove_index_from_tool_calls(
+ messages=_messages,
+ )
+ data = drop_params_from_unprocessable_entity_error(e=e, data=request_data)
+ return data