aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/azure/azure.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/azure/azure.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/azure/azure.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/azure/azure.py1347
1 files changed, 1347 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure/azure.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure/azure.py
new file mode 100644
index 00000000..03c5cc09
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure/azure.py
@@ -0,0 +1,1347 @@
+import asyncio
+import json
+import time
+from typing import Any, Callable, Coroutine, Dict, List, Optional, Union
+
+import httpx # type: ignore
+from openai import APITimeoutError, AsyncAzureOpenAI, AzureOpenAI
+
+import litellm
+from litellm.constants import DEFAULT_MAX_RETRIES
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ get_async_httpx_client,
+)
+from litellm.types.utils import (
+ EmbeddingResponse,
+ ImageResponse,
+ LlmProviders,
+ ModelResponse,
+)
+from litellm.utils import (
+ CustomStreamWrapper,
+ convert_to_model_response_object,
+ modify_url,
+)
+
+from ...types.llms.openai import HttpxBinaryResponseContent
+from ..base import BaseLLM
+from .common_utils import (
+ AzureOpenAIError,
+ BaseAzureLLM,
+ get_azure_ad_token_from_oidc,
+ process_azure_headers,
+ select_azure_base_url_or_endpoint,
+)
+
+
+class AzureOpenAIAssistantsAPIConfig:
+ """
+ Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
+ """
+
+ def __init__(
+ self,
+ ) -> None:
+ pass
+
+ def get_supported_openai_create_message_params(self):
+ return [
+ "role",
+ "content",
+ "attachments",
+ "metadata",
+ ]
+
+ def map_openai_params_create_message_params(
+ self, non_default_params: dict, optional_params: dict
+ ):
+ for param, value in non_default_params.items():
+ if param == "role":
+ optional_params["role"] = value
+ if param == "metadata":
+ optional_params["metadata"] = value
+ elif param == "content": # only string accepted
+ if isinstance(value, str):
+ optional_params["content"] = value
+ else:
+ raise litellm.utils.UnsupportedParamsError(
+ message="Azure only accepts content as a string.",
+ status_code=400,
+ )
+ elif (
+ param == "attachments"
+ ): # this is a v2 param. Azure currently supports the old 'file_id's param
+ file_ids: List[str] = []
+ if isinstance(value, list):
+ for item in value:
+ if "file_id" in item:
+ file_ids.append(item["file_id"])
+ else:
+ if litellm.drop_params is True:
+ pass
+ else:
+ raise litellm.utils.UnsupportedParamsError(
+ message="Azure doesn't support {}. To drop it from the call, set `litellm.drop_params = True.".format(
+ value
+ ),
+ status_code=400,
+ )
+ else:
+ raise litellm.utils.UnsupportedParamsError(
+ message="Invalid param. attachments should always be a list. Got={}, Expected=List. Raw value={}".format(
+ type(value), value
+ ),
+ status_code=400,
+ )
+ return optional_params
+
+
+def _check_dynamic_azure_params(
+ azure_client_params: dict,
+ azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]],
+) -> bool:
+ """
+ Returns True if user passed in client params != initialized azure client
+
+ Currently only implemented for api version
+ """
+ if azure_client is None:
+ return True
+
+ dynamic_params = ["api_version"]
+ for k, v in azure_client_params.items():
+ if k in dynamic_params and k == "api_version":
+ if v is not None and v != azure_client._custom_query["api-version"]:
+ return True
+
+ return False
+
+
+class AzureChatCompletion(BaseAzureLLM, BaseLLM):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def validate_environment(self, api_key, azure_ad_token, azure_ad_token_provider):
+ headers = {
+ "content-type": "application/json",
+ }
+ if api_key is not None:
+ headers["api-key"] = api_key
+ elif azure_ad_token is not None:
+ if azure_ad_token.startswith("oidc/"):
+ azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
+ headers["Authorization"] = f"Bearer {azure_ad_token}"
+ elif azure_ad_token_provider is not None:
+ azure_ad_token = azure_ad_token_provider()
+ headers["Authorization"] = f"Bearer {azure_ad_token}"
+
+ return headers
+
+ def make_sync_azure_openai_chat_completion_request(
+ self,
+ azure_client: AzureOpenAI,
+ data: dict,
+ timeout: Union[float, httpx.Timeout],
+ ):
+ """
+ Helper to:
+ - call chat.completions.create.with_raw_response when litellm.return_response_headers is True
+ - call chat.completions.create by default
+ """
+ try:
+ raw_response = azure_client.chat.completions.with_raw_response.create(
+ **data, timeout=timeout
+ )
+
+ headers = dict(raw_response.headers)
+ response = raw_response.parse()
+ return headers, response
+ except Exception as e:
+ raise e
+
+ @track_llm_api_timing()
+ async def make_azure_openai_chat_completion_request(
+ self,
+ azure_client: AsyncAzureOpenAI,
+ data: dict,
+ timeout: Union[float, httpx.Timeout],
+ logging_obj: LiteLLMLoggingObj,
+ ):
+ """
+ Helper to:
+ - call chat.completions.create.with_raw_response when litellm.return_response_headers is True
+ - call chat.completions.create by default
+ """
+ start_time = time.time()
+ try:
+ raw_response = await azure_client.chat.completions.with_raw_response.create(
+ **data, timeout=timeout
+ )
+
+ headers = dict(raw_response.headers)
+ response = raw_response.parse()
+ return headers, response
+ except APITimeoutError as e:
+ end_time = time.time()
+ time_delta = round(end_time - start_time, 2)
+ e.message += f" - timeout value={timeout}, time taken={time_delta} seconds"
+ raise e
+ except Exception as e:
+ raise e
+
+ def completion( # noqa: PLR0915
+ self,
+ model: str,
+ messages: list,
+ model_response: ModelResponse,
+ api_key: str,
+ api_base: str,
+ api_version: str,
+ api_type: str,
+ azure_ad_token: str,
+ azure_ad_token_provider: Callable,
+ dynamic_params: bool,
+ print_verbose: Callable,
+ timeout: Union[float, httpx.Timeout],
+ logging_obj: LiteLLMLoggingObj,
+ optional_params,
+ litellm_params,
+ logger_fn,
+ acompletion: bool = False,
+ headers: Optional[dict] = None,
+ client=None,
+ ):
+ if headers:
+ optional_params["extra_headers"] = headers
+ try:
+ if model is None or messages is None:
+ raise AzureOpenAIError(
+ status_code=422, message="Missing model or messages"
+ )
+
+ max_retries = optional_params.pop("max_retries", None)
+ if max_retries is None:
+ max_retries = DEFAULT_MAX_RETRIES
+ json_mode: Optional[bool] = optional_params.pop("json_mode", False)
+
+ ### CHECK IF CLOUDFLARE AI GATEWAY ###
+ ### if so - set the model as part of the base url
+ if "gateway.ai.cloudflare.com" in api_base:
+ client = self._init_azure_client_for_cloudflare_ai_gateway(
+ api_base=api_base,
+ model=model,
+ api_version=api_version,
+ max_retries=max_retries,
+ timeout=timeout,
+ api_key=api_key,
+ azure_ad_token=azure_ad_token,
+ azure_ad_token_provider=azure_ad_token_provider,
+ acompletion=acompletion,
+ client=client,
+ )
+
+ data = {"model": None, "messages": messages, **optional_params}
+ else:
+ data = litellm.AzureOpenAIConfig().transform_request(
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ headers=headers or {},
+ )
+
+ if acompletion is True:
+ if optional_params.get("stream", False):
+ return self.async_streaming(
+ logging_obj=logging_obj,
+ api_base=api_base,
+ dynamic_params=dynamic_params,
+ data=data,
+ model=model,
+ api_key=api_key,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ azure_ad_token_provider=azure_ad_token_provider,
+ timeout=timeout,
+ client=client,
+ max_retries=max_retries,
+ litellm_params=litellm_params,
+ )
+ else:
+ return self.acompletion(
+ api_base=api_base,
+ data=data,
+ model_response=model_response,
+ api_key=api_key,
+ api_version=api_version,
+ model=model,
+ azure_ad_token=azure_ad_token,
+ azure_ad_token_provider=azure_ad_token_provider,
+ dynamic_params=dynamic_params,
+ timeout=timeout,
+ client=client,
+ logging_obj=logging_obj,
+ max_retries=max_retries,
+ convert_tool_call_to_json_mode=json_mode,
+ litellm_params=litellm_params,
+ )
+ elif "stream" in optional_params and optional_params["stream"] is True:
+ return self.streaming(
+ logging_obj=logging_obj,
+ api_base=api_base,
+ dynamic_params=dynamic_params,
+ data=data,
+ model=model,
+ api_key=api_key,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ azure_ad_token_provider=azure_ad_token_provider,
+ timeout=timeout,
+ client=client,
+ max_retries=max_retries,
+ litellm_params=litellm_params,
+ )
+ else:
+ ## LOGGING
+ logging_obj.pre_call(
+ input=messages,
+ api_key=api_key,
+ additional_args={
+ "headers": {
+ "api_key": api_key,
+ "azure_ad_token": azure_ad_token,
+ },
+ "api_version": api_version,
+ "api_base": api_base,
+ "complete_input_dict": data,
+ },
+ )
+ if not isinstance(max_retries, int):
+ raise AzureOpenAIError(
+ status_code=422, message="max retries must be an int"
+ )
+ # init AzureOpenAI Client
+ azure_client = self.get_azure_openai_client(
+ api_version=api_version,
+ api_base=api_base,
+ api_key=api_key,
+ model=model,
+ client=client,
+ _is_async=False,
+ litellm_params=litellm_params,
+ )
+ if not isinstance(azure_client, AzureOpenAI):
+ raise AzureOpenAIError(
+ status_code=500,
+ message="azure_client is not an instance of AzureOpenAI",
+ )
+
+ headers, response = self.make_sync_azure_openai_chat_completion_request(
+ azure_client=azure_client, data=data, timeout=timeout
+ )
+ stringified_response = response.model_dump()
+ ## LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key=api_key,
+ original_response=stringified_response,
+ additional_args={
+ "headers": headers,
+ "api_version": api_version,
+ "api_base": api_base,
+ },
+ )
+ return convert_to_model_response_object(
+ response_object=stringified_response,
+ model_response_object=model_response,
+ convert_tool_call_to_json_mode=json_mode,
+ _response_headers=headers,
+ )
+ except AzureOpenAIError as e:
+ raise e
+ except Exception as e:
+ status_code = getattr(e, "status_code", 500)
+ error_headers = getattr(e, "headers", None)
+ error_response = getattr(e, "response", None)
+ error_body = getattr(e, "body", None)
+ if error_headers is None and error_response:
+ error_headers = getattr(error_response, "headers", None)
+ raise AzureOpenAIError(
+ status_code=status_code,
+ message=str(e),
+ headers=error_headers,
+ body=error_body,
+ )
+
+ async def acompletion(
+ self,
+ api_key: str,
+ api_version: str,
+ model: str,
+ api_base: str,
+ data: dict,
+ timeout: Any,
+ dynamic_params: bool,
+ model_response: ModelResponse,
+ logging_obj: LiteLLMLoggingObj,
+ max_retries: int,
+ azure_ad_token: Optional[str] = None,
+ azure_ad_token_provider: Optional[Callable] = None,
+ convert_tool_call_to_json_mode: Optional[bool] = None,
+ client=None, # this is the AsyncAzureOpenAI
+ litellm_params: Optional[dict] = {},
+ ):
+ response = None
+ try:
+ # setting Azure client
+ azure_client = self.get_azure_openai_client(
+ api_version=api_version,
+ api_base=api_base,
+ api_key=api_key,
+ model=model,
+ client=client,
+ _is_async=True,
+ litellm_params=litellm_params,
+ )
+ if not isinstance(azure_client, AsyncAzureOpenAI):
+ raise ValueError("Azure client is not an instance of AsyncAzureOpenAI")
+ ## LOGGING
+ logging_obj.pre_call(
+ input=data["messages"],
+ api_key=azure_client.api_key,
+ additional_args={
+ "headers": {
+ "api_key": api_key,
+ "azure_ad_token": azure_ad_token,
+ },
+ "api_base": azure_client._base_url._uri_reference,
+ "acompletion": True,
+ "complete_input_dict": data,
+ },
+ )
+
+ headers, response = await self.make_azure_openai_chat_completion_request(
+ azure_client=azure_client,
+ data=data,
+ timeout=timeout,
+ logging_obj=logging_obj,
+ )
+ logging_obj.model_call_details["response_headers"] = headers
+
+ stringified_response = response.model_dump()
+ logging_obj.post_call(
+ input=data["messages"],
+ api_key=api_key,
+ original_response=stringified_response,
+ additional_args={"complete_input_dict": data},
+ )
+
+ return convert_to_model_response_object(
+ response_object=stringified_response,
+ model_response_object=model_response,
+ hidden_params={"headers": headers},
+ _response_headers=headers,
+ convert_tool_call_to_json_mode=convert_tool_call_to_json_mode,
+ )
+ except AzureOpenAIError as e:
+ ## LOGGING
+ logging_obj.post_call(
+ input=data["messages"],
+ api_key=api_key,
+ additional_args={"complete_input_dict": data},
+ original_response=str(e),
+ )
+ raise e
+ except asyncio.CancelledError as e:
+ ## LOGGING
+ logging_obj.post_call(
+ input=data["messages"],
+ api_key=api_key,
+ additional_args={"complete_input_dict": data},
+ original_response=str(e),
+ )
+ raise AzureOpenAIError(status_code=500, message=str(e))
+ except Exception as e:
+ message = getattr(e, "message", str(e))
+ body = getattr(e, "body", None)
+ ## LOGGING
+ logging_obj.post_call(
+ input=data["messages"],
+ api_key=api_key,
+ additional_args={"complete_input_dict": data},
+ original_response=str(e),
+ )
+ if hasattr(e, "status_code"):
+ raise e
+ else:
+ raise AzureOpenAIError(status_code=500, message=message, body=body)
+
+ def streaming(
+ self,
+ logging_obj,
+ api_base: str,
+ api_key: str,
+ api_version: str,
+ dynamic_params: bool,
+ data: dict,
+ model: str,
+ timeout: Any,
+ max_retries: int,
+ azure_ad_token: Optional[str] = None,
+ azure_ad_token_provider: Optional[Callable] = None,
+ client=None,
+ litellm_params: Optional[dict] = {},
+ ):
+ # init AzureOpenAI Client
+ azure_client_params = {
+ "api_version": api_version,
+ "azure_endpoint": api_base,
+ "azure_deployment": model,
+ "http_client": litellm.client_session,
+ "max_retries": max_retries,
+ "timeout": timeout,
+ }
+ azure_client_params = select_azure_base_url_or_endpoint(
+ azure_client_params=azure_client_params
+ )
+ if api_key is not None:
+ azure_client_params["api_key"] = api_key
+ elif azure_ad_token is not None:
+ if azure_ad_token.startswith("oidc/"):
+ azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
+ azure_client_params["azure_ad_token"] = azure_ad_token
+ elif azure_ad_token_provider is not None:
+ azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
+
+ azure_client = self.get_azure_openai_client(
+ api_version=api_version,
+ api_base=api_base,
+ api_key=api_key,
+ model=model,
+ client=client,
+ _is_async=False,
+ litellm_params=litellm_params,
+ )
+ if not isinstance(azure_client, AzureOpenAI):
+ raise AzureOpenAIError(
+ status_code=500,
+ message="azure_client is not an instance of AzureOpenAI",
+ )
+ ## LOGGING
+ logging_obj.pre_call(
+ input=data["messages"],
+ api_key=azure_client.api_key,
+ additional_args={
+ "headers": {
+ "api_key": api_key,
+ "azure_ad_token": azure_ad_token,
+ },
+ "api_base": azure_client._base_url._uri_reference,
+ "acompletion": True,
+ "complete_input_dict": data,
+ },
+ )
+ headers, response = self.make_sync_azure_openai_chat_completion_request(
+ azure_client=azure_client, data=data, timeout=timeout
+ )
+ streamwrapper = CustomStreamWrapper(
+ completion_stream=response,
+ model=model,
+ custom_llm_provider="azure",
+ logging_obj=logging_obj,
+ stream_options=data.get("stream_options", None),
+ _response_headers=process_azure_headers(headers),
+ )
+ return streamwrapper
+
+ async def async_streaming(
+ self,
+ logging_obj: LiteLLMLoggingObj,
+ api_base: str,
+ api_key: str,
+ api_version: str,
+ dynamic_params: bool,
+ data: dict,
+ model: str,
+ timeout: Any,
+ max_retries: int,
+ azure_ad_token: Optional[str] = None,
+ azure_ad_token_provider: Optional[Callable] = None,
+ client=None,
+ litellm_params: Optional[dict] = {},
+ ):
+ try:
+ azure_client = self.get_azure_openai_client(
+ api_version=api_version,
+ api_base=api_base,
+ api_key=api_key,
+ model=model,
+ client=client,
+ _is_async=True,
+ litellm_params=litellm_params,
+ )
+ if not isinstance(azure_client, AsyncAzureOpenAI):
+ raise ValueError("Azure client is not an instance of AsyncAzureOpenAI")
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=data["messages"],
+ api_key=azure_client.api_key,
+ additional_args={
+ "headers": {
+ "api_key": api_key,
+ "azure_ad_token": azure_ad_token,
+ },
+ "api_base": azure_client._base_url._uri_reference,
+ "acompletion": True,
+ "complete_input_dict": data,
+ },
+ )
+
+ headers, response = await self.make_azure_openai_chat_completion_request(
+ azure_client=azure_client,
+ data=data,
+ timeout=timeout,
+ logging_obj=logging_obj,
+ )
+ logging_obj.model_call_details["response_headers"] = headers
+
+ # return response
+ streamwrapper = CustomStreamWrapper(
+ completion_stream=response,
+ model=model,
+ custom_llm_provider="azure",
+ logging_obj=logging_obj,
+ stream_options=data.get("stream_options", None),
+ _response_headers=headers,
+ )
+ return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
+ except Exception as e:
+ status_code = getattr(e, "status_code", 500)
+ error_headers = getattr(e, "headers", None)
+ error_response = getattr(e, "response", None)
+ message = getattr(e, "message", str(e))
+ error_body = getattr(e, "body", None)
+ if error_headers is None and error_response:
+ error_headers = getattr(error_response, "headers", None)
+ raise AzureOpenAIError(
+ status_code=status_code,
+ message=message,
+ headers=error_headers,
+ body=error_body,
+ )
+
+ async def aembedding(
+ self,
+ model: str,
+ data: dict,
+ model_response: EmbeddingResponse,
+ input: list,
+ logging_obj: LiteLLMLoggingObj,
+ api_base: str,
+ api_key: Optional[str] = None,
+ api_version: Optional[str] = None,
+ client: Optional[AsyncAzureOpenAI] = None,
+ timeout: Optional[Union[float, httpx.Timeout]] = None,
+ max_retries: Optional[int] = None,
+ azure_ad_token: Optional[str] = None,
+ azure_ad_token_provider: Optional[Callable] = None,
+ litellm_params: Optional[dict] = {},
+ ) -> EmbeddingResponse:
+ response = None
+ try:
+
+ openai_aclient = self.get_azure_openai_client(
+ api_version=api_version,
+ api_base=api_base,
+ api_key=api_key,
+ model=model,
+ _is_async=True,
+ client=client,
+ litellm_params=litellm_params,
+ )
+ if not isinstance(openai_aclient, AsyncAzureOpenAI):
+ raise ValueError("Azure client is not an instance of AsyncAzureOpenAI")
+
+ raw_response = await openai_aclient.embeddings.with_raw_response.create(
+ **data, timeout=timeout
+ )
+ headers = dict(raw_response.headers)
+ response = raw_response.parse()
+ stringified_response = response.model_dump()
+ ## LOGGING
+ logging_obj.post_call(
+ input=input,
+ api_key=api_key,
+ additional_args={"complete_input_dict": data},
+ original_response=stringified_response,
+ )
+ embedding_response = convert_to_model_response_object(
+ response_object=stringified_response,
+ model_response_object=model_response,
+ hidden_params={"headers": headers},
+ _response_headers=process_azure_headers(headers),
+ response_type="embedding",
+ )
+ if not isinstance(embedding_response, EmbeddingResponse):
+ raise AzureOpenAIError(
+ status_code=500,
+ message="embedding_response is not an instance of EmbeddingResponse",
+ )
+ return embedding_response
+ except Exception as e:
+ ## LOGGING
+ logging_obj.post_call(
+ input=input,
+ api_key=api_key,
+ additional_args={"complete_input_dict": data},
+ original_response=str(e),
+ )
+ raise e
+
+ def embedding(
+ self,
+ model: str,
+ input: list,
+ api_base: str,
+ api_version: str,
+ timeout: float,
+ logging_obj: LiteLLMLoggingObj,
+ model_response: EmbeddingResponse,
+ optional_params: dict,
+ api_key: Optional[str] = None,
+ azure_ad_token: Optional[str] = None,
+ azure_ad_token_provider: Optional[Callable] = None,
+ max_retries: Optional[int] = None,
+ client=None,
+ aembedding=None,
+ headers: Optional[dict] = None,
+ litellm_params: Optional[dict] = None,
+ ) -> Union[EmbeddingResponse, Coroutine[Any, Any, EmbeddingResponse]]:
+ if headers:
+ optional_params["extra_headers"] = headers
+ if self._client_session is None:
+ self._client_session = self.create_client_session()
+ try:
+ data = {"model": model, "input": input, **optional_params}
+ if max_retries is None:
+ max_retries = litellm.DEFAULT_MAX_RETRIES
+ ## LOGGING
+ logging_obj.pre_call(
+ input=input,
+ api_key=api_key,
+ additional_args={
+ "complete_input_dict": data,
+ "headers": {"api_key": api_key, "azure_ad_token": azure_ad_token},
+ },
+ )
+
+ if aembedding is True:
+ return self.aembedding(
+ data=data,
+ input=input,
+ model=model,
+ logging_obj=logging_obj,
+ api_key=api_key,
+ model_response=model_response,
+ timeout=timeout,
+ client=client,
+ litellm_params=litellm_params,
+ api_base=api_base,
+ )
+ azure_client = self.get_azure_openai_client(
+ api_version=api_version,
+ api_base=api_base,
+ api_key=api_key,
+ model=model,
+ _is_async=False,
+ client=client,
+ litellm_params=litellm_params,
+ )
+ if not isinstance(azure_client, AzureOpenAI):
+ raise AzureOpenAIError(
+ status_code=500,
+ message="azure_client is not an instance of AzureOpenAI",
+ )
+
+ ## COMPLETION CALL
+ raw_response = azure_client.embeddings.with_raw_response.create(**data, timeout=timeout) # type: ignore
+ headers = dict(raw_response.headers)
+ response = raw_response.parse()
+ ## LOGGING
+ logging_obj.post_call(
+ input=input,
+ api_key=api_key,
+ additional_args={"complete_input_dict": data, "api_base": api_base},
+ original_response=response,
+ )
+
+ return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding", _response_headers=process_azure_headers(headers)) # type: ignore
+ except AzureOpenAIError as e:
+ raise e
+ except Exception as e:
+ status_code = getattr(e, "status_code", 500)
+ error_headers = getattr(e, "headers", None)
+ error_response = getattr(e, "response", None)
+ if error_headers is None and error_response:
+ error_headers = getattr(error_response, "headers", None)
+ raise AzureOpenAIError(
+ status_code=status_code, message=str(e), headers=error_headers
+ )
+
+ async def make_async_azure_httpx_request(
+ self,
+ client: Optional[AsyncHTTPHandler],
+ timeout: Optional[Union[float, httpx.Timeout]],
+ api_base: str,
+ api_version: str,
+ api_key: str,
+ data: dict,
+ headers: dict,
+ ) -> httpx.Response:
+ """
+ Implemented for azure dall-e-2 image gen calls
+
+ Alternative to needing a custom transport implementation
+ """
+ if client is None:
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ _httpx_timeout = httpx.Timeout(timeout)
+ _params["timeout"] = _httpx_timeout
+ else:
+ _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
+
+ async_handler = get_async_httpx_client(
+ llm_provider=LlmProviders.AZURE,
+ params=_params,
+ )
+ else:
+ async_handler = client # type: ignore
+
+ if (
+ "images/generations" in api_base
+ and api_version
+ in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
+ "2023-06-01-preview",
+ "2023-07-01-preview",
+ "2023-08-01-preview",
+ "2023-09-01-preview",
+ "2023-10-01-preview",
+ ]
+ ): # CREATE + POLL for azure dall-e-2 calls
+
+ api_base = modify_url(
+ original_url=api_base, new_path="/openai/images/generations:submit"
+ )
+
+ data.pop(
+ "model", None
+ ) # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview
+ response = await async_handler.post(
+ url=api_base,
+ data=json.dumps(data),
+ headers=headers,
+ )
+ if "operation-location" in response.headers:
+ operation_location_url = response.headers["operation-location"]
+ else:
+ raise AzureOpenAIError(status_code=500, message=response.text)
+ response = await async_handler.get(
+ url=operation_location_url,
+ headers=headers,
+ )
+
+ await response.aread()
+
+ timeout_secs: int = 120
+ start_time = time.time()
+ if "status" not in response.json():
+ raise Exception(
+ "Expected 'status' in response. Got={}".format(response.json())
+ )
+ while response.json()["status"] not in ["succeeded", "failed"]:
+ if time.time() - start_time > timeout_secs:
+
+ raise AzureOpenAIError(
+ status_code=408, message="Operation polling timed out."
+ )
+
+ await asyncio.sleep(int(response.headers.get("retry-after") or 10))
+ response = await async_handler.get(
+ url=operation_location_url,
+ headers=headers,
+ )
+ await response.aread()
+
+ if response.json()["status"] == "failed":
+ error_data = response.json()
+ raise AzureOpenAIError(status_code=400, message=json.dumps(error_data))
+
+ result = response.json()["result"]
+ return httpx.Response(
+ status_code=200,
+ headers=response.headers,
+ content=json.dumps(result).encode("utf-8"),
+ request=httpx.Request(method="POST", url="https://api.openai.com/v1"),
+ )
+ return await async_handler.post(
+ url=api_base,
+ json=data,
+ headers=headers,
+ )
+
+ def make_sync_azure_httpx_request(
+ self,
+ client: Optional[HTTPHandler],
+ timeout: Optional[Union[float, httpx.Timeout]],
+ api_base: str,
+ api_version: str,
+ api_key: str,
+ data: dict,
+ headers: dict,
+ ) -> httpx.Response:
+ """
+ Implemented for azure dall-e-2 image gen calls
+
+ Alternative to needing a custom transport implementation
+ """
+ if client is None:
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ _httpx_timeout = httpx.Timeout(timeout)
+ _params["timeout"] = _httpx_timeout
+ else:
+ _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
+
+ sync_handler = HTTPHandler(**_params, client=litellm.client_session) # type: ignore
+ else:
+ sync_handler = client # type: ignore
+
+ if (
+ "images/generations" in api_base
+ and api_version
+ in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict
+ "2023-06-01-preview",
+ "2023-07-01-preview",
+ "2023-08-01-preview",
+ "2023-09-01-preview",
+ "2023-10-01-preview",
+ ]
+ ): # CREATE + POLL for azure dall-e-2 calls
+
+ api_base = modify_url(
+ original_url=api_base, new_path="/openai/images/generations:submit"
+ )
+
+ data.pop(
+ "model", None
+ ) # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview
+ response = sync_handler.post(
+ url=api_base,
+ data=json.dumps(data),
+ headers=headers,
+ )
+ if "operation-location" in response.headers:
+ operation_location_url = response.headers["operation-location"]
+ else:
+ raise AzureOpenAIError(status_code=500, message=response.text)
+ response = sync_handler.get(
+ url=operation_location_url,
+ headers=headers,
+ )
+
+ response.read()
+
+ timeout_secs: int = 120
+ start_time = time.time()
+ if "status" not in response.json():
+ raise Exception(
+ "Expected 'status' in response. Got={}".format(response.json())
+ )
+ while response.json()["status"] not in ["succeeded", "failed"]:
+ if time.time() - start_time > timeout_secs:
+ raise AzureOpenAIError(
+ status_code=408, message="Operation polling timed out."
+ )
+
+ time.sleep(int(response.headers.get("retry-after") or 10))
+ response = sync_handler.get(
+ url=operation_location_url,
+ headers=headers,
+ )
+ response.read()
+
+ if response.json()["status"] == "failed":
+ error_data = response.json()
+ raise AzureOpenAIError(status_code=400, message=json.dumps(error_data))
+
+ result = response.json()["result"]
+ return httpx.Response(
+ status_code=200,
+ headers=response.headers,
+ content=json.dumps(result).encode("utf-8"),
+ request=httpx.Request(method="POST", url="https://api.openai.com/v1"),
+ )
+ return sync_handler.post(
+ url=api_base,
+ json=data,
+ headers=headers,
+ )
+
+ def create_azure_base_url(
+ self, azure_client_params: dict, model: Optional[str]
+ ) -> str:
+ api_base: str = azure_client_params.get(
+ "azure_endpoint", ""
+ ) # "https://example-endpoint.openai.azure.com"
+ if api_base.endswith("/"):
+ api_base = api_base.rstrip("/")
+ api_version: str = azure_client_params.get("api_version", "")
+ if model is None:
+ model = ""
+
+ if "/openai/deployments/" in api_base:
+ base_url_with_deployment = api_base
+ else:
+ base_url_with_deployment = api_base + "/openai/deployments/" + model
+
+ base_url_with_deployment += "/images/generations"
+ base_url_with_deployment += "?api-version=" + api_version
+
+ return base_url_with_deployment
+
+ async def aimage_generation(
+ self,
+ data: dict,
+ model_response: ModelResponse,
+ azure_client_params: dict,
+ api_key: str,
+ input: list,
+ logging_obj: LiteLLMLoggingObj,
+ headers: dict,
+ client=None,
+ timeout=None,
+ ) -> litellm.ImageResponse:
+ response: Optional[dict] = None
+ try:
+ # response = await azure_client.images.generate(**data, timeout=timeout)
+ api_base: str = azure_client_params.get(
+ "api_base", ""
+ ) # "https://example-endpoint.openai.azure.com"
+ if api_base.endswith("/"):
+ api_base = api_base.rstrip("/")
+ api_version: str = azure_client_params.get("api_version", "")
+ img_gen_api_base = self.create_azure_base_url(
+ azure_client_params=azure_client_params, model=data.get("model", "")
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=data["prompt"],
+ api_key=api_key,
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": img_gen_api_base,
+ "headers": headers,
+ },
+ )
+ httpx_response: httpx.Response = await self.make_async_azure_httpx_request(
+ client=None,
+ timeout=timeout,
+ api_base=img_gen_api_base,
+ api_version=api_version,
+ api_key=api_key,
+ data=data,
+ headers=headers,
+ )
+ response = httpx_response.json()
+
+ stringified_response = response
+ ## LOGGING
+ logging_obj.post_call(
+ input=input,
+ api_key=api_key,
+ additional_args={"complete_input_dict": data},
+ original_response=stringified_response,
+ )
+ return convert_to_model_response_object( # type: ignore
+ response_object=stringified_response,
+ model_response_object=model_response,
+ response_type="image_generation",
+ )
+ except Exception as e:
+ ## LOGGING
+ logging_obj.post_call(
+ input=input,
+ api_key=api_key,
+ additional_args={"complete_input_dict": data},
+ original_response=str(e),
+ )
+ raise e
+
+ def image_generation(
+ self,
+ prompt: str,
+ timeout: float,
+ optional_params: dict,
+ logging_obj: LiteLLMLoggingObj,
+ headers: dict,
+ model: Optional[str] = None,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ api_version: Optional[str] = None,
+ model_response: Optional[ImageResponse] = None,
+ azure_ad_token: Optional[str] = None,
+ azure_ad_token_provider: Optional[Callable] = None,
+ client=None,
+ aimg_generation=None,
+ litellm_params: Optional[dict] = None,
+ ) -> ImageResponse:
+ try:
+ if model and len(model) > 0:
+ model = model
+ else:
+ model = None
+
+ ## BASE MODEL CHECK
+ if (
+ model_response is not None
+ and optional_params.get("base_model", None) is not None
+ ):
+ model_response._hidden_params["model"] = optional_params.pop(
+ "base_model"
+ )
+
+ data = {"model": model, "prompt": prompt, **optional_params}
+ max_retries = data.pop("max_retries", 2)
+ if not isinstance(max_retries, int):
+ raise AzureOpenAIError(
+ status_code=422, message="max retries must be an int"
+ )
+
+ # init AzureOpenAI Client
+ azure_client_params: Dict[str, Any] = self.initialize_azure_sdk_client(
+ litellm_params=litellm_params or {},
+ api_key=api_key,
+ model_name=model or "",
+ api_version=api_version,
+ api_base=api_base,
+ is_async=False,
+ )
+ if aimg_generation is True:
+ return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore
+
+ img_gen_api_base = self.create_azure_base_url(
+ azure_client_params=azure_client_params, model=data.get("model", "")
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=data["prompt"],
+ api_key=api_key,
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": img_gen_api_base,
+ "headers": headers,
+ },
+ )
+ httpx_response: httpx.Response = self.make_sync_azure_httpx_request(
+ client=None,
+ timeout=timeout,
+ api_base=img_gen_api_base,
+ api_version=api_version or "",
+ api_key=api_key or "",
+ data=data,
+ headers=headers,
+ )
+ response = httpx_response.json()
+
+ ## LOGGING
+ logging_obj.post_call(
+ input=prompt,
+ api_key=api_key,
+ additional_args={"complete_input_dict": data},
+ original_response=response,
+ )
+ # return response
+ return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
+ except AzureOpenAIError as e:
+ raise e
+ except Exception as e:
+ error_code = getattr(e, "status_code", None)
+ if error_code is not None:
+ raise AzureOpenAIError(status_code=error_code, message=str(e))
+ else:
+ raise AzureOpenAIError(status_code=500, message=str(e))
+
+ def audio_speech(
+ self,
+ model: str,
+ input: str,
+ voice: str,
+ optional_params: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ organization: Optional[str],
+ max_retries: int,
+ timeout: Union[float, httpx.Timeout],
+ azure_ad_token: Optional[str] = None,
+ azure_ad_token_provider: Optional[Callable] = None,
+ aspeech: Optional[bool] = None,
+ client=None,
+ litellm_params: Optional[dict] = None,
+ ) -> HttpxBinaryResponseContent:
+
+ max_retries = optional_params.pop("max_retries", 2)
+
+ if aspeech is not None and aspeech is True:
+ return self.async_audio_speech(
+ model=model,
+ input=input,
+ voice=voice,
+ optional_params=optional_params,
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ azure_ad_token_provider=azure_ad_token_provider,
+ max_retries=max_retries,
+ timeout=timeout,
+ client=client,
+ litellm_params=litellm_params,
+ ) # type: ignore
+
+ azure_client: AzureOpenAI = self.get_azure_openai_client(
+ api_base=api_base,
+ api_version=api_version,
+ api_key=api_key,
+ model=model,
+ _is_async=False,
+ client=client,
+ litellm_params=litellm_params,
+ ) # type: ignore
+
+ response = azure_client.audio.speech.create(
+ model=model,
+ voice=voice, # type: ignore
+ input=input,
+ **optional_params,
+ )
+ return HttpxBinaryResponseContent(response=response.response)
+
+ async def async_audio_speech(
+ self,
+ model: str,
+ input: str,
+ voice: str,
+ optional_params: dict,
+ api_key: Optional[str],
+ api_base: Optional[str],
+ api_version: Optional[str],
+ azure_ad_token: Optional[str],
+ azure_ad_token_provider: Optional[Callable],
+ max_retries: int,
+ timeout: Union[float, httpx.Timeout],
+ client=None,
+ litellm_params: Optional[dict] = None,
+ ) -> HttpxBinaryResponseContent:
+
+ azure_client: AsyncAzureOpenAI = self.get_azure_openai_client(
+ api_base=api_base,
+ api_version=api_version,
+ api_key=api_key,
+ model=model,
+ _is_async=True,
+ client=client,
+ litellm_params=litellm_params,
+ ) # type: ignore
+
+ azure_response = await azure_client.audio.speech.create(
+ model=model,
+ voice=voice, # type: ignore
+ input=input,
+ **optional_params,
+ )
+
+ return HttpxBinaryResponseContent(response=azure_response.response)
+
+ def get_headers(
+ self,
+ model: Optional[str],
+ api_key: str,
+ api_base: str,
+ api_version: str,
+ timeout: float,
+ mode: str,
+ messages: Optional[list] = None,
+ input: Optional[list] = None,
+ prompt: Optional[str] = None,
+ ) -> dict:
+ client_session = litellm.client_session or httpx.Client()
+ if "gateway.ai.cloudflare.com" in api_base:
+ ## build base url - assume api base includes resource name
+ if not api_base.endswith("/"):
+ api_base += "/"
+ api_base += f"{model}"
+ client = AzureOpenAI(
+ base_url=api_base,
+ api_version=api_version,
+ api_key=api_key,
+ timeout=timeout,
+ http_client=client_session,
+ )
+ model = None
+ # cloudflare ai gateway, needs model=None
+ else:
+ client = AzureOpenAI(
+ api_version=api_version,
+ azure_endpoint=api_base,
+ api_key=api_key,
+ timeout=timeout,
+ http_client=client_session,
+ )
+
+ # only run this check if it's not cloudflare ai gateway
+ if model is None and mode != "image_generation":
+ raise Exception("model is not set")
+
+ completion = None
+
+ if messages is None:
+ messages = [{"role": "user", "content": "Hey"}]
+ try:
+ completion = client.chat.completions.with_raw_response.create(
+ model=model, # type: ignore
+ messages=messages, # type: ignore
+ )
+ except Exception as e:
+ raise e
+ response = {}
+
+ if completion is None or not hasattr(completion, "headers"):
+ raise Exception("invalid completion response")
+
+ if (
+ completion.headers.get("x-ratelimit-remaining-requests", None) is not None
+ ): # not provided for dall-e requests
+ response["x-ratelimit-remaining-requests"] = completion.headers[
+ "x-ratelimit-remaining-requests"
+ ]
+
+ if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
+ response["x-ratelimit-remaining-tokens"] = completion.headers[
+ "x-ratelimit-remaining-tokens"
+ ]
+
+ if completion.headers.get("x-ms-region", None) is not None:
+ response["x-ms-region"] = completion.headers["x-ms-region"]
+
+ return response