about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/azure/azure.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/azure/azure.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/azure/azure.py1347
1 files changed, 1347 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure/azure.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure/azure.py
new file mode 100644
index 00000000..03c5cc09
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure/azure.py
@@ -0,0 +1,1347 @@
+import asyncio
+import json
+import time
+from typing import Any, Callable, Coroutine, Dict, List, Optional, Union
+
+import httpx  # type: ignore
+from openai import APITimeoutError, AsyncAzureOpenAI, AzureOpenAI
+
+import litellm
+from litellm.constants import DEFAULT_MAX_RETRIES
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
+from litellm.llms.custom_httpx.http_handler import (
+    AsyncHTTPHandler,
+    HTTPHandler,
+    get_async_httpx_client,
+)
+from litellm.types.utils import (
+    EmbeddingResponse,
+    ImageResponse,
+    LlmProviders,
+    ModelResponse,
+)
+from litellm.utils import (
+    CustomStreamWrapper,
+    convert_to_model_response_object,
+    modify_url,
+)
+
+from ...types.llms.openai import HttpxBinaryResponseContent
+from ..base import BaseLLM
+from .common_utils import (
+    AzureOpenAIError,
+    BaseAzureLLM,
+    get_azure_ad_token_from_oidc,
+    process_azure_headers,
+    select_azure_base_url_or_endpoint,
+)
+
+
+class AzureOpenAIAssistantsAPIConfig:
+    """
+    Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
+    """
+
+    def __init__(
+        self,
+    ) -> None:
+        pass
+
+    def get_supported_openai_create_message_params(self):
+        return [
+            "role",
+            "content",
+            "attachments",
+            "metadata",
+        ]
+
+    def map_openai_params_create_message_params(
+        self, non_default_params: dict, optional_params: dict
+    ):
+        for param, value in non_default_params.items():
+            if param == "role":
+                optional_params["role"] = value
+            if param == "metadata":
+                optional_params["metadata"] = value
+            elif param == "content":  # only string accepted
+                if isinstance(value, str):
+                    optional_params["content"] = value
+                else:
+                    raise litellm.utils.UnsupportedParamsError(
+                        message="Azure only accepts content as a string.",
+                        status_code=400,
+                    )
+            elif (
+                param == "attachments"
+            ):  # this is a v2 param. Azure currently supports the old 'file_id's param
+                file_ids: List[str] = []
+                if isinstance(value, list):
+                    for item in value:
+                        if "file_id" in item:
+                            file_ids.append(item["file_id"])
+                        else:
+                            if litellm.drop_params is True:
+                                pass
+                            else:
+                                raise litellm.utils.UnsupportedParamsError(
+                                    message="Azure doesn't support {}. To drop it from the call, set `litellm.drop_params = True.".format(
+                                        value
+                                    ),
+                                    status_code=400,
+                                )
+                else:
+                    raise litellm.utils.UnsupportedParamsError(
+                        message="Invalid param. attachments should always be a list. Got={}, Expected=List. Raw value={}".format(
+                            type(value), value
+                        ),
+                        status_code=400,
+                    )
+        return optional_params
+
+
+def _check_dynamic_azure_params(
+    azure_client_params: dict,
+    azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]],
+) -> bool:
+    """
+    Returns True if user passed in client params != initialized azure client
+
+    Currently only implemented for api version
+    """
+    if azure_client is None:
+        return True
+
+    dynamic_params = ["api_version"]
+    for k, v in azure_client_params.items():
+        if k in dynamic_params and k == "api_version":
+            if v is not None and v != azure_client._custom_query["api-version"]:
+                return True
+
+    return False
+
+
+class AzureChatCompletion(BaseAzureLLM, BaseLLM):
+    def __init__(self) -> None:
+        super().__init__()
+
+    def validate_environment(self, api_key, azure_ad_token, azure_ad_token_provider):
+        headers = {
+            "content-type": "application/json",
+        }
+        if api_key is not None:
+            headers["api-key"] = api_key
+        elif azure_ad_token is not None:
+            if azure_ad_token.startswith("oidc/"):
+                azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
+            headers["Authorization"] = f"Bearer {azure_ad_token}"
+        elif azure_ad_token_provider is not None:
+            azure_ad_token = azure_ad_token_provider()
+            headers["Authorization"] = f"Bearer {azure_ad_token}"
+
+        return headers
+
+    def make_sync_azure_openai_chat_completion_request(
+        self,
+        azure_client: AzureOpenAI,
+        data: dict,
+        timeout: Union[float, httpx.Timeout],
+    ):
+        """
+        Helper to:
+        - call chat.completions.create.with_raw_response when litellm.return_response_headers is True
+        - call chat.completions.create by default
+        """
+        try:
+            raw_response = azure_client.chat.completions.with_raw_response.create(
+                **data, timeout=timeout
+            )
+
+            headers = dict(raw_response.headers)
+            response = raw_response.parse()
+            return headers, response
+        except Exception as e:
+            raise e
+
+    @track_llm_api_timing()
+    async def make_azure_openai_chat_completion_request(
+        self,
+        azure_client: AsyncAzureOpenAI,
+        data: dict,
+        timeout: Union[float, httpx.Timeout],
+        logging_obj: LiteLLMLoggingObj,
+    ):
+        """
+        Helper to:
+        - call chat.completions.create.with_raw_response when litellm.return_response_headers is True
+        - call chat.completions.create by default
+        """
+        start_time = time.time()
+        try:
+            raw_response = await azure_client.chat.completions.with_raw_response.create(
+                **data, timeout=timeout
+            )
+
+            headers = dict(raw_response.headers)
+            response = raw_response.parse()
+            return headers, response
+        except APITimeoutError as e:
+            end_time = time.time()
+            time_delta = round(end_time - start_time, 2)
+            e.message += f" - timeout value={timeout}, time taken={time_delta} seconds"
+            raise e
+        except Exception as e:
+            raise e
+
+    def completion(  # noqa: PLR0915
+        self,
+        model: str,
+        messages: list,
+        model_response: ModelResponse,
+        api_key: str,
+        api_base: str,
+        api_version: str,
+        api_type: str,
+        azure_ad_token: str,
+        azure_ad_token_provider: Callable,
+        dynamic_params: bool,
+        print_verbose: Callable,
+        timeout: Union[float, httpx.Timeout],
+        logging_obj: LiteLLMLoggingObj,
+        optional_params,
+        litellm_params,
+        logger_fn,
+        acompletion: bool = False,
+        headers: Optional[dict] = None,
+        client=None,
+    ):
+        if headers:
+            optional_params["extra_headers"] = headers
+        try:
+            if model is None or messages is None:
+                raise AzureOpenAIError(
+                    status_code=422, message="Missing model or messages"
+                )
+
+            max_retries = optional_params.pop("max_retries", None)
+            if max_retries is None:
+                max_retries = DEFAULT_MAX_RETRIES
+            json_mode: Optional[bool] = optional_params.pop("json_mode", False)
+
+            ### CHECK IF CLOUDFLARE AI GATEWAY ###
+            ### if so - set the model as part of the base url
+            if "gateway.ai.cloudflare.com" in api_base:
+                client = self._init_azure_client_for_cloudflare_ai_gateway(
+                    api_base=api_base,
+                    model=model,
+                    api_version=api_version,
+                    max_retries=max_retries,
+                    timeout=timeout,
+                    api_key=api_key,
+                    azure_ad_token=azure_ad_token,
+                    azure_ad_token_provider=azure_ad_token_provider,
+                    acompletion=acompletion,
+                    client=client,
+                )
+
+                data = {"model": None, "messages": messages, **optional_params}
+            else:
+                data = litellm.AzureOpenAIConfig().transform_request(
+                    model=model,
+                    messages=messages,
+                    optional_params=optional_params,
+                    litellm_params=litellm_params,
+                    headers=headers or {},
+                )
+
+            if acompletion is True:
+                if optional_params.get("stream", False):
+                    return self.async_streaming(
+                        logging_obj=logging_obj,
+                        api_base=api_base,
+                        dynamic_params=dynamic_params,
+                        data=data,
+                        model=model,
+                        api_key=api_key,
+                        api_version=api_version,
+                        azure_ad_token=azure_ad_token,
+                        azure_ad_token_provider=azure_ad_token_provider,
+                        timeout=timeout,
+                        client=client,
+                        max_retries=max_retries,
+                        litellm_params=litellm_params,
+                    )
+                else:
+                    return self.acompletion(
+                        api_base=api_base,
+                        data=data,
+                        model_response=model_response,
+                        api_key=api_key,
+                        api_version=api_version,
+                        model=model,
+                        azure_ad_token=azure_ad_token,
+                        azure_ad_token_provider=azure_ad_token_provider,
+                        dynamic_params=dynamic_params,
+                        timeout=timeout,
+                        client=client,
+                        logging_obj=logging_obj,
+                        max_retries=max_retries,
+                        convert_tool_call_to_json_mode=json_mode,
+                        litellm_params=litellm_params,
+                    )
+            elif "stream" in optional_params and optional_params["stream"] is True:
+                return self.streaming(
+                    logging_obj=logging_obj,
+                    api_base=api_base,
+                    dynamic_params=dynamic_params,
+                    data=data,
+                    model=model,
+                    api_key=api_key,
+                    api_version=api_version,
+                    azure_ad_token=azure_ad_token,
+                    azure_ad_token_provider=azure_ad_token_provider,
+                    timeout=timeout,
+                    client=client,
+                    max_retries=max_retries,
+                    litellm_params=litellm_params,
+                )
+            else:
+                ## LOGGING
+                logging_obj.pre_call(
+                    input=messages,
+                    api_key=api_key,
+                    additional_args={
+                        "headers": {
+                            "api_key": api_key,
+                            "azure_ad_token": azure_ad_token,
+                        },
+                        "api_version": api_version,
+                        "api_base": api_base,
+                        "complete_input_dict": data,
+                    },
+                )
+                if not isinstance(max_retries, int):
+                    raise AzureOpenAIError(
+                        status_code=422, message="max retries must be an int"
+                    )
+                # init AzureOpenAI Client
+                azure_client = self.get_azure_openai_client(
+                    api_version=api_version,
+                    api_base=api_base,
+                    api_key=api_key,
+                    model=model,
+                    client=client,
+                    _is_async=False,
+                    litellm_params=litellm_params,
+                )
+                if not isinstance(azure_client, AzureOpenAI):
+                    raise AzureOpenAIError(
+                        status_code=500,
+                        message="azure_client is not an instance of AzureOpenAI",
+                    )
+
+                headers, response = self.make_sync_azure_openai_chat_completion_request(
+                    azure_client=azure_client, data=data, timeout=timeout
+                )
+                stringified_response = response.model_dump()
+                ## LOGGING
+                logging_obj.post_call(
+                    input=messages,
+                    api_key=api_key,
+                    original_response=stringified_response,
+                    additional_args={
+                        "headers": headers,
+                        "api_version": api_version,
+                        "api_base": api_base,
+                    },
+                )
+                return convert_to_model_response_object(
+                    response_object=stringified_response,
+                    model_response_object=model_response,
+                    convert_tool_call_to_json_mode=json_mode,
+                    _response_headers=headers,
+                )
+        except AzureOpenAIError as e:
+            raise e
+        except Exception as e:
+            status_code = getattr(e, "status_code", 500)
+            error_headers = getattr(e, "headers", None)
+            error_response = getattr(e, "response", None)
+            error_body = getattr(e, "body", None)
+            if error_headers is None and error_response:
+                error_headers = getattr(error_response, "headers", None)
+            raise AzureOpenAIError(
+                status_code=status_code,
+                message=str(e),
+                headers=error_headers,
+                body=error_body,
+            )
+
+    async def acompletion(
+        self,
+        api_key: str,
+        api_version: str,
+        model: str,
+        api_base: str,
+        data: dict,
+        timeout: Any,
+        dynamic_params: bool,
+        model_response: ModelResponse,
+        logging_obj: LiteLLMLoggingObj,
+        max_retries: int,
+        azure_ad_token: Optional[str] = None,
+        azure_ad_token_provider: Optional[Callable] = None,
+        convert_tool_call_to_json_mode: Optional[bool] = None,
+        client=None,  # this is the AsyncAzureOpenAI
+        litellm_params: Optional[dict] = {},
+    ):
+        response = None
+        try:
+            # setting Azure client
+            azure_client = self.get_azure_openai_client(
+                api_version=api_version,
+                api_base=api_base,
+                api_key=api_key,
+                model=model,
+                client=client,
+                _is_async=True,
+                litellm_params=litellm_params,
+            )
+            if not isinstance(azure_client, AsyncAzureOpenAI):
+                raise ValueError("Azure client is not an instance of AsyncAzureOpenAI")
+            ## LOGGING
+            logging_obj.pre_call(
+                input=data["messages"],
+                api_key=azure_client.api_key,
+                additional_args={
+                    "headers": {
+                        "api_key": api_key,
+                        "azure_ad_token": azure_ad_token,
+                    },
+                    "api_base": azure_client._base_url._uri_reference,
+                    "acompletion": True,
+                    "complete_input_dict": data,
+                },
+            )
+
+            headers, response = await self.make_azure_openai_chat_completion_request(
+                azure_client=azure_client,
+                data=data,
+                timeout=timeout,
+                logging_obj=logging_obj,
+            )
+            logging_obj.model_call_details["response_headers"] = headers
+
+            stringified_response = response.model_dump()
+            logging_obj.post_call(
+                input=data["messages"],
+                api_key=api_key,
+                original_response=stringified_response,
+                additional_args={"complete_input_dict": data},
+            )
+
+            return convert_to_model_response_object(
+                response_object=stringified_response,
+                model_response_object=model_response,
+                hidden_params={"headers": headers},
+                _response_headers=headers,
+                convert_tool_call_to_json_mode=convert_tool_call_to_json_mode,
+            )
+        except AzureOpenAIError as e:
+            ## LOGGING
+            logging_obj.post_call(
+                input=data["messages"],
+                api_key=api_key,
+                additional_args={"complete_input_dict": data},
+                original_response=str(e),
+            )
+            raise e
+        except asyncio.CancelledError as e:
+            ## LOGGING
+            logging_obj.post_call(
+                input=data["messages"],
+                api_key=api_key,
+                additional_args={"complete_input_dict": data},
+                original_response=str(e),
+            )
+            raise AzureOpenAIError(status_code=500, message=str(e))
+        except Exception as e:
+            message = getattr(e, "message", str(e))
+            body = getattr(e, "body", None)
+            ## LOGGING
+            logging_obj.post_call(
+                input=data["messages"],
+                api_key=api_key,
+                additional_args={"complete_input_dict": data},
+                original_response=str(e),
+            )
+            if hasattr(e, "status_code"):
+                raise e
+            else:
+                raise AzureOpenAIError(status_code=500, message=message, body=body)
+
+    def streaming(
+        self,
+        logging_obj,
+        api_base: str,
+        api_key: str,
+        api_version: str,
+        dynamic_params: bool,
+        data: dict,
+        model: str,
+        timeout: Any,
+        max_retries: int,
+        azure_ad_token: Optional[str] = None,
+        azure_ad_token_provider: Optional[Callable] = None,
+        client=None,
+        litellm_params: Optional[dict] = {},
+    ):
+        # init AzureOpenAI Client
+        azure_client_params = {
+            "api_version": api_version,
+            "azure_endpoint": api_base,
+            "azure_deployment": model,
+            "http_client": litellm.client_session,
+            "max_retries": max_retries,
+            "timeout": timeout,
+        }
+        azure_client_params = select_azure_base_url_or_endpoint(
+            azure_client_params=azure_client_params
+        )
+        if api_key is not None:
+            azure_client_params["api_key"] = api_key
+        elif azure_ad_token is not None:
+            if azure_ad_token.startswith("oidc/"):
+                azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
+            azure_client_params["azure_ad_token"] = azure_ad_token
+        elif azure_ad_token_provider is not None:
+            azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
+
+        azure_client = self.get_azure_openai_client(
+            api_version=api_version,
+            api_base=api_base,
+            api_key=api_key,
+            model=model,
+            client=client,
+            _is_async=False,
+            litellm_params=litellm_params,
+        )
+        if not isinstance(azure_client, AzureOpenAI):
+            raise AzureOpenAIError(
+                status_code=500,
+                message="azure_client is not an instance of AzureOpenAI",
+            )
+        ## LOGGING
+        logging_obj.pre_call(
+            input=data["messages"],
+            api_key=azure_client.api_key,
+            additional_args={
+                "headers": {
+                    "api_key": api_key,
+                    "azure_ad_token": azure_ad_token,
+                },
+                "api_base": azure_client._base_url._uri_reference,
+                "acompletion": True,
+                "complete_input_dict": data,
+            },
+        )
+        headers, response = self.make_sync_azure_openai_chat_completion_request(
+            azure_client=azure_client, data=data, timeout=timeout
+        )
+        streamwrapper = CustomStreamWrapper(
+            completion_stream=response,
+            model=model,
+            custom_llm_provider="azure",
+            logging_obj=logging_obj,
+            stream_options=data.get("stream_options", None),
+            _response_headers=process_azure_headers(headers),
+        )
+        return streamwrapper
+
+    async def async_streaming(
+        self,
+        logging_obj: LiteLLMLoggingObj,
+        api_base: str,
+        api_key: str,
+        api_version: str,
+        dynamic_params: bool,
+        data: dict,
+        model: str,
+        timeout: Any,
+        max_retries: int,
+        azure_ad_token: Optional[str] = None,
+        azure_ad_token_provider: Optional[Callable] = None,
+        client=None,
+        litellm_params: Optional[dict] = {},
+    ):
+        try:
+            azure_client = self.get_azure_openai_client(
+                api_version=api_version,
+                api_base=api_base,
+                api_key=api_key,
+                model=model,
+                client=client,
+                _is_async=True,
+                litellm_params=litellm_params,
+            )
+            if not isinstance(azure_client, AsyncAzureOpenAI):
+                raise ValueError("Azure client is not an instance of AsyncAzureOpenAI")
+
+            ## LOGGING
+            logging_obj.pre_call(
+                input=data["messages"],
+                api_key=azure_client.api_key,
+                additional_args={
+                    "headers": {
+                        "api_key": api_key,
+                        "azure_ad_token": azure_ad_token,
+                    },
+                    "api_base": azure_client._base_url._uri_reference,
+                    "acompletion": True,
+                    "complete_input_dict": data,
+                },
+            )
+
+            headers, response = await self.make_azure_openai_chat_completion_request(
+                azure_client=azure_client,
+                data=data,
+                timeout=timeout,
+                logging_obj=logging_obj,
+            )
+            logging_obj.model_call_details["response_headers"] = headers
+
+            # return response
+            streamwrapper = CustomStreamWrapper(
+                completion_stream=response,
+                model=model,
+                custom_llm_provider="azure",
+                logging_obj=logging_obj,
+                stream_options=data.get("stream_options", None),
+                _response_headers=headers,
+            )
+            return streamwrapper  ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
+        except Exception as e:
+            status_code = getattr(e, "status_code", 500)
+            error_headers = getattr(e, "headers", None)
+            error_response = getattr(e, "response", None)
+            message = getattr(e, "message", str(e))
+            error_body = getattr(e, "body", None)
+            if error_headers is None and error_response:
+                error_headers = getattr(error_response, "headers", None)
+            raise AzureOpenAIError(
+                status_code=status_code,
+                message=message,
+                headers=error_headers,
+                body=error_body,
+            )
+
+    async def aembedding(
+        self,
+        model: str,
+        data: dict,
+        model_response: EmbeddingResponse,
+        input: list,
+        logging_obj: LiteLLMLoggingObj,
+        api_base: str,
+        api_key: Optional[str] = None,
+        api_version: Optional[str] = None,
+        client: Optional[AsyncAzureOpenAI] = None,
+        timeout: Optional[Union[float, httpx.Timeout]] = None,
+        max_retries: Optional[int] = None,
+        azure_ad_token: Optional[str] = None,
+        azure_ad_token_provider: Optional[Callable] = None,
+        litellm_params: Optional[dict] = {},
+    ) -> EmbeddingResponse:
+        response = None
+        try:
+
+            openai_aclient = self.get_azure_openai_client(
+                api_version=api_version,
+                api_base=api_base,
+                api_key=api_key,
+                model=model,
+                _is_async=True,
+                client=client,
+                litellm_params=litellm_params,
+            )
+            if not isinstance(openai_aclient, AsyncAzureOpenAI):
+                raise ValueError("Azure client is not an instance of AsyncAzureOpenAI")
+
+            raw_response = await openai_aclient.embeddings.with_raw_response.create(
+                **data, timeout=timeout
+            )
+            headers = dict(raw_response.headers)
+            response = raw_response.parse()
+            stringified_response = response.model_dump()
+            ## LOGGING
+            logging_obj.post_call(
+                input=input,
+                api_key=api_key,
+                additional_args={"complete_input_dict": data},
+                original_response=stringified_response,
+            )
+            embedding_response = convert_to_model_response_object(
+                response_object=stringified_response,
+                model_response_object=model_response,
+                hidden_params={"headers": headers},
+                _response_headers=process_azure_headers(headers),
+                response_type="embedding",
+            )
+            if not isinstance(embedding_response, EmbeddingResponse):
+                raise AzureOpenAIError(
+                    status_code=500,
+                    message="embedding_response is not an instance of EmbeddingResponse",
+                )
+            return embedding_response
+        except Exception as e:
+            ## LOGGING
+            logging_obj.post_call(
+                input=input,
+                api_key=api_key,
+                additional_args={"complete_input_dict": data},
+                original_response=str(e),
+            )
+            raise e
+
+    def embedding(
+        self,
+        model: str,
+        input: list,
+        api_base: str,
+        api_version: str,
+        timeout: float,
+        logging_obj: LiteLLMLoggingObj,
+        model_response: EmbeddingResponse,
+        optional_params: dict,
+        api_key: Optional[str] = None,
+        azure_ad_token: Optional[str] = None,
+        azure_ad_token_provider: Optional[Callable] = None,
+        max_retries: Optional[int] = None,
+        client=None,
+        aembedding=None,
+        headers: Optional[dict] = None,
+        litellm_params: Optional[dict] = None,
+    ) -> Union[EmbeddingResponse, Coroutine[Any, Any, EmbeddingResponse]]:
+        if headers:
+            optional_params["extra_headers"] = headers
+        if self._client_session is None:
+            self._client_session = self.create_client_session()
+        try:
+            data = {"model": model, "input": input, **optional_params}
+            if max_retries is None:
+                max_retries = litellm.DEFAULT_MAX_RETRIES
+            ## LOGGING
+            logging_obj.pre_call(
+                input=input,
+                api_key=api_key,
+                additional_args={
+                    "complete_input_dict": data,
+                    "headers": {"api_key": api_key, "azure_ad_token": azure_ad_token},
+                },
+            )
+
+            if aembedding is True:
+                return self.aembedding(
+                    data=data,
+                    input=input,
+                    model=model,
+                    logging_obj=logging_obj,
+                    api_key=api_key,
+                    model_response=model_response,
+                    timeout=timeout,
+                    client=client,
+                    litellm_params=litellm_params,
+                    api_base=api_base,
+                )
+            azure_client = self.get_azure_openai_client(
+                api_version=api_version,
+                api_base=api_base,
+                api_key=api_key,
+                model=model,
+                _is_async=False,
+                client=client,
+                litellm_params=litellm_params,
+            )
+            if not isinstance(azure_client, AzureOpenAI):
+                raise AzureOpenAIError(
+                    status_code=500,
+                    message="azure_client is not an instance of AzureOpenAI",
+                )
+
+            ## COMPLETION CALL
+            raw_response = azure_client.embeddings.with_raw_response.create(**data, timeout=timeout)  # type: ignore
+            headers = dict(raw_response.headers)
+            response = raw_response.parse()
+            ## LOGGING
+            logging_obj.post_call(
+                input=input,
+                api_key=api_key,
+                additional_args={"complete_input_dict": data, "api_base": api_base},
+                original_response=response,
+            )
+
+            return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding", _response_headers=process_azure_headers(headers))  # type: ignore
+        except AzureOpenAIError as e:
+            raise e
+        except Exception as e:
+            status_code = getattr(e, "status_code", 500)
+            error_headers = getattr(e, "headers", None)
+            error_response = getattr(e, "response", None)
+            if error_headers is None and error_response:
+                error_headers = getattr(error_response, "headers", None)
+            raise AzureOpenAIError(
+                status_code=status_code, message=str(e), headers=error_headers
+            )
+
+    async def make_async_azure_httpx_request(
+        self,
+        client: Optional[AsyncHTTPHandler],
+        timeout: Optional[Union[float, httpx.Timeout]],
+        api_base: str,
+        api_version: str,
+        api_key: str,
+        data: dict,
+        headers: dict,
+    ) -> httpx.Response:
+        """
+        Implemented for azure dall-e-2 image gen calls
+
+        Alternative to needing a custom transport implementation
+        """
+        if client is None:
+            _params = {}
+            if timeout is not None:
+                if isinstance(timeout, float) or isinstance(timeout, int):
+                    _httpx_timeout = httpx.Timeout(timeout)
+                    _params["timeout"] = _httpx_timeout
+            else:
+                _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
+
+            async_handler = get_async_httpx_client(
+                llm_provider=LlmProviders.AZURE,
+                params=_params,
+            )
+        else:
+            async_handler = client  # type: ignore
+
+        if (
+            "images/generations" in api_base
+            and api_version
+            in [  # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
+                "2023-06-01-preview",
+                "2023-07-01-preview",
+                "2023-08-01-preview",
+                "2023-09-01-preview",
+                "2023-10-01-preview",
+            ]
+        ):  # CREATE + POLL for azure dall-e-2 calls
+
+            api_base = modify_url(
+                original_url=api_base, new_path="/openai/images/generations:submit"
+            )
+
+            data.pop(
+                "model", None
+            )  # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview
+            response = await async_handler.post(
+                url=api_base,
+                data=json.dumps(data),
+                headers=headers,
+            )
+            if "operation-location" in response.headers:
+                operation_location_url = response.headers["operation-location"]
+            else:
+                raise AzureOpenAIError(status_code=500, message=response.text)
+            response = await async_handler.get(
+                url=operation_location_url,
+                headers=headers,
+            )
+
+            await response.aread()
+
+            timeout_secs: int = 120
+            start_time = time.time()
+            if "status" not in response.json():
+                raise Exception(
+                    "Expected 'status' in response. Got={}".format(response.json())
+                )
+            while response.json()["status"] not in ["succeeded", "failed"]:
+                if time.time() - start_time > timeout_secs:
+
+                    raise AzureOpenAIError(
+                        status_code=408, message="Operation polling timed out."
+                    )
+
+                await asyncio.sleep(int(response.headers.get("retry-after") or 10))
+                response = await async_handler.get(
+                    url=operation_location_url,
+                    headers=headers,
+                )
+                await response.aread()
+
+            if response.json()["status"] == "failed":
+                error_data = response.json()
+                raise AzureOpenAIError(status_code=400, message=json.dumps(error_data))
+
+            result = response.json()["result"]
+            return httpx.Response(
+                status_code=200,
+                headers=response.headers,
+                content=json.dumps(result).encode("utf-8"),
+                request=httpx.Request(method="POST", url="https://api.openai.com/v1"),
+            )
+        return await async_handler.post(
+            url=api_base,
+            json=data,
+            headers=headers,
+        )
+
+    def make_sync_azure_httpx_request(
+        self,
+        client: Optional[HTTPHandler],
+        timeout: Optional[Union[float, httpx.Timeout]],
+        api_base: str,
+        api_version: str,
+        api_key: str,
+        data: dict,
+        headers: dict,
+    ) -> httpx.Response:
+        """
+        Implemented for azure dall-e-2 image gen calls
+
+        Alternative to needing a custom transport implementation
+        """
+        if client is None:
+            _params = {}
+            if timeout is not None:
+                if isinstance(timeout, float) or isinstance(timeout, int):
+                    _httpx_timeout = httpx.Timeout(timeout)
+                    _params["timeout"] = _httpx_timeout
+            else:
+                _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
+
+            sync_handler = HTTPHandler(**_params, client=litellm.client_session)  # type: ignore
+        else:
+            sync_handler = client  # type: ignore
+
+        if (
+            "images/generations" in api_base
+            and api_version
+            in [  # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
+                "2023-06-01-preview",
+                "2023-07-01-preview",
+                "2023-08-01-preview",
+                "2023-09-01-preview",
+                "2023-10-01-preview",
+            ]
+        ):  # CREATE + POLL for azure dall-e-2 calls
+
+            api_base = modify_url(
+                original_url=api_base, new_path="/openai/images/generations:submit"
+            )
+
+            data.pop(
+                "model", None
+            )  # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview
+            response = sync_handler.post(
+                url=api_base,
+                data=json.dumps(data),
+                headers=headers,
+            )
+            if "operation-location" in response.headers:
+                operation_location_url = response.headers["operation-location"]
+            else:
+                raise AzureOpenAIError(status_code=500, message=response.text)
+            response = sync_handler.get(
+                url=operation_location_url,
+                headers=headers,
+            )
+
+            response.read()
+
+            timeout_secs: int = 120
+            start_time = time.time()
+            if "status" not in response.json():
+                raise Exception(
+                    "Expected 'status' in response. Got={}".format(response.json())
+                )
+            while response.json()["status"] not in ["succeeded", "failed"]:
+                if time.time() - start_time > timeout_secs:
+                    raise AzureOpenAIError(
+                        status_code=408, message="Operation polling timed out."
+                    )
+
+                time.sleep(int(response.headers.get("retry-after") or 10))
+                response = sync_handler.get(
+                    url=operation_location_url,
+                    headers=headers,
+                )
+                response.read()
+
+            if response.json()["status"] == "failed":
+                error_data = response.json()
+                raise AzureOpenAIError(status_code=400, message=json.dumps(error_data))
+
+            result = response.json()["result"]
+            return httpx.Response(
+                status_code=200,
+                headers=response.headers,
+                content=json.dumps(result).encode("utf-8"),
+                request=httpx.Request(method="POST", url="https://api.openai.com/v1"),
+            )
+        return sync_handler.post(
+            url=api_base,
+            json=data,
+            headers=headers,
+        )
+
+    def create_azure_base_url(
+        self, azure_client_params: dict, model: Optional[str]
+    ) -> str:
+        api_base: str = azure_client_params.get(
+            "azure_endpoint", ""
+        )  # "https://example-endpoint.openai.azure.com"
+        if api_base.endswith("/"):
+            api_base = api_base.rstrip("/")
+        api_version: str = azure_client_params.get("api_version", "")
+        if model is None:
+            model = ""
+
+        if "/openai/deployments/" in api_base:
+            base_url_with_deployment = api_base
+        else:
+            base_url_with_deployment = api_base + "/openai/deployments/" + model
+
+        base_url_with_deployment += "/images/generations"
+        base_url_with_deployment += "?api-version=" + api_version
+
+        return base_url_with_deployment
+
+    async def aimage_generation(
+        self,
+        data: dict,
+        model_response: ModelResponse,
+        azure_client_params: dict,
+        api_key: str,
+        input: list,
+        logging_obj: LiteLLMLoggingObj,
+        headers: dict,
+        client=None,
+        timeout=None,
+    ) -> litellm.ImageResponse:
+        response: Optional[dict] = None
+        try:
+            # response = await azure_client.images.generate(**data, timeout=timeout)
+            api_base: str = azure_client_params.get(
+                "api_base", ""
+            )  # "https://example-endpoint.openai.azure.com"
+            if api_base.endswith("/"):
+                api_base = api_base.rstrip("/")
+            api_version: str = azure_client_params.get("api_version", "")
+            img_gen_api_base = self.create_azure_base_url(
+                azure_client_params=azure_client_params, model=data.get("model", "")
+            )
+
+            ## LOGGING
+            logging_obj.pre_call(
+                input=data["prompt"],
+                api_key=api_key,
+                additional_args={
+                    "complete_input_dict": data,
+                    "api_base": img_gen_api_base,
+                    "headers": headers,
+                },
+            )
+            httpx_response: httpx.Response = await self.make_async_azure_httpx_request(
+                client=None,
+                timeout=timeout,
+                api_base=img_gen_api_base,
+                api_version=api_version,
+                api_key=api_key,
+                data=data,
+                headers=headers,
+            )
+            response = httpx_response.json()
+
+            stringified_response = response
+            ## LOGGING
+            logging_obj.post_call(
+                input=input,
+                api_key=api_key,
+                additional_args={"complete_input_dict": data},
+                original_response=stringified_response,
+            )
+            return convert_to_model_response_object(  # type: ignore
+                response_object=stringified_response,
+                model_response_object=model_response,
+                response_type="image_generation",
+            )
+        except Exception as e:
+            ## LOGGING
+            logging_obj.post_call(
+                input=input,
+                api_key=api_key,
+                additional_args={"complete_input_dict": data},
+                original_response=str(e),
+            )
+            raise e
+
+    def image_generation(
+        self,
+        prompt: str,
+        timeout: float,
+        optional_params: dict,
+        logging_obj: LiteLLMLoggingObj,
+        headers: dict,
+        model: Optional[str] = None,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+        api_version: Optional[str] = None,
+        model_response: Optional[ImageResponse] = None,
+        azure_ad_token: Optional[str] = None,
+        azure_ad_token_provider: Optional[Callable] = None,
+        client=None,
+        aimg_generation=None,
+        litellm_params: Optional[dict] = None,
+    ) -> ImageResponse:
+        try:
+            if model and len(model) > 0:
+                model = model
+            else:
+                model = None
+
+            ## BASE MODEL CHECK
+            if (
+                model_response is not None
+                and optional_params.get("base_model", None) is not None
+            ):
+                model_response._hidden_params["model"] = optional_params.pop(
+                    "base_model"
+                )
+
+            data = {"model": model, "prompt": prompt, **optional_params}
+            max_retries = data.pop("max_retries", 2)
+            if not isinstance(max_retries, int):
+                raise AzureOpenAIError(
+                    status_code=422, message="max retries must be an int"
+                )
+
+            # init AzureOpenAI Client
+            azure_client_params: Dict[str, Any] = self.initialize_azure_sdk_client(
+                litellm_params=litellm_params or {},
+                api_key=api_key,
+                model_name=model or "",
+                api_version=api_version,
+                api_base=api_base,
+                is_async=False,
+            )
+            if aimg_generation is True:
+                return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers)  # type: ignore
+
+            img_gen_api_base = self.create_azure_base_url(
+                azure_client_params=azure_client_params, model=data.get("model", "")
+            )
+
+            ## LOGGING
+            logging_obj.pre_call(
+                input=data["prompt"],
+                api_key=api_key,
+                additional_args={
+                    "complete_input_dict": data,
+                    "api_base": img_gen_api_base,
+                    "headers": headers,
+                },
+            )
+            httpx_response: httpx.Response = self.make_sync_azure_httpx_request(
+                client=None,
+                timeout=timeout,
+                api_base=img_gen_api_base,
+                api_version=api_version or "",
+                api_key=api_key or "",
+                data=data,
+                headers=headers,
+            )
+            response = httpx_response.json()
+
+            ## LOGGING
+            logging_obj.post_call(
+                input=prompt,
+                api_key=api_key,
+                additional_args={"complete_input_dict": data},
+                original_response=response,
+            )
+            # return response
+            return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation")  # type: ignore
+        except AzureOpenAIError as e:
+            raise e
+        except Exception as e:
+            error_code = getattr(e, "status_code", None)
+            if error_code is not None:
+                raise AzureOpenAIError(status_code=error_code, message=str(e))
+            else:
+                raise AzureOpenAIError(status_code=500, message=str(e))
+
+    def audio_speech(
+        self,
+        model: str,
+        input: str,
+        voice: str,
+        optional_params: dict,
+        api_key: Optional[str],
+        api_base: Optional[str],
+        api_version: Optional[str],
+        organization: Optional[str],
+        max_retries: int,
+        timeout: Union[float, httpx.Timeout],
+        azure_ad_token: Optional[str] = None,
+        azure_ad_token_provider: Optional[Callable] = None,
+        aspeech: Optional[bool] = None,
+        client=None,
+        litellm_params: Optional[dict] = None,
+    ) -> HttpxBinaryResponseContent:
+
+        max_retries = optional_params.pop("max_retries", 2)
+
+        if aspeech is not None and aspeech is True:
+            return self.async_audio_speech(
+                model=model,
+                input=input,
+                voice=voice,
+                optional_params=optional_params,
+                api_key=api_key,
+                api_base=api_base,
+                api_version=api_version,
+                azure_ad_token=azure_ad_token,
+                azure_ad_token_provider=azure_ad_token_provider,
+                max_retries=max_retries,
+                timeout=timeout,
+                client=client,
+                litellm_params=litellm_params,
+            )  # type: ignore
+
+        azure_client: AzureOpenAI = self.get_azure_openai_client(
+            api_base=api_base,
+            api_version=api_version,
+            api_key=api_key,
+            model=model,
+            _is_async=False,
+            client=client,
+            litellm_params=litellm_params,
+        )  # type: ignore
+
+        response = azure_client.audio.speech.create(
+            model=model,
+            voice=voice,  # type: ignore
+            input=input,
+            **optional_params,
+        )
+        return HttpxBinaryResponseContent(response=response.response)
+
+    async def async_audio_speech(
+        self,
+        model: str,
+        input: str,
+        voice: str,
+        optional_params: dict,
+        api_key: Optional[str],
+        api_base: Optional[str],
+        api_version: Optional[str],
+        azure_ad_token: Optional[str],
+        azure_ad_token_provider: Optional[Callable],
+        max_retries: int,
+        timeout: Union[float, httpx.Timeout],
+        client=None,
+        litellm_params: Optional[dict] = None,
+    ) -> HttpxBinaryResponseContent:
+
+        azure_client: AsyncAzureOpenAI = self.get_azure_openai_client(
+            api_base=api_base,
+            api_version=api_version,
+            api_key=api_key,
+            model=model,
+            _is_async=True,
+            client=client,
+            litellm_params=litellm_params,
+        )  # type: ignore
+
+        azure_response = await azure_client.audio.speech.create(
+            model=model,
+            voice=voice,  # type: ignore
+            input=input,
+            **optional_params,
+        )
+
+        return HttpxBinaryResponseContent(response=azure_response.response)
+
+    def get_headers(
+        self,
+        model: Optional[str],
+        api_key: str,
+        api_base: str,
+        api_version: str,
+        timeout: float,
+        mode: str,
+        messages: Optional[list] = None,
+        input: Optional[list] = None,
+        prompt: Optional[str] = None,
+    ) -> dict:
+        client_session = litellm.client_session or httpx.Client()
+        if "gateway.ai.cloudflare.com" in api_base:
+            ## build base url - assume api base includes resource name
+            if not api_base.endswith("/"):
+                api_base += "/"
+            api_base += f"{model}"
+            client = AzureOpenAI(
+                base_url=api_base,
+                api_version=api_version,
+                api_key=api_key,
+                timeout=timeout,
+                http_client=client_session,
+            )
+            model = None
+            # cloudflare ai gateway, needs model=None
+        else:
+            client = AzureOpenAI(
+                api_version=api_version,
+                azure_endpoint=api_base,
+                api_key=api_key,
+                timeout=timeout,
+                http_client=client_session,
+            )
+
+            # only run this check if it's not cloudflare ai gateway
+            if model is None and mode != "image_generation":
+                raise Exception("model is not set")
+
+        completion = None
+
+        if messages is None:
+            messages = [{"role": "user", "content": "Hey"}]
+        try:
+            completion = client.chat.completions.with_raw_response.create(
+                model=model,  # type: ignore
+                messages=messages,  # type: ignore
+            )
+        except Exception as e:
+            raise e
+        response = {}
+
+        if completion is None or not hasattr(completion, "headers"):
+            raise Exception("invalid completion response")
+
+        if (
+            completion.headers.get("x-ratelimit-remaining-requests", None) is not None
+        ):  # not provided for dall-e requests
+            response["x-ratelimit-remaining-requests"] = completion.headers[
+                "x-ratelimit-remaining-requests"
+            ]
+
+        if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
+            response["x-ratelimit-remaining-tokens"] = completion.headers[
+                "x-ratelimit-remaining-tokens"
+            ]
+
+        if completion.headers.get("x-ms-region", None) is not None:
+            response["x-ms-region"] = completion.headers["x-ms-region"]
+
+        return response