about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/azure_ai
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/azure_ai')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/README.md1
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/handler.py3
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/transformation.py268
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/embed/__init__.py1
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/embed/cohere_transformation.py99
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/embed/handler.py292
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/handler.py5
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/transformation.py90
8 files changed, 759 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/README.md b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/README.md
new file mode 100644
index 00000000..8c521519
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/README.md
@@ -0,0 +1 @@
+`/chat/completion` calls routed via `openai.py`. 
\ No newline at end of file
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/handler.py
new file mode 100644
index 00000000..d141498c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/handler.py
@@ -0,0 +1,3 @@
+"""
+LLM Calling done in `openai/openai.py`
+"""
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/transformation.py
new file mode 100644
index 00000000..154f3455
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/chat/transformation.py
@@ -0,0 +1,268 @@
+from typing import Any, List, Optional, Tuple, cast
+from urllib.parse import urlparse
+
+import httpx
+from httpx import Response
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.litellm_core_utils.prompt_templates.common_utils import (
+    _audio_or_image_in_message_content,
+    convert_content_list_to_str,
+)
+from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
+from litellm.llms.openai.common_utils import drop_params_from_unprocessable_entity_error
+from litellm.llms.openai.openai import OpenAIConfig
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import ModelResponse, ProviderField
+from litellm.utils import _add_path_to_api_base, supports_tool_choice
+
+
+class AzureAIStudioConfig(OpenAIConfig):
+    def get_supported_openai_params(self, model: str) -> List:
+        model_supports_tool_choice = True  # azure ai supports this by default
+        if not supports_tool_choice(model=f"azure_ai/{model}"):
+            model_supports_tool_choice = False
+        supported_params = super().get_supported_openai_params(model)
+        if not model_supports_tool_choice:
+            filtered_supported_params = []
+            for param in supported_params:
+                if param != "tool_choice":
+                    filtered_supported_params.append(param)
+            return filtered_supported_params
+        return supported_params
+
+    def validate_environment(
+        self,
+        headers: dict,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+    ) -> dict:
+        if api_base and self._should_use_api_key_header(api_base):
+            headers["api-key"] = api_key
+        else:
+            headers["Authorization"] = f"Bearer {api_key}"
+
+        return headers
+
+    def _should_use_api_key_header(self, api_base: str) -> bool:
+        """
+        Returns True if the request should use `api-key` header for authentication.
+        """
+        parsed_url = urlparse(api_base)
+        host = parsed_url.hostname
+        if host and (
+            host.endswith(".services.ai.azure.com")
+            or host.endswith(".openai.azure.com")
+        ):
+            return True
+        return False
+
+    def get_complete_url(
+        self,
+        api_base: Optional[str],
+        model: str,
+        optional_params: dict,
+        litellm_params: dict,
+        stream: Optional[bool] = None,
+    ) -> str:
+        """
+        Constructs a complete URL for the API request.
+
+        Args:
+        - api_base: Base URL, e.g.,
+            "https://litellm8397336933.services.ai.azure.com"
+            OR
+            "https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
+        - model: Model name.
+        - optional_params: Additional query parameters, including "api_version".
+        - stream: If streaming is required (optional).
+
+        Returns:
+        - A complete URL string, e.g.,
+        "https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview"
+        """
+        if api_base is None:
+            raise ValueError(
+                f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`"
+            )
+        original_url = httpx.URL(api_base)
+
+        # Extract api_version or use default
+        api_version = cast(Optional[str], litellm_params.get("api_version"))
+
+        # Create a new dictionary with existing params
+        query_params = dict(original_url.params)
+
+        # Add api_version if needed
+        if "api-version" not in query_params and api_version:
+            query_params["api-version"] = api_version
+
+        # Add the path to the base URL
+        if "services.ai.azure.com" in api_base:
+            new_url = _add_path_to_api_base(
+                api_base=api_base, ending_path="/models/chat/completions"
+            )
+        else:
+            new_url = _add_path_to_api_base(
+                api_base=api_base, ending_path="/chat/completions"
+            )
+
+        # Use the new query_params dictionary
+        final_url = httpx.URL(new_url).copy_with(params=query_params)
+
+        return str(final_url)
+
+    def get_required_params(self) -> List[ProviderField]:
+        """For a given provider, return it's required fields with a description"""
+        return [
+            ProviderField(
+                field_name="api_key",
+                field_type="string",
+                field_description="Your Azure AI Studio API Key.",
+                field_value="zEJ...",
+            ),
+            ProviderField(
+                field_name="api_base",
+                field_type="string",
+                field_description="Your Azure AI Studio API Base.",
+                field_value="https://Mistral-serverless.",
+            ),
+        ]
+
+    def _transform_messages(
+        self,
+        messages: List[AllMessageValues],
+        model: str,
+    ) -> List:
+        """
+        - Azure AI Studio doesn't support content as a list. This handles:
+            1. Transforms list content to a string.
+            2. If message contains an image or audio, send as is (user-intended)
+        """
+        for message in messages:
+
+            # Do nothing if the message contains an image or audio
+            if _audio_or_image_in_message_content(message):
+                continue
+
+            texts = convert_content_list_to_str(message=message)
+            if texts:
+                message["content"] = texts
+        return messages
+
+    def _is_azure_openai_model(self, model: str, api_base: Optional[str]) -> bool:
+        try:
+            if "/" in model:
+                model = model.split("/", 1)[1]
+            if (
+                model in litellm.open_ai_chat_completion_models
+                or model in litellm.open_ai_text_completion_models
+                or model in litellm.open_ai_embedding_models
+            ):
+                return True
+
+        except Exception:
+            return False
+        return False
+
+    def _get_openai_compatible_provider_info(
+        self,
+        model: str,
+        api_base: Optional[str],
+        api_key: Optional[str],
+        custom_llm_provider: str,
+    ) -> Tuple[Optional[str], Optional[str], str]:
+        api_base = api_base or get_secret_str("AZURE_AI_API_BASE")
+        dynamic_api_key = api_key or get_secret_str("AZURE_AI_API_KEY")
+
+        if self._is_azure_openai_model(model=model, api_base=api_base):
+            verbose_logger.debug(
+                "Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format(
+                    model
+                )
+            )
+            custom_llm_provider = "azure"
+        return api_base, dynamic_api_key, custom_llm_provider
+
+    def transform_request(
+        self,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        headers: dict,
+    ) -> dict:
+        extra_body = optional_params.pop("extra_body", {})
+        if extra_body and isinstance(extra_body, dict):
+            optional_params.update(extra_body)
+        optional_params.pop("max_retries", None)
+        return super().transform_request(
+            model, messages, optional_params, litellm_params, headers
+        )
+
+    def transform_response(
+        self,
+        model: str,
+        raw_response: Response,
+        model_response: ModelResponse,
+        logging_obj: LiteLLMLoggingObj,
+        request_data: dict,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        encoding: Any,
+        api_key: Optional[str] = None,
+        json_mode: Optional[bool] = None,
+    ) -> ModelResponse:
+        model_response.model = f"azure_ai/{model}"
+        return super().transform_response(
+            model=model,
+            raw_response=raw_response,
+            model_response=model_response,
+            logging_obj=logging_obj,
+            request_data=request_data,
+            messages=messages,
+            optional_params=optional_params,
+            litellm_params=litellm_params,
+            encoding=encoding,
+            api_key=api_key,
+            json_mode=json_mode,
+        )
+
+    def should_retry_llm_api_inside_llm_translation_on_http_error(
+        self, e: httpx.HTTPStatusError, litellm_params: dict
+    ) -> bool:
+        should_drop_params = litellm_params.get("drop_params") or litellm.drop_params
+        error_text = e.response.text
+        if should_drop_params and "Extra inputs are not permitted" in error_text:
+            return True
+        elif (
+            "unknown field: parameter index is not a valid field" in error_text
+        ):  # remove index from tool calls
+            return True
+        return super().should_retry_llm_api_inside_llm_translation_on_http_error(
+            e=e, litellm_params=litellm_params
+        )
+
+    @property
+    def max_retry_on_unprocessable_entity_error(self) -> int:
+        return 2
+
+    def transform_request_on_unprocessable_entity_error(
+        self, e: httpx.HTTPStatusError, request_data: dict
+    ) -> dict:
+        _messages = cast(Optional[List[AllMessageValues]], request_data.get("messages"))
+        if (
+            "unknown field: parameter index is not a valid field" in e.response.text
+            and _messages is not None
+        ):
+            litellm.remove_index_from_tool_calls(
+                messages=_messages,
+            )
+        data = drop_params_from_unprocessable_entity_error(e=e, data=request_data)
+        return data
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/embed/__init__.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/embed/__init__.py
new file mode 100644
index 00000000..e0d67acb
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/embed/__init__.py
@@ -0,0 +1 @@
+from .handler import AzureAIEmbedding
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/embed/cohere_transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/embed/cohere_transformation.py
new file mode 100644
index 00000000..38b0dbbe
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/embed/cohere_transformation.py
@@ -0,0 +1,99 @@
+"""
+Transformation logic from OpenAI /v1/embeddings format to Azure AI Cohere's /v1/embed. 
+
+Why separate file? Make it easy to see how transformation works
+
+Convers
+- Cohere request format
+
+Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
+"""
+
+from typing import List, Optional, Tuple
+
+from litellm.types.llms.azure_ai import ImageEmbeddingInput, ImageEmbeddingRequest
+from litellm.types.llms.openai import EmbeddingCreateParams
+from litellm.types.utils import EmbeddingResponse, Usage
+from litellm.utils import is_base64_encoded
+
+
+class AzureAICohereConfig:
+    def __init__(self) -> None:
+        pass
+
+    def _map_azure_model_group(self, model: str) -> str:
+
+        if model == "offer-cohere-embed-multili-paygo":
+            return "Cohere-embed-v3-multilingual"
+        elif model == "offer-cohere-embed-english-paygo":
+            return "Cohere-embed-v3-english"
+
+        return model
+
+    def _transform_request_image_embeddings(
+        self, input: List[str], optional_params: dict
+    ) -> ImageEmbeddingRequest:
+        """
+        Assume all str in list is base64 encoded string
+        """
+        image_input: List[ImageEmbeddingInput] = []
+        for i in input:
+            embedding_input = ImageEmbeddingInput(image=i)
+            image_input.append(embedding_input)
+        return ImageEmbeddingRequest(input=image_input, **optional_params)
+
+    def _transform_request(
+        self, input: List[str], optional_params: dict, model: str
+    ) -> Tuple[ImageEmbeddingRequest, EmbeddingCreateParams, List[int]]:
+        """
+        Return the list of input to `/image/embeddings`, `/v1/embeddings`, list of image_embedding_idx for recombination
+        """
+        image_embeddings: List[str] = []
+        image_embedding_idx: List[int] = []
+        for idx, i in enumerate(input):
+            """
+            - is base64 -> route to image embeddings
+            - is ImageEmbeddingInput -> route to image embeddings
+            - else -> route to `/v1/embeddings`
+            """
+            if is_base64_encoded(i):
+                image_embeddings.append(i)
+                image_embedding_idx.append(idx)
+
+        ## REMOVE IMAGE EMBEDDINGS FROM input list
+        filtered_input = [
+            item for idx, item in enumerate(input) if idx not in image_embedding_idx
+        ]
+
+        v1_embeddings_request = EmbeddingCreateParams(
+            input=filtered_input, model=model, **optional_params
+        )
+        image_embeddings_request = self._transform_request_image_embeddings(
+            input=image_embeddings, optional_params=optional_params
+        )
+
+        return image_embeddings_request, v1_embeddings_request, image_embedding_idx
+
+    def _transform_response(self, response: EmbeddingResponse) -> EmbeddingResponse:
+        additional_headers: Optional[dict] = response._hidden_params.get(
+            "additional_headers"
+        )
+        if additional_headers:
+            # CALCULATE USAGE
+            input_tokens: Optional[str] = additional_headers.get(
+                "llm_provider-num_tokens"
+            )
+            if input_tokens:
+                if response.usage:
+                    response.usage.prompt_tokens = int(input_tokens)
+                else:
+                    response.usage = Usage(prompt_tokens=int(input_tokens))
+
+            # SET MODEL
+            base_model: Optional[str] = additional_headers.get(
+                "llm_provider-azureml-model-group"
+            )
+            if base_model:
+                response.model = self._map_azure_model_group(base_model)
+
+        return response
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/embed/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/embed/handler.py
new file mode 100644
index 00000000..f33c979c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/embed/handler.py
@@ -0,0 +1,292 @@
+from typing import List, Optional, Union
+
+from openai import OpenAI
+
+import litellm
+from litellm.llms.custom_httpx.http_handler import (
+    AsyncHTTPHandler,
+    HTTPHandler,
+    get_async_httpx_client,
+)
+from litellm.llms.openai.openai import OpenAIChatCompletion
+from litellm.types.llms.azure_ai import ImageEmbeddingRequest
+from litellm.types.utils import EmbeddingResponse
+from litellm.utils import convert_to_model_response_object
+
+from .cohere_transformation import AzureAICohereConfig
+
+
+class AzureAIEmbedding(OpenAIChatCompletion):
+
+    def _process_response(
+        self,
+        image_embedding_responses: Optional[List],
+        text_embedding_responses: Optional[List],
+        image_embeddings_idx: List[int],
+        model_response: EmbeddingResponse,
+        input: List,
+    ):
+        combined_responses = []
+        if (
+            image_embedding_responses is not None
+            and text_embedding_responses is not None
+        ):
+            # Combine and order the results
+            text_idx = 0
+            image_idx = 0
+
+            for idx in range(len(input)):
+                if idx in image_embeddings_idx:
+                    combined_responses.append(image_embedding_responses[image_idx])
+                    image_idx += 1
+                else:
+                    combined_responses.append(text_embedding_responses[text_idx])
+                    text_idx += 1
+
+            model_response.data = combined_responses
+        elif image_embedding_responses is not None:
+            model_response.data = image_embedding_responses
+        elif text_embedding_responses is not None:
+            model_response.data = text_embedding_responses
+
+        response = AzureAICohereConfig()._transform_response(response=model_response)  # type: ignore
+
+        return response
+
+    async def async_image_embedding(
+        self,
+        model: str,
+        data: ImageEmbeddingRequest,
+        timeout: float,
+        logging_obj,
+        model_response: litellm.EmbeddingResponse,
+        optional_params: dict,
+        api_key: Optional[str],
+        api_base: Optional[str],
+        client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+    ) -> EmbeddingResponse:
+        if client is None or not isinstance(client, AsyncHTTPHandler):
+            client = get_async_httpx_client(
+                llm_provider=litellm.LlmProviders.AZURE_AI,
+                params={"timeout": timeout},
+            )
+
+        url = "{}/images/embeddings".format(api_base)
+
+        response = await client.post(
+            url=url,
+            json=data,  # type: ignore
+            headers={"Authorization": "Bearer {}".format(api_key)},
+        )
+
+        embedding_response = response.json()
+        embedding_headers = dict(response.headers)
+        returned_response: EmbeddingResponse = convert_to_model_response_object(  # type: ignore
+            response_object=embedding_response,
+            model_response_object=model_response,
+            response_type="embedding",
+            stream=False,
+            _response_headers=embedding_headers,
+        )
+        return returned_response
+
+    def image_embedding(
+        self,
+        model: str,
+        data: ImageEmbeddingRequest,
+        timeout: float,
+        logging_obj,
+        model_response: EmbeddingResponse,
+        optional_params: dict,
+        api_key: Optional[str],
+        api_base: Optional[str],
+        client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+    ):
+        if api_base is None:
+            raise ValueError(
+                "api_base is None. Please set AZURE_AI_API_BASE or dynamically via `api_base` param, to make the request."
+            )
+        if api_key is None:
+            raise ValueError(
+                "api_key is None. Please set AZURE_AI_API_KEY or dynamically via `api_key` param, to make the request."
+            )
+
+        if client is None or not isinstance(client, HTTPHandler):
+            client = HTTPHandler(timeout=timeout, concurrent_limit=1)
+
+        url = "{}/images/embeddings".format(api_base)
+
+        response = client.post(
+            url=url,
+            json=data,  # type: ignore
+            headers={"Authorization": "Bearer {}".format(api_key)},
+        )
+
+        embedding_response = response.json()
+        embedding_headers = dict(response.headers)
+        returned_response: EmbeddingResponse = convert_to_model_response_object(  # type: ignore
+            response_object=embedding_response,
+            model_response_object=model_response,
+            response_type="embedding",
+            stream=False,
+            _response_headers=embedding_headers,
+        )
+        return returned_response
+
+    async def async_embedding(
+        self,
+        model: str,
+        input: List,
+        timeout: float,
+        logging_obj,
+        model_response: litellm.EmbeddingResponse,
+        optional_params: dict,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+        client=None,
+    ) -> EmbeddingResponse:
+
+        (
+            image_embeddings_request,
+            v1_embeddings_request,
+            image_embeddings_idx,
+        ) = AzureAICohereConfig()._transform_request(
+            input=input, optional_params=optional_params, model=model
+        )
+
+        image_embedding_responses: Optional[List] = None
+        text_embedding_responses: Optional[List] = None
+
+        if image_embeddings_request["input"]:
+            image_response = await self.async_image_embedding(
+                model=model,
+                data=image_embeddings_request,
+                timeout=timeout,
+                logging_obj=logging_obj,
+                model_response=model_response,
+                optional_params=optional_params,
+                api_key=api_key,
+                api_base=api_base,
+                client=client,
+            )
+
+            image_embedding_responses = image_response.data
+            if image_embedding_responses is None:
+                raise Exception("/image/embeddings route returned None Embeddings.")
+
+        if v1_embeddings_request["input"]:
+            response: EmbeddingResponse = await super().embedding(  # type: ignore
+                model=model,
+                input=input,
+                timeout=timeout,
+                logging_obj=logging_obj,
+                model_response=model_response,
+                optional_params=optional_params,
+                api_key=api_key,
+                api_base=api_base,
+                client=client,
+                aembedding=True,
+            )
+            text_embedding_responses = response.data
+            if text_embedding_responses is None:
+                raise Exception("/v1/embeddings route returned None Embeddings.")
+
+        return self._process_response(
+            image_embedding_responses=image_embedding_responses,
+            text_embedding_responses=text_embedding_responses,
+            image_embeddings_idx=image_embeddings_idx,
+            model_response=model_response,
+            input=input,
+        )
+
+    def embedding(
+        self,
+        model: str,
+        input: List,
+        timeout: float,
+        logging_obj,
+        model_response: EmbeddingResponse,
+        optional_params: dict,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+        client=None,
+        aembedding=None,
+        max_retries: Optional[int] = None,
+    ) -> EmbeddingResponse:
+        """
+        - Separate image url from text
+        -> route image url call to `/image/embeddings`
+        -> route text call to `/v1/embeddings` (OpenAI route)
+
+        assemble result in-order, and return
+        """
+        if aembedding is True:
+            return self.async_embedding(  # type: ignore
+                model,
+                input,
+                timeout,
+                logging_obj,
+                model_response,
+                optional_params,
+                api_key,
+                api_base,
+                client,
+            )
+
+        (
+            image_embeddings_request,
+            v1_embeddings_request,
+            image_embeddings_idx,
+        ) = AzureAICohereConfig()._transform_request(
+            input=input, optional_params=optional_params, model=model
+        )
+
+        image_embedding_responses: Optional[List] = None
+        text_embedding_responses: Optional[List] = None
+
+        if image_embeddings_request["input"]:
+            image_response = self.image_embedding(
+                model=model,
+                data=image_embeddings_request,
+                timeout=timeout,
+                logging_obj=logging_obj,
+                model_response=model_response,
+                optional_params=optional_params,
+                api_key=api_key,
+                api_base=api_base,
+                client=client,
+            )
+
+            image_embedding_responses = image_response.data
+            if image_embedding_responses is None:
+                raise Exception("/image/embeddings route returned None Embeddings.")
+
+        if v1_embeddings_request["input"]:
+            response: EmbeddingResponse = super().embedding(  # type: ignore
+                model,
+                input,
+                timeout,
+                logging_obj,
+                model_response,
+                optional_params,
+                api_key,
+                api_base,
+                client=(
+                    client
+                    if client is not None and isinstance(client, OpenAI)
+                    else None
+                ),
+                aembedding=aembedding,
+            )
+
+            text_embedding_responses = response.data
+            if text_embedding_responses is None:
+                raise Exception("/v1/embeddings route returned None Embeddings.")
+
+        return self._process_response(
+            image_embedding_responses=image_embedding_responses,
+            text_embedding_responses=text_embedding_responses,
+            image_embeddings_idx=image_embeddings_idx,
+            model_response=model_response,
+            input=input,
+        )
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/handler.py
new file mode 100644
index 00000000..57e7cefd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/handler.py
@@ -0,0 +1,5 @@
+"""
+Azure AI Rerank - uses `llm_http_handler.py` to make httpx requests
+
+Request/Response transformation is handled in `transformation.py`
+"""
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/transformation.py
new file mode 100644
index 00000000..842511f3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure_ai/rerank/transformation.py
@@ -0,0 +1,90 @@
+"""
+Translate between Cohere's `/rerank` format and Azure AI's `/rerank` format. 
+"""
+
+from typing import Optional
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.utils import RerankResponse
+
+
+class AzureAIRerankConfig(CohereRerankConfig):
+    """
+    Azure AI Rerank - Follows the same Spec as Cohere Rerank
+    """
+    def get_complete_url(self, api_base: Optional[str], model: str) -> str:
+        if api_base is None:
+            raise ValueError(
+                "Azure AI API Base is required. api_base=None. Set in call or via `AZURE_AI_API_BASE` env var."
+            )
+        if not api_base.endswith("/v1/rerank"):
+            api_base = f"{api_base}/v1/rerank"
+        return api_base
+
+    def validate_environment(
+        self,
+        headers: dict,
+        model: str,
+        api_key: Optional[str] = None,
+    ) -> dict:
+        if api_key is None:
+            api_key = get_secret_str("AZURE_AI_API_KEY") or litellm.azure_key
+
+        if api_key is None:
+            raise ValueError(
+                "Azure AI API key is required. Please set 'AZURE_AI_API_KEY' or 'litellm.azure_key'"
+            )
+
+        default_headers = {
+            "Authorization": f"Bearer {api_key}",
+            "accept": "application/json",
+            "content-type": "application/json",
+        }
+
+        # If 'Authorization' is provided in headers, it overrides the default.
+        if "Authorization" in headers:
+            default_headers["Authorization"] = headers["Authorization"]
+
+        # Merge other headers, overriding any default ones except Authorization
+        return {**default_headers, **headers}
+
+    def transform_rerank_response(
+        self,
+        model: str,
+        raw_response: httpx.Response,
+        model_response: RerankResponse,
+        logging_obj: LiteLLMLoggingObj,
+        api_key: Optional[str] = None,
+        request_data: dict = {},
+        optional_params: dict = {},
+        litellm_params: dict = {},
+    ) -> RerankResponse:
+        rerank_response = super().transform_rerank_response(
+            model=model,
+            raw_response=raw_response,
+            model_response=model_response,
+            logging_obj=logging_obj,
+            api_key=api_key,
+            request_data=request_data,
+            optional_params=optional_params,
+            litellm_params=litellm_params,
+        )
+        base_model = self._get_base_model(
+            rerank_response._hidden_params.get("llm_provider-azureml-model-group")
+        )
+        rerank_response._hidden_params["model"] = base_model
+        return rerank_response
+
+    def _get_base_model(self, azure_model_group: Optional[str]) -> Optional[str]:
+        if azure_model_group is None:
+            return None
+        if azure_model_group == "offer-cohere-rerank-mul-paygo":
+            return "azure_ai/cohere-rerank-v3-multilingual"
+        if azure_model_group == "offer-cohere-rerank-eng-paygo":
+            return "azure_ai/cohere-rerank-v3-english"
+        return azure_model_group