diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/azure | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/azure')
14 files changed, 4544 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure/assistants.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure/assistants.py new file mode 100644 index 00000000..2e8c78b2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure/assistants.py @@ -0,0 +1,1024 @@ +from typing import Any, Coroutine, Dict, Iterable, Literal, Optional, Union + +import httpx +from openai import AsyncAzureOpenAI, AzureOpenAI +from typing_extensions import overload + +from ...types.llms.openai import ( + Assistant, + AssistantEventHandler, + AssistantStreamManager, + AssistantToolParam, + AsyncAssistantEventHandler, + AsyncAssistantStreamManager, + AsyncCursorPage, + OpenAICreateThreadParamsMessage, + OpenAIMessage, + Run, + SyncCursorPage, + Thread, +) +from .common_utils import BaseAzureLLM + + +class AzureAssistantsAPI(BaseAzureLLM): + def __init__(self) -> None: + super().__init__() + + def get_azure_client( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI] = None, + litellm_params: Optional[dict] = None, + ) -> AzureOpenAI: + if client is None: + azure_client_params = self.initialize_azure_sdk_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + model_name="", + api_version=api_version, + is_async=False, + ) + azure_openai_client = AzureOpenAI(**azure_client_params) # type: ignore + else: + azure_openai_client = client + + return azure_openai_client + + def async_get_azure_client( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI] = None, + litellm_params: Optional[dict] = None, + ) -> AsyncAzureOpenAI: + if client is None: + azure_client_params = self.initialize_azure_sdk_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + model_name="", + api_version=api_version, + is_async=True, + ) + + azure_openai_client = AsyncAzureOpenAI(**azure_client_params) + # azure_openai_client = AsyncAzureOpenAI(**data) # type: ignore + else: + azure_openai_client = client + + return azure_openai_client + + ### ASSISTANTS ### + + async def async_get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + litellm_params: Optional[dict] = None, + ) -> AsyncCursorPage[Assistant]: + azure_openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + litellm_params=litellm_params, + ) + + response = await azure_openai_client.beta.assistants.list() + + return response + + # fmt: off + + @overload + def get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + aget_assistants: Literal[True], + ) -> Coroutine[None, None, AsyncCursorPage[Assistant]]: + ... + + @overload + def get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI], + aget_assistants: Optional[Literal[False]], + ) -> SyncCursorPage[Assistant]: + ... + + # fmt: on + + def get_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client=None, + aget_assistants=None, + litellm_params: Optional[dict] = None, + ): + if aget_assistants is not None and aget_assistants is True: + return self.async_get_assistants( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + litellm_params=litellm_params, + ) + azure_openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + api_version=api_version, + litellm_params=litellm_params, + ) + + response = azure_openai_client.beta.assistants.list() + + return response + + ### MESSAGES ### + + async def a_add_message( + self, + thread_id: str, + message_data: dict, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI] = None, + litellm_params: Optional[dict] = None, + ) -> OpenAIMessage: + openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + litellm_params=litellm_params, + ) + + thread_message: OpenAIMessage = await openai_client.beta.threads.messages.create( # type: ignore + thread_id, **message_data # type: ignore + ) + + response_obj: Optional[OpenAIMessage] = None + if getattr(thread_message, "status", None) is None: + thread_message.status = "completed" + response_obj = OpenAIMessage(**thread_message.dict()) + else: + response_obj = OpenAIMessage(**thread_message.dict()) + return response_obj + + # fmt: off + + @overload + def add_message( + self, + thread_id: str, + message_data: dict, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + a_add_message: Literal[True], + litellm_params: Optional[dict] = None, + ) -> Coroutine[None, None, OpenAIMessage]: + ... + + @overload + def add_message( + self, + thread_id: str, + message_data: dict, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI], + a_add_message: Optional[Literal[False]], + litellm_params: Optional[dict] = None, + ) -> OpenAIMessage: + ... + + # fmt: on + + def add_message( + self, + thread_id: str, + message_data: dict, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client=None, + a_add_message: Optional[bool] = None, + litellm_params: Optional[dict] = None, + ): + if a_add_message is not None and a_add_message is True: + return self.a_add_message( + thread_id=thread_id, + message_data=message_data, + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + litellm_params=litellm_params, + ) + openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + ) + + thread_message: OpenAIMessage = openai_client.beta.threads.messages.create( # type: ignore + thread_id, **message_data # type: ignore + ) + + response_obj: Optional[OpenAIMessage] = None + if getattr(thread_message, "status", None) is None: + thread_message.status = "completed" + response_obj = OpenAIMessage(**thread_message.dict()) + else: + response_obj = OpenAIMessage(**thread_message.dict()) + return response_obj + + async def async_get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI] = None, + litellm_params: Optional[dict] = None, + ) -> AsyncCursorPage[OpenAIMessage]: + openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + litellm_params=litellm_params, + ) + + response = await openai_client.beta.threads.messages.list(thread_id=thread_id) + + return response + + # fmt: off + + @overload + def get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + aget_messages: Literal[True], + litellm_params: Optional[dict] = None, + ) -> Coroutine[None, None, AsyncCursorPage[OpenAIMessage]]: + ... + + @overload + def get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI], + aget_messages: Optional[Literal[False]], + litellm_params: Optional[dict] = None, + ) -> SyncCursorPage[OpenAIMessage]: + ... + + # fmt: on + + def get_messages( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client=None, + aget_messages=None, + litellm_params: Optional[dict] = None, + ): + if aget_messages is not None and aget_messages is True: + return self.async_get_messages( + thread_id=thread_id, + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + litellm_params=litellm_params, + ) + openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + litellm_params=litellm_params, + ) + + response = openai_client.beta.threads.messages.list(thread_id=thread_id) + + return response + + ### THREADS ### + + async def async_create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + litellm_params: Optional[dict] = None, + ) -> Thread: + openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + litellm_params=litellm_params, + ) + + data = {} + if messages is not None: + data["messages"] = messages # type: ignore + if metadata is not None: + data["metadata"] = metadata # type: ignore + + message_thread = await openai_client.beta.threads.create(**data) # type: ignore + + return Thread(**message_thread.dict()) + + # fmt: off + + @overload + def create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + client: Optional[AsyncAzureOpenAI], + acreate_thread: Literal[True], + litellm_params: Optional[dict] = None, + ) -> Coroutine[None, None, Thread]: + ... + + @overload + def create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + client: Optional[AzureOpenAI], + acreate_thread: Optional[Literal[False]], + litellm_params: Optional[dict] = None, + ) -> Thread: + ... + + # fmt: on + + def create_thread( + self, + metadata: Optional[dict], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + messages: Optional[Iterable[OpenAICreateThreadParamsMessage]], + client=None, + acreate_thread=None, + litellm_params: Optional[dict] = None, + ): + """ + Here's an example: + ``` + from litellm.llms.openai.openai import OpenAIAssistantsAPI, MessageData + + # create thread + message: MessageData = {"role": "user", "content": "Hey, how's it going?"} + openai_api.create_thread(messages=[message]) + ``` + """ + if acreate_thread is not None and acreate_thread is True: + return self.async_create_thread( + metadata=metadata, + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + messages=messages, + litellm_params=litellm_params, + ) + azure_openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + litellm_params=litellm_params, + ) + + data = {} + if messages is not None: + data["messages"] = messages # type: ignore + if metadata is not None: + data["metadata"] = metadata # type: ignore + + message_thread = azure_openai_client.beta.threads.create(**data) # type: ignore + + return Thread(**message_thread.dict()) + + async def async_get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + litellm_params: Optional[dict] = None, + ) -> Thread: + openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + litellm_params=litellm_params, + ) + + response = await openai_client.beta.threads.retrieve(thread_id=thread_id) + + return Thread(**response.dict()) + + # fmt: off + + @overload + def get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + aget_thread: Literal[True], + litellm_params: Optional[dict] = None, + ) -> Coroutine[None, None, Thread]: + ... + + @overload + def get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI], + aget_thread: Optional[Literal[False]], + litellm_params: Optional[dict] = None, + ) -> Thread: + ... + + # fmt: on + + def get_thread( + self, + thread_id: str, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client=None, + aget_thread=None, + litellm_params: Optional[dict] = None, + ): + if aget_thread is not None and aget_thread is True: + return self.async_get_thread( + thread_id=thread_id, + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + litellm_params=litellm_params, + ) + openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + litellm_params=litellm_params, + ) + + response = openai_client.beta.threads.retrieve(thread_id=thread_id) + + return Thread(**response.dict()) + + # def delete_thread(self): + # pass + + ### RUNS ### + + async def arun_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[Dict], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + litellm_params: Optional[dict] = None, + ) -> Run: + openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + timeout=timeout, + max_retries=max_retries, + api_version=api_version, + azure_ad_token=azure_ad_token, + client=client, + litellm_params=litellm_params, + ) + + response = await openai_client.beta.threads.runs.create_and_poll( # type: ignore + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, # type: ignore + model=model, + tools=tools, + ) + + return response + + def async_run_thread_stream( + self, + client: AsyncAzureOpenAI, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[Dict], + model: Optional[str], + tools: Optional[Iterable[AssistantToolParam]], + event_handler: Optional[AssistantEventHandler], + litellm_params: Optional[dict] = None, + ) -> AsyncAssistantStreamManager[AsyncAssistantEventHandler]: + data: Dict[str, Any] = { + "thread_id": thread_id, + "assistant_id": assistant_id, + "additional_instructions": additional_instructions, + "instructions": instructions, + "metadata": metadata, + "model": model, + "tools": tools, + } + if event_handler is not None: + data["event_handler"] = event_handler + return client.beta.threads.runs.stream(**data) # type: ignore + + def run_thread_stream( + self, + client: AzureOpenAI, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[Dict], + model: Optional[str], + tools: Optional[Iterable[AssistantToolParam]], + event_handler: Optional[AssistantEventHandler], + litellm_params: Optional[dict] = None, + ) -> AssistantStreamManager[AssistantEventHandler]: + data: Dict[str, Any] = { + "thread_id": thread_id, + "assistant_id": assistant_id, + "additional_instructions": additional_instructions, + "instructions": instructions, + "metadata": metadata, + "model": model, + "tools": tools, + } + if event_handler is not None: + data["event_handler"] = event_handler + return client.beta.threads.runs.stream(**data) # type: ignore + + # fmt: off + + @overload + def run_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[Dict], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + arun_thread: Literal[True], + ) -> Coroutine[None, None, Run]: + ... + + @overload + def run_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[Dict], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI], + arun_thread: Optional[Literal[False]], + ) -> Run: + ... + + # fmt: on + + def run_thread( + self, + thread_id: str, + assistant_id: str, + additional_instructions: Optional[str], + instructions: Optional[str], + metadata: Optional[Dict], + model: Optional[str], + stream: Optional[bool], + tools: Optional[Iterable[AssistantToolParam]], + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client=None, + arun_thread=None, + event_handler: Optional[AssistantEventHandler] = None, + litellm_params: Optional[dict] = None, + ): + if arun_thread is not None and arun_thread is True: + if stream is not None and stream is True: + azure_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + litellm_params=litellm_params, + ) + return self.async_run_thread_stream( + client=azure_client, + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, + event_handler=event_handler, + litellm_params=litellm_params, + ) + return self.arun_thread( + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, # type: ignore + model=model, + stream=stream, + tools=tools, + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + litellm_params=litellm_params, + ) + openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + litellm_params=litellm_params, + ) + + if stream is not None and stream is True: + return self.run_thread_stream( + client=openai_client, + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, + model=model, + tools=tools, + event_handler=event_handler, + litellm_params=litellm_params, + ) + + response = openai_client.beta.threads.runs.create_and_poll( # type: ignore + thread_id=thread_id, + assistant_id=assistant_id, + additional_instructions=additional_instructions, + instructions=instructions, + metadata=metadata, # type: ignore + model=model, + tools=tools, + ) + + return response + + # Create Assistant + async def async_create_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + create_assistant_data: dict, + litellm_params: Optional[dict] = None, + ) -> Assistant: + azure_openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + litellm_params=litellm_params, + ) + + response = await azure_openai_client.beta.assistants.create( + **create_assistant_data + ) + return response + + def create_assistants( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + create_assistant_data: dict, + client=None, + async_create_assistants=None, + litellm_params: Optional[dict] = None, + ): + if async_create_assistants is not None and async_create_assistants is True: + return self.async_create_assistants( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + create_assistant_data=create_assistant_data, + litellm_params=litellm_params, + ) + azure_openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + litellm_params=litellm_params, + ) + + response = azure_openai_client.beta.assistants.create(**create_assistant_data) + return response + + # Delete Assistant + async def async_delete_assistant( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AsyncAzureOpenAI], + assistant_id: str, + litellm_params: Optional[dict] = None, + ): + azure_openai_client = self.async_get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + litellm_params=litellm_params, + ) + + response = await azure_openai_client.beta.assistants.delete( + assistant_id=assistant_id + ) + return response + + def delete_assistant( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + assistant_id: str, + async_delete_assistants: Optional[bool] = None, + client=None, + litellm_params: Optional[dict] = None, + ): + if async_delete_assistants is not None and async_delete_assistants is True: + return self.async_delete_assistant( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + assistant_id=assistant_id, + litellm_params=litellm_params, + ) + azure_openai_client = self.get_azure_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + max_retries=max_retries, + client=client, + litellm_params=litellm_params, + ) + + response = azure_openai_client.beta.assistants.delete(assistant_id=assistant_id) + return response diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure/audio_transcriptions.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure/audio_transcriptions.py new file mode 100644 index 00000000..be7d0fa3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure/audio_transcriptions.py @@ -0,0 +1,198 @@ +import uuid +from typing import Any, Coroutine, Optional, Union + +from openai import AsyncAzureOpenAI, AzureOpenAI +from pydantic import BaseModel + +from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_name +from litellm.types.utils import FileTypes +from litellm.utils import ( + TranscriptionResponse, + convert_to_model_response_object, + extract_duration_from_srt_or_vtt, +) + +from .azure import AzureChatCompletion +from .common_utils import AzureOpenAIError + + +class AzureAudioTranscription(AzureChatCompletion): + def audio_transcriptions( + self, + model: str, + audio_file: FileTypes, + optional_params: dict, + logging_obj: Any, + model_response: TranscriptionResponse, + timeout: float, + max_retries: int, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + client=None, + azure_ad_token: Optional[str] = None, + atranscription: bool = False, + litellm_params: Optional[dict] = None, + ) -> Union[TranscriptionResponse, Coroutine[Any, Any, TranscriptionResponse]]: + data = {"model": model, "file": audio_file, **optional_params} + + if atranscription is True: + return self.async_audio_transcriptions( + audio_file=audio_file, + data=data, + model_response=model_response, + timeout=timeout, + api_key=api_key, + api_base=api_base, + client=client, + max_retries=max_retries, + logging_obj=logging_obj, + model=model, + litellm_params=litellm_params, + ) + + azure_client = self.get_azure_openai_client( + api_version=api_version, + api_base=api_base, + api_key=api_key, + model=model, + _is_async=False, + client=client, + litellm_params=litellm_params, + ) + if not isinstance(azure_client, AzureOpenAI): + raise AzureOpenAIError( + status_code=500, + message="azure_client is not an instance of AzureOpenAI", + ) + + ## LOGGING + logging_obj.pre_call( + input=f"audio_file_{uuid.uuid4()}", + api_key=azure_client.api_key, + additional_args={ + "headers": {"Authorization": f"Bearer {azure_client.api_key}"}, + "api_base": azure_client._base_url._uri_reference, + "atranscription": True, + "complete_input_dict": data, + }, + ) + + response = azure_client.audio.transcriptions.create( + **data, timeout=timeout # type: ignore + ) + + if isinstance(response, BaseModel): + stringified_response = response.model_dump() + else: + stringified_response = TranscriptionResponse(text=response).model_dump() + + ## LOGGING + logging_obj.post_call( + input=get_audio_file_name(audio_file), + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"} + final_response: TranscriptionResponse = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore + return final_response + + async def async_audio_transcriptions( + self, + audio_file: FileTypes, + model: str, + data: dict, + model_response: TranscriptionResponse, + timeout: float, + logging_obj: Any, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + client=None, + max_retries=None, + litellm_params: Optional[dict] = None, + ) -> TranscriptionResponse: + response = None + try: + async_azure_client = self.get_azure_openai_client( + api_version=api_version, + api_base=api_base, + api_key=api_key, + model=model, + _is_async=True, + client=client, + litellm_params=litellm_params, + ) + if not isinstance(async_azure_client, AsyncAzureOpenAI): + raise AzureOpenAIError( + status_code=500, + message="async_azure_client is not an instance of AsyncAzureOpenAI", + ) + + ## LOGGING + logging_obj.pre_call( + input=f"audio_file_{uuid.uuid4()}", + api_key=async_azure_client.api_key, + additional_args={ + "headers": { + "Authorization": f"Bearer {async_azure_client.api_key}" + }, + "api_base": async_azure_client._base_url._uri_reference, + "atranscription": True, + "complete_input_dict": data, + }, + ) + + raw_response = ( + await async_azure_client.audio.transcriptions.with_raw_response.create( + **data, timeout=timeout + ) + ) # type: ignore + + headers = dict(raw_response.headers) + response = raw_response.parse() + + if isinstance(response, BaseModel): + stringified_response = response.model_dump() + else: + stringified_response = TranscriptionResponse(text=response).model_dump() + duration = extract_duration_from_srt_or_vtt(response) + stringified_response["duration"] = duration + + ## LOGGING + logging_obj.post_call( + input=get_audio_file_name(audio_file), + api_key=api_key, + additional_args={ + "headers": { + "Authorization": f"Bearer {async_azure_client.api_key}" + }, + "api_base": async_azure_client._base_url._uri_reference, + "atranscription": True, + "complete_input_dict": data, + }, + original_response=stringified_response, + ) + hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"} + response = convert_to_model_response_object( + _response_headers=headers, + response_object=stringified_response, + model_response_object=model_response, + hidden_params=hidden_params, + response_type="audio_transcription", + ) + if not isinstance(response, TranscriptionResponse): + raise AzureOpenAIError( + status_code=500, + message="response is not an instance of TranscriptionResponse", + ) + return response + except Exception as e: + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + original_response=str(e), + ) + raise e diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure/azure.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure/azure.py new file mode 100644 index 00000000..03c5cc09 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure/azure.py @@ -0,0 +1,1347 @@ +import asyncio +import json +import time +from typing import Any, Callable, Coroutine, Dict, List, Optional, Union + +import httpx # type: ignore +from openai import APITimeoutError, AsyncAzureOpenAI, AzureOpenAI + +import litellm +from litellm.constants import DEFAULT_MAX_RETRIES +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.logging_utils import track_llm_api_timing +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, +) +from litellm.types.utils import ( + EmbeddingResponse, + ImageResponse, + LlmProviders, + ModelResponse, +) +from litellm.utils import ( + CustomStreamWrapper, + convert_to_model_response_object, + modify_url, +) + +from ...types.llms.openai import HttpxBinaryResponseContent +from ..base import BaseLLM +from .common_utils import ( + AzureOpenAIError, + BaseAzureLLM, + get_azure_ad_token_from_oidc, + process_azure_headers, + select_azure_base_url_or_endpoint, +) + + +class AzureOpenAIAssistantsAPIConfig: + """ + Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message + """ + + def __init__( + self, + ) -> None: + pass + + def get_supported_openai_create_message_params(self): + return [ + "role", + "content", + "attachments", + "metadata", + ] + + def map_openai_params_create_message_params( + self, non_default_params: dict, optional_params: dict + ): + for param, value in non_default_params.items(): + if param == "role": + optional_params["role"] = value + if param == "metadata": + optional_params["metadata"] = value + elif param == "content": # only string accepted + if isinstance(value, str): + optional_params["content"] = value + else: + raise litellm.utils.UnsupportedParamsError( + message="Azure only accepts content as a string.", + status_code=400, + ) + elif ( + param == "attachments" + ): # this is a v2 param. Azure currently supports the old 'file_id's param + file_ids: List[str] = [] + if isinstance(value, list): + for item in value: + if "file_id" in item: + file_ids.append(item["file_id"]) + else: + if litellm.drop_params is True: + pass + else: + raise litellm.utils.UnsupportedParamsError( + message="Azure doesn't support {}. To drop it from the call, set `litellm.drop_params = True.".format( + value + ), + status_code=400, + ) + else: + raise litellm.utils.UnsupportedParamsError( + message="Invalid param. attachments should always be a list. Got={}, Expected=List. Raw value={}".format( + type(value), value + ), + status_code=400, + ) + return optional_params + + +def _check_dynamic_azure_params( + azure_client_params: dict, + azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]], +) -> bool: + """ + Returns True if user passed in client params != initialized azure client + + Currently only implemented for api version + """ + if azure_client is None: + return True + + dynamic_params = ["api_version"] + for k, v in azure_client_params.items(): + if k in dynamic_params and k == "api_version": + if v is not None and v != azure_client._custom_query["api-version"]: + return True + + return False + + +class AzureChatCompletion(BaseAzureLLM, BaseLLM): + def __init__(self) -> None: + super().__init__() + + def validate_environment(self, api_key, azure_ad_token, azure_ad_token_provider): + headers = { + "content-type": "application/json", + } + if api_key is not None: + headers["api-key"] = api_key + elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) + headers["Authorization"] = f"Bearer {azure_ad_token}" + elif azure_ad_token_provider is not None: + azure_ad_token = azure_ad_token_provider() + headers["Authorization"] = f"Bearer {azure_ad_token}" + + return headers + + def make_sync_azure_openai_chat_completion_request( + self, + azure_client: AzureOpenAI, + data: dict, + timeout: Union[float, httpx.Timeout], + ): + """ + Helper to: + - call chat.completions.create.with_raw_response when litellm.return_response_headers is True + - call chat.completions.create by default + """ + try: + raw_response = azure_client.chat.completions.with_raw_response.create( + **data, timeout=timeout + ) + + headers = dict(raw_response.headers) + response = raw_response.parse() + return headers, response + except Exception as e: + raise e + + @track_llm_api_timing() + async def make_azure_openai_chat_completion_request( + self, + azure_client: AsyncAzureOpenAI, + data: dict, + timeout: Union[float, httpx.Timeout], + logging_obj: LiteLLMLoggingObj, + ): + """ + Helper to: + - call chat.completions.create.with_raw_response when litellm.return_response_headers is True + - call chat.completions.create by default + """ + start_time = time.time() + try: + raw_response = await azure_client.chat.completions.with_raw_response.create( + **data, timeout=timeout + ) + + headers = dict(raw_response.headers) + response = raw_response.parse() + return headers, response + except APITimeoutError as e: + end_time = time.time() + time_delta = round(end_time - start_time, 2) + e.message += f" - timeout value={timeout}, time taken={time_delta} seconds" + raise e + except Exception as e: + raise e + + def completion( # noqa: PLR0915 + self, + model: str, + messages: list, + model_response: ModelResponse, + api_key: str, + api_base: str, + api_version: str, + api_type: str, + azure_ad_token: str, + azure_ad_token_provider: Callable, + dynamic_params: bool, + print_verbose: Callable, + timeout: Union[float, httpx.Timeout], + logging_obj: LiteLLMLoggingObj, + optional_params, + litellm_params, + logger_fn, + acompletion: bool = False, + headers: Optional[dict] = None, + client=None, + ): + if headers: + optional_params["extra_headers"] = headers + try: + if model is None or messages is None: + raise AzureOpenAIError( + status_code=422, message="Missing model or messages" + ) + + max_retries = optional_params.pop("max_retries", None) + if max_retries is None: + max_retries = DEFAULT_MAX_RETRIES + json_mode: Optional[bool] = optional_params.pop("json_mode", False) + + ### CHECK IF CLOUDFLARE AI GATEWAY ### + ### if so - set the model as part of the base url + if "gateway.ai.cloudflare.com" in api_base: + client = self._init_azure_client_for_cloudflare_ai_gateway( + api_base=api_base, + model=model, + api_version=api_version, + max_retries=max_retries, + timeout=timeout, + api_key=api_key, + azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, + acompletion=acompletion, + client=client, + ) + + data = {"model": None, "messages": messages, **optional_params} + else: + data = litellm.AzureOpenAIConfig().transform_request( + model=model, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + headers=headers or {}, + ) + + if acompletion is True: + if optional_params.get("stream", False): + return self.async_streaming( + logging_obj=logging_obj, + api_base=api_base, + dynamic_params=dynamic_params, + data=data, + model=model, + api_key=api_key, + api_version=api_version, + azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, + timeout=timeout, + client=client, + max_retries=max_retries, + litellm_params=litellm_params, + ) + else: + return self.acompletion( + api_base=api_base, + data=data, + model_response=model_response, + api_key=api_key, + api_version=api_version, + model=model, + azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, + dynamic_params=dynamic_params, + timeout=timeout, + client=client, + logging_obj=logging_obj, + max_retries=max_retries, + convert_tool_call_to_json_mode=json_mode, + litellm_params=litellm_params, + ) + elif "stream" in optional_params and optional_params["stream"] is True: + return self.streaming( + logging_obj=logging_obj, + api_base=api_base, + dynamic_params=dynamic_params, + data=data, + model=model, + api_key=api_key, + api_version=api_version, + azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, + timeout=timeout, + client=client, + max_retries=max_retries, + litellm_params=litellm_params, + ) + else: + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key=api_key, + additional_args={ + "headers": { + "api_key": api_key, + "azure_ad_token": azure_ad_token, + }, + "api_version": api_version, + "api_base": api_base, + "complete_input_dict": data, + }, + ) + if not isinstance(max_retries, int): + raise AzureOpenAIError( + status_code=422, message="max retries must be an int" + ) + # init AzureOpenAI Client + azure_client = self.get_azure_openai_client( + api_version=api_version, + api_base=api_base, + api_key=api_key, + model=model, + client=client, + _is_async=False, + litellm_params=litellm_params, + ) + if not isinstance(azure_client, AzureOpenAI): + raise AzureOpenAIError( + status_code=500, + message="azure_client is not an instance of AzureOpenAI", + ) + + headers, response = self.make_sync_azure_openai_chat_completion_request( + azure_client=azure_client, data=data, timeout=timeout + ) + stringified_response = response.model_dump() + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=stringified_response, + additional_args={ + "headers": headers, + "api_version": api_version, + "api_base": api_base, + }, + ) + return convert_to_model_response_object( + response_object=stringified_response, + model_response_object=model_response, + convert_tool_call_to_json_mode=json_mode, + _response_headers=headers, + ) + except AzureOpenAIError as e: + raise e + except Exception as e: + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_response = getattr(e, "response", None) + error_body = getattr(e, "body", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise AzureOpenAIError( + status_code=status_code, + message=str(e), + headers=error_headers, + body=error_body, + ) + + async def acompletion( + self, + api_key: str, + api_version: str, + model: str, + api_base: str, + data: dict, + timeout: Any, + dynamic_params: bool, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + max_retries: int, + azure_ad_token: Optional[str] = None, + azure_ad_token_provider: Optional[Callable] = None, + convert_tool_call_to_json_mode: Optional[bool] = None, + client=None, # this is the AsyncAzureOpenAI + litellm_params: Optional[dict] = {}, + ): + response = None + try: + # setting Azure client + azure_client = self.get_azure_openai_client( + api_version=api_version, + api_base=api_base, + api_key=api_key, + model=model, + client=client, + _is_async=True, + litellm_params=litellm_params, + ) + if not isinstance(azure_client, AsyncAzureOpenAI): + raise ValueError("Azure client is not an instance of AsyncAzureOpenAI") + ## LOGGING + logging_obj.pre_call( + input=data["messages"], + api_key=azure_client.api_key, + additional_args={ + "headers": { + "api_key": api_key, + "azure_ad_token": azure_ad_token, + }, + "api_base": azure_client._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, + ) + + headers, response = await self.make_azure_openai_chat_completion_request( + azure_client=azure_client, + data=data, + timeout=timeout, + logging_obj=logging_obj, + ) + logging_obj.model_call_details["response_headers"] = headers + + stringified_response = response.model_dump() + logging_obj.post_call( + input=data["messages"], + api_key=api_key, + original_response=stringified_response, + additional_args={"complete_input_dict": data}, + ) + + return convert_to_model_response_object( + response_object=stringified_response, + model_response_object=model_response, + hidden_params={"headers": headers}, + _response_headers=headers, + convert_tool_call_to_json_mode=convert_tool_call_to_json_mode, + ) + except AzureOpenAIError as e: + ## LOGGING + logging_obj.post_call( + input=data["messages"], + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) + raise e + except asyncio.CancelledError as e: + ## LOGGING + logging_obj.post_call( + input=data["messages"], + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) + raise AzureOpenAIError(status_code=500, message=str(e)) + except Exception as e: + message = getattr(e, "message", str(e)) + body = getattr(e, "body", None) + ## LOGGING + logging_obj.post_call( + input=data["messages"], + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) + if hasattr(e, "status_code"): + raise e + else: + raise AzureOpenAIError(status_code=500, message=message, body=body) + + def streaming( + self, + logging_obj, + api_base: str, + api_key: str, + api_version: str, + dynamic_params: bool, + data: dict, + model: str, + timeout: Any, + max_retries: int, + azure_ad_token: Optional[str] = None, + azure_ad_token_provider: Optional[Callable] = None, + client=None, + litellm_params: Optional[dict] = {}, + ): + # init AzureOpenAI Client + azure_client_params = { + "api_version": api_version, + "azure_endpoint": api_base, + "azure_deployment": model, + "http_client": litellm.client_session, + "max_retries": max_retries, + "timeout": timeout, + } + azure_client_params = select_azure_base_url_or_endpoint( + azure_client_params=azure_client_params + ) + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) + azure_client_params["azure_ad_token"] = azure_ad_token + elif azure_ad_token_provider is not None: + azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider + + azure_client = self.get_azure_openai_client( + api_version=api_version, + api_base=api_base, + api_key=api_key, + model=model, + client=client, + _is_async=False, + litellm_params=litellm_params, + ) + if not isinstance(azure_client, AzureOpenAI): + raise AzureOpenAIError( + status_code=500, + message="azure_client is not an instance of AzureOpenAI", + ) + ## LOGGING + logging_obj.pre_call( + input=data["messages"], + api_key=azure_client.api_key, + additional_args={ + "headers": { + "api_key": api_key, + "azure_ad_token": azure_ad_token, + }, + "api_base": azure_client._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, + ) + headers, response = self.make_sync_azure_openai_chat_completion_request( + azure_client=azure_client, data=data, timeout=timeout + ) + streamwrapper = CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider="azure", + logging_obj=logging_obj, + stream_options=data.get("stream_options", None), + _response_headers=process_azure_headers(headers), + ) + return streamwrapper + + async def async_streaming( + self, + logging_obj: LiteLLMLoggingObj, + api_base: str, + api_key: str, + api_version: str, + dynamic_params: bool, + data: dict, + model: str, + timeout: Any, + max_retries: int, + azure_ad_token: Optional[str] = None, + azure_ad_token_provider: Optional[Callable] = None, + client=None, + litellm_params: Optional[dict] = {}, + ): + try: + azure_client = self.get_azure_openai_client( + api_version=api_version, + api_base=api_base, + api_key=api_key, + model=model, + client=client, + _is_async=True, + litellm_params=litellm_params, + ) + if not isinstance(azure_client, AsyncAzureOpenAI): + raise ValueError("Azure client is not an instance of AsyncAzureOpenAI") + + ## LOGGING + logging_obj.pre_call( + input=data["messages"], + api_key=azure_client.api_key, + additional_args={ + "headers": { + "api_key": api_key, + "azure_ad_token": azure_ad_token, + }, + "api_base": azure_client._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, + ) + + headers, response = await self.make_azure_openai_chat_completion_request( + azure_client=azure_client, + data=data, + timeout=timeout, + logging_obj=logging_obj, + ) + logging_obj.model_call_details["response_headers"] = headers + + # return response + streamwrapper = CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider="azure", + logging_obj=logging_obj, + stream_options=data.get("stream_options", None), + _response_headers=headers, + ) + return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails + except Exception as e: + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_response = getattr(e, "response", None) + message = getattr(e, "message", str(e)) + error_body = getattr(e, "body", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise AzureOpenAIError( + status_code=status_code, + message=message, + headers=error_headers, + body=error_body, + ) + + async def aembedding( + self, + model: str, + data: dict, + model_response: EmbeddingResponse, + input: list, + logging_obj: LiteLLMLoggingObj, + api_base: str, + api_key: Optional[str] = None, + api_version: Optional[str] = None, + client: Optional[AsyncAzureOpenAI] = None, + timeout: Optional[Union[float, httpx.Timeout]] = None, + max_retries: Optional[int] = None, + azure_ad_token: Optional[str] = None, + azure_ad_token_provider: Optional[Callable] = None, + litellm_params: Optional[dict] = {}, + ) -> EmbeddingResponse: + response = None + try: + + openai_aclient = self.get_azure_openai_client( + api_version=api_version, + api_base=api_base, + api_key=api_key, + model=model, + _is_async=True, + client=client, + litellm_params=litellm_params, + ) + if not isinstance(openai_aclient, AsyncAzureOpenAI): + raise ValueError("Azure client is not an instance of AsyncAzureOpenAI") + + raw_response = await openai_aclient.embeddings.with_raw_response.create( + **data, timeout=timeout + ) + headers = dict(raw_response.headers) + response = raw_response.parse() + stringified_response = response.model_dump() + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + embedding_response = convert_to_model_response_object( + response_object=stringified_response, + model_response_object=model_response, + hidden_params={"headers": headers}, + _response_headers=process_azure_headers(headers), + response_type="embedding", + ) + if not isinstance(embedding_response, EmbeddingResponse): + raise AzureOpenAIError( + status_code=500, + message="embedding_response is not an instance of EmbeddingResponse", + ) + return embedding_response + except Exception as e: + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) + raise e + + def embedding( + self, + model: str, + input: list, + api_base: str, + api_version: str, + timeout: float, + logging_obj: LiteLLMLoggingObj, + model_response: EmbeddingResponse, + optional_params: dict, + api_key: Optional[str] = None, + azure_ad_token: Optional[str] = None, + azure_ad_token_provider: Optional[Callable] = None, + max_retries: Optional[int] = None, + client=None, + aembedding=None, + headers: Optional[dict] = None, + litellm_params: Optional[dict] = None, + ) -> Union[EmbeddingResponse, Coroutine[Any, Any, EmbeddingResponse]]: + if headers: + optional_params["extra_headers"] = headers + if self._client_session is None: + self._client_session = self.create_client_session() + try: + data = {"model": model, "input": input, **optional_params} + if max_retries is None: + max_retries = litellm.DEFAULT_MAX_RETRIES + ## LOGGING + logging_obj.pre_call( + input=input, + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "headers": {"api_key": api_key, "azure_ad_token": azure_ad_token}, + }, + ) + + if aembedding is True: + return self.aembedding( + data=data, + input=input, + model=model, + logging_obj=logging_obj, + api_key=api_key, + model_response=model_response, + timeout=timeout, + client=client, + litellm_params=litellm_params, + api_base=api_base, + ) + azure_client = self.get_azure_openai_client( + api_version=api_version, + api_base=api_base, + api_key=api_key, + model=model, + _is_async=False, + client=client, + litellm_params=litellm_params, + ) + if not isinstance(azure_client, AzureOpenAI): + raise AzureOpenAIError( + status_code=500, + message="azure_client is not an instance of AzureOpenAI", + ) + + ## COMPLETION CALL + raw_response = azure_client.embeddings.with_raw_response.create(**data, timeout=timeout) # type: ignore + headers = dict(raw_response.headers) + response = raw_response.parse() + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data, "api_base": api_base}, + original_response=response, + ) + + return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding", _response_headers=process_azure_headers(headers)) # type: ignore + except AzureOpenAIError as e: + raise e + except Exception as e: + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise AzureOpenAIError( + status_code=status_code, message=str(e), headers=error_headers + ) + + async def make_async_azure_httpx_request( + self, + client: Optional[AsyncHTTPHandler], + timeout: Optional[Union[float, httpx.Timeout]], + api_base: str, + api_version: str, + api_key: str, + data: dict, + headers: dict, + ) -> httpx.Response: + """ + Implemented for azure dall-e-2 image gen calls + + Alternative to needing a custom transport implementation + """ + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + async_handler = get_async_httpx_client( + llm_provider=LlmProviders.AZURE, + params=_params, + ) + else: + async_handler = client # type: ignore + + if ( + "images/generations" in api_base + and api_version + in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict + "2023-06-01-preview", + "2023-07-01-preview", + "2023-08-01-preview", + "2023-09-01-preview", + "2023-10-01-preview", + ] + ): # CREATE + POLL for azure dall-e-2 calls + + api_base = modify_url( + original_url=api_base, new_path="/openai/images/generations:submit" + ) + + data.pop( + "model", None + ) # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview + response = await async_handler.post( + url=api_base, + data=json.dumps(data), + headers=headers, + ) + if "operation-location" in response.headers: + operation_location_url = response.headers["operation-location"] + else: + raise AzureOpenAIError(status_code=500, message=response.text) + response = await async_handler.get( + url=operation_location_url, + headers=headers, + ) + + await response.aread() + + timeout_secs: int = 120 + start_time = time.time() + if "status" not in response.json(): + raise Exception( + "Expected 'status' in response. Got={}".format(response.json()) + ) + while response.json()["status"] not in ["succeeded", "failed"]: + if time.time() - start_time > timeout_secs: + + raise AzureOpenAIError( + status_code=408, message="Operation polling timed out." + ) + + await asyncio.sleep(int(response.headers.get("retry-after") or 10)) + response = await async_handler.get( + url=operation_location_url, + headers=headers, + ) + await response.aread() + + if response.json()["status"] == "failed": + error_data = response.json() + raise AzureOpenAIError(status_code=400, message=json.dumps(error_data)) + + result = response.json()["result"] + return httpx.Response( + status_code=200, + headers=response.headers, + content=json.dumps(result).encode("utf-8"), + request=httpx.Request(method="POST", url="https://api.openai.com/v1"), + ) + return await async_handler.post( + url=api_base, + json=data, + headers=headers, + ) + + def make_sync_azure_httpx_request( + self, + client: Optional[HTTPHandler], + timeout: Optional[Union[float, httpx.Timeout]], + api_base: str, + api_version: str, + api_key: str, + data: dict, + headers: dict, + ) -> httpx.Response: + """ + Implemented for azure dall-e-2 image gen calls + + Alternative to needing a custom transport implementation + """ + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + sync_handler = HTTPHandler(**_params, client=litellm.client_session) # type: ignore + else: + sync_handler = client # type: ignore + + if ( + "images/generations" in api_base + and api_version + in [ # dall-e-3 starts from `2023-12-01-preview` so we should be able to avoid conflict + "2023-06-01-preview", + "2023-07-01-preview", + "2023-08-01-preview", + "2023-09-01-preview", + "2023-10-01-preview", + ] + ): # CREATE + POLL for azure dall-e-2 calls + + api_base = modify_url( + original_url=api_base, new_path="/openai/images/generations:submit" + ) + + data.pop( + "model", None + ) # REMOVE 'model' from dall-e-2 arg https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#request-a-generated-image-dall-e-2-preview + response = sync_handler.post( + url=api_base, + data=json.dumps(data), + headers=headers, + ) + if "operation-location" in response.headers: + operation_location_url = response.headers["operation-location"] + else: + raise AzureOpenAIError(status_code=500, message=response.text) + response = sync_handler.get( + url=operation_location_url, + headers=headers, + ) + + response.read() + + timeout_secs: int = 120 + start_time = time.time() + if "status" not in response.json(): + raise Exception( + "Expected 'status' in response. Got={}".format(response.json()) + ) + while response.json()["status"] not in ["succeeded", "failed"]: + if time.time() - start_time > timeout_secs: + raise AzureOpenAIError( + status_code=408, message="Operation polling timed out." + ) + + time.sleep(int(response.headers.get("retry-after") or 10)) + response = sync_handler.get( + url=operation_location_url, + headers=headers, + ) + response.read() + + if response.json()["status"] == "failed": + error_data = response.json() + raise AzureOpenAIError(status_code=400, message=json.dumps(error_data)) + + result = response.json()["result"] + return httpx.Response( + status_code=200, + headers=response.headers, + content=json.dumps(result).encode("utf-8"), + request=httpx.Request(method="POST", url="https://api.openai.com/v1"), + ) + return sync_handler.post( + url=api_base, + json=data, + headers=headers, + ) + + def create_azure_base_url( + self, azure_client_params: dict, model: Optional[str] + ) -> str: + api_base: str = azure_client_params.get( + "azure_endpoint", "" + ) # "https://example-endpoint.openai.azure.com" + if api_base.endswith("/"): + api_base = api_base.rstrip("/") + api_version: str = azure_client_params.get("api_version", "") + if model is None: + model = "" + + if "/openai/deployments/" in api_base: + base_url_with_deployment = api_base + else: + base_url_with_deployment = api_base + "/openai/deployments/" + model + + base_url_with_deployment += "/images/generations" + base_url_with_deployment += "?api-version=" + api_version + + return base_url_with_deployment + + async def aimage_generation( + self, + data: dict, + model_response: ModelResponse, + azure_client_params: dict, + api_key: str, + input: list, + logging_obj: LiteLLMLoggingObj, + headers: dict, + client=None, + timeout=None, + ) -> litellm.ImageResponse: + response: Optional[dict] = None + try: + # response = await azure_client.images.generate(**data, timeout=timeout) + api_base: str = azure_client_params.get( + "api_base", "" + ) # "https://example-endpoint.openai.azure.com" + if api_base.endswith("/"): + api_base = api_base.rstrip("/") + api_version: str = azure_client_params.get("api_version", "") + img_gen_api_base = self.create_azure_base_url( + azure_client_params=azure_client_params, model=data.get("model", "") + ) + + ## LOGGING + logging_obj.pre_call( + input=data["prompt"], + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "api_base": img_gen_api_base, + "headers": headers, + }, + ) + httpx_response: httpx.Response = await self.make_async_azure_httpx_request( + client=None, + timeout=timeout, + api_base=img_gen_api_base, + api_version=api_version, + api_key=api_key, + data=data, + headers=headers, + ) + response = httpx_response.json() + + stringified_response = response + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + return convert_to_model_response_object( # type: ignore + response_object=stringified_response, + model_response_object=model_response, + response_type="image_generation", + ) + except Exception as e: + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) + raise e + + def image_generation( + self, + prompt: str, + timeout: float, + optional_params: dict, + logging_obj: LiteLLMLoggingObj, + headers: dict, + model: Optional[str] = None, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + model_response: Optional[ImageResponse] = None, + azure_ad_token: Optional[str] = None, + azure_ad_token_provider: Optional[Callable] = None, + client=None, + aimg_generation=None, + litellm_params: Optional[dict] = None, + ) -> ImageResponse: + try: + if model and len(model) > 0: + model = model + else: + model = None + + ## BASE MODEL CHECK + if ( + model_response is not None + and optional_params.get("base_model", None) is not None + ): + model_response._hidden_params["model"] = optional_params.pop( + "base_model" + ) + + data = {"model": model, "prompt": prompt, **optional_params} + max_retries = data.pop("max_retries", 2) + if not isinstance(max_retries, int): + raise AzureOpenAIError( + status_code=422, message="max retries must be an int" + ) + + # init AzureOpenAI Client + azure_client_params: Dict[str, Any] = self.initialize_azure_sdk_client( + litellm_params=litellm_params or {}, + api_key=api_key, + model_name=model or "", + api_version=api_version, + api_base=api_base, + is_async=False, + ) + if aimg_generation is True: + return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout, headers=headers) # type: ignore + + img_gen_api_base = self.create_azure_base_url( + azure_client_params=azure_client_params, model=data.get("model", "") + ) + + ## LOGGING + logging_obj.pre_call( + input=data["prompt"], + api_key=api_key, + additional_args={ + "complete_input_dict": data, + "api_base": img_gen_api_base, + "headers": headers, + }, + ) + httpx_response: httpx.Response = self.make_sync_azure_httpx_request( + client=None, + timeout=timeout, + api_base=img_gen_api_base, + api_version=api_version or "", + api_key=api_key or "", + data=data, + headers=headers, + ) + response = httpx_response.json() + + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=response, + ) + # return response + return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore + except AzureOpenAIError as e: + raise e + except Exception as e: + error_code = getattr(e, "status_code", None) + if error_code is not None: + raise AzureOpenAIError(status_code=error_code, message=str(e)) + else: + raise AzureOpenAIError(status_code=500, message=str(e)) + + def audio_speech( + self, + model: str, + input: str, + voice: str, + optional_params: dict, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + organization: Optional[str], + max_retries: int, + timeout: Union[float, httpx.Timeout], + azure_ad_token: Optional[str] = None, + azure_ad_token_provider: Optional[Callable] = None, + aspeech: Optional[bool] = None, + client=None, + litellm_params: Optional[dict] = None, + ) -> HttpxBinaryResponseContent: + + max_retries = optional_params.pop("max_retries", 2) + + if aspeech is not None and aspeech is True: + return self.async_audio_speech( + model=model, + input=input, + voice=voice, + optional_params=optional_params, + api_key=api_key, + api_base=api_base, + api_version=api_version, + azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, + max_retries=max_retries, + timeout=timeout, + client=client, + litellm_params=litellm_params, + ) # type: ignore + + azure_client: AzureOpenAI = self.get_azure_openai_client( + api_base=api_base, + api_version=api_version, + api_key=api_key, + model=model, + _is_async=False, + client=client, + litellm_params=litellm_params, + ) # type: ignore + + response = azure_client.audio.speech.create( + model=model, + voice=voice, # type: ignore + input=input, + **optional_params, + ) + return HttpxBinaryResponseContent(response=response.response) + + async def async_audio_speech( + self, + model: str, + input: str, + voice: str, + optional_params: dict, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + azure_ad_token: Optional[str], + azure_ad_token_provider: Optional[Callable], + max_retries: int, + timeout: Union[float, httpx.Timeout], + client=None, + litellm_params: Optional[dict] = None, + ) -> HttpxBinaryResponseContent: + + azure_client: AsyncAzureOpenAI = self.get_azure_openai_client( + api_base=api_base, + api_version=api_version, + api_key=api_key, + model=model, + _is_async=True, + client=client, + litellm_params=litellm_params, + ) # type: ignore + + azure_response = await azure_client.audio.speech.create( + model=model, + voice=voice, # type: ignore + input=input, + **optional_params, + ) + + return HttpxBinaryResponseContent(response=azure_response.response) + + def get_headers( + self, + model: Optional[str], + api_key: str, + api_base: str, + api_version: str, + timeout: float, + mode: str, + messages: Optional[list] = None, + input: Optional[list] = None, + prompt: Optional[str] = None, + ) -> dict: + client_session = litellm.client_session or httpx.Client() + if "gateway.ai.cloudflare.com" in api_base: + ## build base url - assume api base includes resource name + if not api_base.endswith("/"): + api_base += "/" + api_base += f"{model}" + client = AzureOpenAI( + base_url=api_base, + api_version=api_version, + api_key=api_key, + timeout=timeout, + http_client=client_session, + ) + model = None + # cloudflare ai gateway, needs model=None + else: + client = AzureOpenAI( + api_version=api_version, + azure_endpoint=api_base, + api_key=api_key, + timeout=timeout, + http_client=client_session, + ) + + # only run this check if it's not cloudflare ai gateway + if model is None and mode != "image_generation": + raise Exception("model is not set") + + completion = None + + if messages is None: + messages = [{"role": "user", "content": "Hey"}] + try: + completion = client.chat.completions.with_raw_response.create( + model=model, # type: ignore + messages=messages, # type: ignore + ) + except Exception as e: + raise e + response = {} + + if completion is None or not hasattr(completion, "headers"): + raise Exception("invalid completion response") + + if ( + completion.headers.get("x-ratelimit-remaining-requests", None) is not None + ): # not provided for dall-e requests + response["x-ratelimit-remaining-requests"] = completion.headers[ + "x-ratelimit-remaining-requests" + ] + + if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None: + response["x-ratelimit-remaining-tokens"] = completion.headers[ + "x-ratelimit-remaining-tokens" + ] + + if completion.headers.get("x-ms-region", None) is not None: + response["x-ms-region"] = completion.headers["x-ms-region"] + + return response diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure/batches/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure/batches/handler.py new file mode 100644 index 00000000..1b93c526 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure/batches/handler.py @@ -0,0 +1,210 @@ +""" +Azure Batches API Handler +""" + +from typing import Any, Coroutine, Optional, Union, cast + +import httpx + +from litellm.llms.azure.azure import AsyncAzureOpenAI, AzureOpenAI +from litellm.types.llms.openai import ( + Batch, + CancelBatchRequest, + CreateBatchRequest, + RetrieveBatchRequest, +) +from litellm.types.utils import LiteLLMBatch + +from ..common_utils import BaseAzureLLM + + +class AzureBatchesAPI(BaseAzureLLM): + """ + Azure methods to support for batches + - create_batch() + - retrieve_batch() + - cancel_batch() + - list_batch() + """ + + def __init__(self) -> None: + super().__init__() + + async def acreate_batch( + self, + create_batch_data: CreateBatchRequest, + azure_client: AsyncAzureOpenAI, + ) -> LiteLLMBatch: + response = await azure_client.batches.create(**create_batch_data) + return LiteLLMBatch(**response.model_dump()) + + def create_batch( + self, + _is_async: bool, + create_batch_data: CreateBatchRequest, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + litellm_params: Optional[dict] = None, + ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]: + azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( + self.get_azure_openai_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, + litellm_params=litellm_params or {}, + ) + ) + if azure_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(azure_client, AsyncAzureOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.acreate_batch( # type: ignore + create_batch_data=create_batch_data, azure_client=azure_client + ) + response = cast(AzureOpenAI, azure_client).batches.create(**create_batch_data) + return LiteLLMBatch(**response.model_dump()) + + async def aretrieve_batch( + self, + retrieve_batch_data: RetrieveBatchRequest, + client: AsyncAzureOpenAI, + ) -> LiteLLMBatch: + response = await client.batches.retrieve(**retrieve_batch_data) + return LiteLLMBatch(**response.model_dump()) + + def retrieve_batch( + self, + _is_async: bool, + retrieve_batch_data: RetrieveBatchRequest, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI] = None, + litellm_params: Optional[dict] = None, + ): + azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( + self.get_azure_openai_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, + litellm_params=litellm_params or {}, + ) + ) + if azure_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(azure_client, AsyncAzureOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.aretrieve_batch( # type: ignore + retrieve_batch_data=retrieve_batch_data, client=azure_client + ) + response = cast(AzureOpenAI, azure_client).batches.retrieve( + **retrieve_batch_data + ) + return LiteLLMBatch(**response.model_dump()) + + async def acancel_batch( + self, + cancel_batch_data: CancelBatchRequest, + client: AsyncAzureOpenAI, + ) -> Batch: + response = await client.batches.cancel(**cancel_batch_data) + return response + + def cancel_batch( + self, + _is_async: bool, + cancel_batch_data: CancelBatchRequest, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[AzureOpenAI] = None, + litellm_params: Optional[dict] = None, + ): + azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( + self.get_azure_openai_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, + litellm_params=litellm_params or {}, + ) + ) + if azure_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + response = azure_client.batches.cancel(**cancel_batch_data) + return response + + async def alist_batches( + self, + client: AsyncAzureOpenAI, + after: Optional[str] = None, + limit: Optional[int] = None, + ): + response = await client.batches.list(after=after, limit=limit) # type: ignore + return response + + def list_batches( + self, + _is_async: bool, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + after: Optional[str] = None, + limit: Optional[int] = None, + client: Optional[AzureOpenAI] = None, + litellm_params: Optional[dict] = None, + ): + azure_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( + self.get_azure_openai_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, + litellm_params=litellm_params or {}, + ) + ) + if azure_client is None: + raise ValueError( + "OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(azure_client, AsyncAzureOpenAI): + raise ValueError( + "OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client." + ) + return self.alist_batches( # type: ignore + client=azure_client, after=after, limit=limit + ) + response = azure_client.batches.list(after=after, limit=limit) # type: ignore + return response diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure/chat/gpt_transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure/chat/gpt_transformation.py new file mode 100644 index 00000000..ee85517e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure/chat/gpt_transformation.py @@ -0,0 +1,294 @@ +from typing import TYPE_CHECKING, Any, List, Optional, Union + +from httpx._models import Headers, Response + +import litellm +from litellm.litellm_core_utils.prompt_templates.factory import ( + convert_to_azure_openai_messages, +) +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.types.utils import ModelResponse +from litellm.utils import supports_response_schema + +from ....exceptions import UnsupportedParamsError +from ....types.llms.openai import AllMessageValues +from ...base_llm.chat.transformation import BaseConfig +from ..common_utils import AzureOpenAIError + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + + LoggingClass = LiteLLMLoggingObj +else: + LoggingClass = Any + + +class AzureOpenAIConfig(BaseConfig): + """ + Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions + + The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. Below are the parameters:: + + - `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition. + + - `function_call` (string or object): This optional parameter controls how the model calls functions. + + - `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs. + + - `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion. + + - `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. + + - `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message. + + - `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics. + + - `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens. + + - `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. + + - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. + """ + + def __init__( + self, + frequency_penalty: Optional[int] = None, + function_call: Optional[Union[str, dict]] = None, + functions: Optional[list] = None, + logit_bias: Optional[dict] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[int] = None, + stop: Optional[Union[str, list]] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return super().get_config() + + def get_supported_openai_params(self, model: str) -> List[str]: + return [ + "temperature", + "n", + "stream", + "stream_options", + "stop", + "max_tokens", + "max_completion_tokens", + "tools", + "tool_choice", + "presence_penalty", + "frequency_penalty", + "logit_bias", + "user", + "function_call", + "functions", + "tools", + "tool_choice", + "top_p", + "logprobs", + "top_logprobs", + "response_format", + "seed", + "extra_headers", + "parallel_tool_calls", + "prediction", + "modalities", + "audio", + ] + + def _is_response_format_supported_model(self, model: str) -> bool: + """ + - all 4o models are supported + - check if 'supports_response_format' is True from get_model_info + - [TODO] support smart retries for 3.5 models (some supported, some not) + """ + if "4o" in model: + return True + elif supports_response_schema(model): + return True + + return False + + def _is_response_format_supported_api_version( + self, api_version_year: str, api_version_month: str + ) -> bool: + """ + - check if api_version is supported for response_format + """ + + is_supported = int(api_version_year) <= 2024 and int(api_version_month) >= 8 + + return is_supported + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + api_version: str = "", + ) -> dict: + supported_openai_params = self.get_supported_openai_params(model) + + api_version_times = api_version.split("-") + api_version_year = api_version_times[0] + api_version_month = api_version_times[1] + api_version_day = api_version_times[2] + for param, value in non_default_params.items(): + if param == "tool_choice": + """ + This parameter requires API version 2023-12-01-preview or later + + tool_choice='required' is not supported as of 2024-05-01-preview + """ + ## check if api version supports this param ## + if ( + api_version_year < "2023" + or (api_version_year == "2023" and api_version_month < "12") + or ( + api_version_year == "2023" + and api_version_month == "12" + and api_version_day < "01" + ) + ): + if litellm.drop_params is True or ( + drop_params is not None and drop_params is True + ): + pass + else: + raise UnsupportedParamsError( + status_code=400, + message=f"""Azure does not support 'tool_choice', for api_version={api_version}. Bump your API version to '2023-12-01-preview' or later. This parameter requires 'api_version="2023-12-01-preview"' or later. Azure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions""", + ) + elif value == "required" and ( + api_version_year == "2024" and api_version_month <= "05" + ): ## check if tool_choice value is supported ## + if litellm.drop_params is True or ( + drop_params is not None and drop_params is True + ): + pass + else: + raise UnsupportedParamsError( + status_code=400, + message=f"Azure does not support '{value}' as a {param} param, for api_version={api_version}. To drop 'tool_choice=required' for calls with this Azure API version, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\nAzure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions", + ) + else: + optional_params["tool_choice"] = value + elif param == "response_format" and isinstance(value, dict): + _is_response_format_supported_model = ( + self._is_response_format_supported_model(model) + ) + + is_response_format_supported_api_version = ( + self._is_response_format_supported_api_version( + api_version_year, api_version_month + ) + ) + is_response_format_supported = ( + is_response_format_supported_api_version + and _is_response_format_supported_model + ) + optional_params = self._add_response_format_to_tools( + optional_params=optional_params, + value=value, + is_response_format_supported=is_response_format_supported, + ) + elif param == "tools" and isinstance(value, list): + optional_params.setdefault("tools", []) + optional_params["tools"].extend(value) + elif param in supported_openai_params: + optional_params[param] = value + + return optional_params + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + messages = convert_to_azure_openai_messages(messages) + return { + "model": model, + "messages": messages, + **optional_params, + } + + def transform_response( + self, + model: str, + raw_response: Response, + model_response: ModelResponse, + logging_obj: LoggingClass, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + raise NotImplementedError( + "Azure OpenAI handler.py has custom logic for transforming response, as it uses the OpenAI SDK." + ) + + def get_mapped_special_auth_params(self) -> dict: + return {"token": "azure_ad_token"} + + def map_special_auth_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + if param == "token": + optional_params["azure_ad_token"] = value + return optional_params + + def get_eu_regions(self) -> List[str]: + """ + Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability + """ + return ["europe", "sweden", "switzerland", "france", "uk"] + + def get_us_regions(self) -> List[str]: + """ + Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability + """ + return [ + "us", + "eastus", + "eastus2", + "eastus2euap", + "eastus3", + "southcentralus", + "westus", + "westus2", + "westus3", + "westus4", + ] + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, Headers] + ) -> BaseLLMException: + return AzureOpenAIError( + message=error_message, status_code=status_code, headers=headers + ) + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + raise NotImplementedError( + "Azure OpenAI has custom logic for validating environment, as it uses the OpenAI SDK." + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure/chat/o_series_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure/chat/o_series_handler.py new file mode 100644 index 00000000..2f3e9e63 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure/chat/o_series_handler.py @@ -0,0 +1,72 @@ +""" +Handler file for calls to Azure OpenAI's o1/o3 family of models + +Written separately to handle faking streaming for o1 and o3 models. +""" + +from typing import Any, Callable, Optional, Union + +import httpx + +from litellm.types.utils import ModelResponse + +from ...openai.openai import OpenAIChatCompletion +from ..common_utils import BaseAzureLLM + + +class AzureOpenAIO1ChatCompletion(BaseAzureLLM, OpenAIChatCompletion): + def completion( + self, + model_response: ModelResponse, + timeout: Union[float, httpx.Timeout], + optional_params: dict, + litellm_params: dict, + logging_obj: Any, + model: Optional[str] = None, + messages: Optional[list] = None, + print_verbose: Optional[Callable] = None, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + dynamic_params: Optional[bool] = None, + azure_ad_token: Optional[str] = None, + acompletion: bool = False, + logger_fn=None, + headers: Optional[dict] = None, + custom_prompt_dict: dict = {}, + client=None, + organization: Optional[str] = None, + custom_llm_provider: Optional[str] = None, + drop_params: Optional[bool] = None, + ): + client = self.get_azure_openai_client( + litellm_params=litellm_params, + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=acompletion, + ) + return super().completion( + model_response=model_response, + timeout=timeout, + optional_params=optional_params, + litellm_params=litellm_params, + logging_obj=logging_obj, + model=model, + messages=messages, + print_verbose=print_verbose, + api_key=api_key, + api_base=api_base, + api_version=api_version, + dynamic_params=dynamic_params, + azure_ad_token=azure_ad_token, + acompletion=acompletion, + logger_fn=logger_fn, + headers=headers, + custom_prompt_dict=custom_prompt_dict, + client=client, + organization=organization, + custom_llm_provider=custom_llm_provider, + drop_params=drop_params, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure/chat/o_series_transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure/chat/o_series_transformation.py new file mode 100644 index 00000000..0ca3a28d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure/chat/o_series_transformation.py @@ -0,0 +1,75 @@ +""" +Support for o1 and o3 model families + +https://platform.openai.com/docs/guides/reasoning + +Translations handled by LiteLLM: +- modalities: image => drop param (if user opts in to dropping param) +- role: system ==> translate to role 'user' +- streaming => faked by LiteLLM +- Tools, response_format => drop param (if user opts in to dropping param) +- Logprobs => drop param (if user opts in to dropping param) +- Temperature => drop param (if user opts in to dropping param) +""" + +from typing import List, Optional + +from litellm import verbose_logger +from litellm.types.llms.openai import AllMessageValues +from litellm.utils import get_model_info + +from ...openai.chat.o_series_transformation import OpenAIOSeriesConfig + + +class AzureOpenAIO1Config(OpenAIOSeriesConfig): + def should_fake_stream( + self, + model: Optional[str], + stream: Optional[bool], + custom_llm_provider: Optional[str] = None, + ) -> bool: + """ + Currently no Azure O Series models support native streaming. + """ + + if stream is not True: + return False + + if ( + model and "o3" in model + ): # o3 models support streaming - https://github.com/BerriAI/litellm/issues/8274 + return False + + if model is not None: + try: + model_info = get_model_info( + model=model, custom_llm_provider=custom_llm_provider + ) # allow user to override default with model_info={"supports_native_streaming": true} + + if ( + model_info.get("supports_native_streaming") is True + ): # allow user to override default with model_info={"supports_native_streaming": true} + return False + except Exception as e: + verbose_logger.debug( + f"Error getting model info in AzureOpenAIO1Config: {e}" + ) + return True + + def is_o_series_model(self, model: str) -> bool: + return "o1" in model or "o3" in model or "o_series/" in model + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + model = model.replace( + "o_series/", "" + ) # handle o_series/my-random-deployment-name + return super().transform_request( + model, messages, optional_params, litellm_params, headers + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure/common_utils.py new file mode 100644 index 00000000..71092c8b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure/common_utils.py @@ -0,0 +1,426 @@ +import json +import os +from typing import Any, Callable, Dict, Optional, Union + +import httpx +from openai import AsyncAzureOpenAI, AzureOpenAI + +import litellm +from litellm._logging import verbose_logger +from litellm.caching.caching import DualCache +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.llms.openai.common_utils import BaseOpenAILLM +from litellm.secret_managers.get_azure_ad_token_provider import ( + get_azure_ad_token_provider, +) +from litellm.secret_managers.main import get_secret_str + +azure_ad_cache = DualCache() + + +class AzureOpenAIError(BaseLLMException): + def __init__( + self, + status_code, + message, + request: Optional[httpx.Request] = None, + response: Optional[httpx.Response] = None, + headers: Optional[Union[httpx.Headers, dict]] = None, + body: Optional[dict] = None, + ): + super().__init__( + status_code=status_code, + message=message, + request=request, + response=response, + headers=headers, + body=body, + ) + + +def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict: + openai_headers = {} + if "x-ratelimit-limit-requests" in headers: + openai_headers["x-ratelimit-limit-requests"] = headers[ + "x-ratelimit-limit-requests" + ] + if "x-ratelimit-remaining-requests" in headers: + openai_headers["x-ratelimit-remaining-requests"] = headers[ + "x-ratelimit-remaining-requests" + ] + if "x-ratelimit-limit-tokens" in headers: + openai_headers["x-ratelimit-limit-tokens"] = headers["x-ratelimit-limit-tokens"] + if "x-ratelimit-remaining-tokens" in headers: + openai_headers["x-ratelimit-remaining-tokens"] = headers[ + "x-ratelimit-remaining-tokens" + ] + llm_response_headers = { + "{}-{}".format("llm_provider", k): v for k, v in headers.items() + } + + return {**llm_response_headers, **openai_headers} + + +def get_azure_ad_token_from_entrata_id( + tenant_id: str, + client_id: str, + client_secret: str, + scope: str = "https://cognitiveservices.azure.com/.default", +) -> Callable[[], str]: + """ + Get Azure AD token provider from `client_id`, `client_secret`, and `tenant_id` + + Args: + tenant_id: str + client_id: str + client_secret: str + scope: str + + Returns: + callable that returns a bearer token. + """ + from azure.identity import ClientSecretCredential, get_bearer_token_provider + + verbose_logger.debug("Getting Azure AD Token from Entrata ID") + + if tenant_id.startswith("os.environ/"): + _tenant_id = get_secret_str(tenant_id) + else: + _tenant_id = tenant_id + + if client_id.startswith("os.environ/"): + _client_id = get_secret_str(client_id) + else: + _client_id = client_id + + if client_secret.startswith("os.environ/"): + _client_secret = get_secret_str(client_secret) + else: + _client_secret = client_secret + + verbose_logger.debug( + "tenant_id %s, client_id %s, client_secret %s", + _tenant_id, + _client_id, + _client_secret, + ) + if _tenant_id is None or _client_id is None or _client_secret is None: + raise ValueError("tenant_id, client_id, and client_secret must be provided") + credential = ClientSecretCredential(_tenant_id, _client_id, _client_secret) + + verbose_logger.debug("credential %s", credential) + + token_provider = get_bearer_token_provider(credential, scope) + + verbose_logger.debug("token_provider %s", token_provider) + + return token_provider + + +def get_azure_ad_token_from_username_password( + client_id: str, + azure_username: str, + azure_password: str, + scope: str = "https://cognitiveservices.azure.com/.default", +) -> Callable[[], str]: + """ + Get Azure AD token provider from `client_id`, `azure_username`, and `azure_password` + + Args: + client_id: str + azure_username: str + azure_password: str + scope: str + + Returns: + callable that returns a bearer token. + """ + from azure.identity import UsernamePasswordCredential, get_bearer_token_provider + + verbose_logger.debug( + "client_id %s, azure_username %s, azure_password %s", + client_id, + azure_username, + azure_password, + ) + credential = UsernamePasswordCredential( + client_id=client_id, + username=azure_username, + password=azure_password, + ) + + verbose_logger.debug("credential %s", credential) + + token_provider = get_bearer_token_provider(credential, scope) + + verbose_logger.debug("token_provider %s", token_provider) + + return token_provider + + +def get_azure_ad_token_from_oidc(azure_ad_token: str): + azure_client_id = os.getenv("AZURE_CLIENT_ID", None) + azure_tenant_id = os.getenv("AZURE_TENANT_ID", None) + azure_authority_host = os.getenv( + "AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com" + ) + + if azure_client_id is None or azure_tenant_id is None: + raise AzureOpenAIError( + status_code=422, + message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set", + ) + + oidc_token = get_secret_str(azure_ad_token) + + if oidc_token is None: + raise AzureOpenAIError( + status_code=401, + message="OIDC token could not be retrieved from secret manager.", + ) + + azure_ad_token_cache_key = json.dumps( + { + "azure_client_id": azure_client_id, + "azure_tenant_id": azure_tenant_id, + "azure_authority_host": azure_authority_host, + "oidc_token": oidc_token, + } + ) + + azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key) + if azure_ad_token_access_token is not None: + return azure_ad_token_access_token + + client = litellm.module_level_client + req_token = client.post( + f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token", + data={ + "client_id": azure_client_id, + "grant_type": "client_credentials", + "scope": "https://cognitiveservices.azure.com/.default", + "client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", + "client_assertion": oidc_token, + }, + ) + + if req_token.status_code != 200: + raise AzureOpenAIError( + status_code=req_token.status_code, + message=req_token.text, + ) + + azure_ad_token_json = req_token.json() + azure_ad_token_access_token = azure_ad_token_json.get("access_token", None) + azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None) + + if azure_ad_token_access_token is None: + raise AzureOpenAIError( + status_code=422, message="Azure AD Token access_token not returned" + ) + + if azure_ad_token_expires_in is None: + raise AzureOpenAIError( + status_code=422, message="Azure AD Token expires_in not returned" + ) + + azure_ad_cache.set_cache( + key=azure_ad_token_cache_key, + value=azure_ad_token_access_token, + ttl=azure_ad_token_expires_in, + ) + + return azure_ad_token_access_token + + +def select_azure_base_url_or_endpoint(azure_client_params: dict): + azure_endpoint = azure_client_params.get("azure_endpoint", None) + if azure_endpoint is not None: + # see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192 + if "/openai/deployments" in azure_endpoint: + # this is base_url, not an azure_endpoint + azure_client_params["base_url"] = azure_endpoint + azure_client_params.pop("azure_endpoint") + + return azure_client_params + + +class BaseAzureLLM(BaseOpenAILLM): + def get_azure_openai_client( + self, + api_key: Optional[str], + api_base: Optional[str], + api_version: Optional[str] = None, + client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + litellm_params: Optional[dict] = None, + _is_async: bool = False, + model: Optional[str] = None, + ) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI]]: + openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None + client_initialization_params: dict = locals() + if client is None: + cached_client = self.get_cached_openai_client( + client_initialization_params=client_initialization_params, + client_type="azure", + ) + if cached_client: + if isinstance(cached_client, AzureOpenAI) or isinstance( + cached_client, AsyncAzureOpenAI + ): + return cached_client + + azure_client_params = self.initialize_azure_sdk_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + model_name=model, + api_version=api_version, + is_async=_is_async, + ) + if _is_async is True: + openai_client = AsyncAzureOpenAI(**azure_client_params) + else: + openai_client = AzureOpenAI(**azure_client_params) # type: ignore + else: + openai_client = client + if api_version is not None and isinstance( + openai_client._custom_query, dict + ): + # set api_version to version passed by user + openai_client._custom_query.setdefault("api-version", api_version) + + # save client in-memory cache + self.set_cached_openai_client( + openai_client=openai_client, + client_initialization_params=client_initialization_params, + client_type="azure", + ) + return openai_client + + def initialize_azure_sdk_client( + self, + litellm_params: dict, + api_key: Optional[str], + api_base: Optional[str], + model_name: Optional[str], + api_version: Optional[str], + is_async: bool, + ) -> dict: + + azure_ad_token_provider: Optional[Callable[[], str]] = None + # If we have api_key, then we have higher priority + azure_ad_token = litellm_params.get("azure_ad_token") + tenant_id = litellm_params.get("tenant_id") + client_id = litellm_params.get("client_id") + client_secret = litellm_params.get("client_secret") + azure_username = litellm_params.get("azure_username") + azure_password = litellm_params.get("azure_password") + max_retries = litellm_params.get("max_retries") + timeout = litellm_params.get("timeout") + if not api_key and tenant_id and client_id and client_secret: + verbose_logger.debug("Using Azure AD Token Provider for Azure Auth") + azure_ad_token_provider = get_azure_ad_token_from_entrata_id( + tenant_id=tenant_id, + client_id=client_id, + client_secret=client_secret, + ) + if azure_username and azure_password and client_id: + azure_ad_token_provider = get_azure_ad_token_from_username_password( + azure_username=azure_username, + azure_password=azure_password, + client_id=client_id, + ) + + if azure_ad_token is not None and azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) + elif ( + not api_key + and azure_ad_token_provider is None + and litellm.enable_azure_ad_token_refresh is True + ): + try: + azure_ad_token_provider = get_azure_ad_token_provider() + except ValueError: + verbose_logger.debug("Azure AD Token Provider could not be used.") + if api_version is None: + api_version = os.getenv( + "AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION + ) + + _api_key = api_key + if _api_key is not None and isinstance(_api_key, str): + # only show first 5 chars of api_key + _api_key = _api_key[:8] + "*" * 15 + verbose_logger.debug( + f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}" + ) + azure_client_params = { + "api_key": api_key, + "azure_endpoint": api_base, + "api_version": api_version, + "azure_ad_token": azure_ad_token, + "azure_ad_token_provider": azure_ad_token_provider, + } + # init http client + SSL Verification settings + if is_async is True: + azure_client_params["http_client"] = self._get_async_http_client() + else: + azure_client_params["http_client"] = self._get_sync_http_client() + + if max_retries is not None: + azure_client_params["max_retries"] = max_retries + if timeout is not None: + azure_client_params["timeout"] = timeout + + if azure_ad_token_provider is not None: + azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider + # this decides if we should set azure_endpoint or base_url on Azure OpenAI Client + # required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client + + azure_client_params = select_azure_base_url_or_endpoint( + azure_client_params=azure_client_params + ) + + return azure_client_params + + def _init_azure_client_for_cloudflare_ai_gateway( + self, + api_base: str, + model: str, + api_version: str, + max_retries: int, + timeout: Union[float, httpx.Timeout], + api_key: Optional[str], + azure_ad_token: Optional[str], + azure_ad_token_provider: Optional[Callable[[], str]], + acompletion: bool, + client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + ) -> Union[AzureOpenAI, AsyncAzureOpenAI]: + ## build base url - assume api base includes resource name + if client is None: + if not api_base.endswith("/"): + api_base += "/" + api_base += f"{model}" + + azure_client_params: Dict[str, Any] = { + "api_version": api_version, + "base_url": f"{api_base}", + "http_client": litellm.client_session, + "max_retries": max_retries, + "timeout": timeout, + } + if api_key is not None: + azure_client_params["api_key"] = api_key + elif azure_ad_token is not None: + if azure_ad_token.startswith("oidc/"): + azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) + + azure_client_params["azure_ad_token"] = azure_ad_token + if azure_ad_token_provider is not None: + azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider + + if acompletion is True: + client = AsyncAzureOpenAI(**azure_client_params) # type: ignore + else: + client = AzureOpenAI(**azure_client_params) # type: ignore + return client diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/handler.py new file mode 100644 index 00000000..8301c4d6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/handler.py @@ -0,0 +1,378 @@ +from typing import Any, Callable, Optional + +from openai import AsyncAzureOpenAI, AzureOpenAI + +from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory +from litellm.utils import CustomStreamWrapper, ModelResponse, TextCompletionResponse + +from ...openai.completion.transformation import OpenAITextCompletionConfig +from ..common_utils import AzureOpenAIError, BaseAzureLLM + +openai_text_completion_config = OpenAITextCompletionConfig() + + +class AzureTextCompletion(BaseAzureLLM): + def __init__(self) -> None: + super().__init__() + + def validate_environment(self, api_key, azure_ad_token): + headers = { + "content-type": "application/json", + } + if api_key is not None: + headers["api-key"] = api_key + elif azure_ad_token is not None: + headers["Authorization"] = f"Bearer {azure_ad_token}" + return headers + + def completion( # noqa: PLR0915 + self, + model: str, + messages: list, + model_response: ModelResponse, + api_key: str, + api_base: str, + api_version: str, + api_type: str, + azure_ad_token: str, + azure_ad_token_provider: Optional[Callable], + print_verbose: Callable, + timeout, + logging_obj, + optional_params, + litellm_params, + logger_fn, + acompletion: bool = False, + headers: Optional[dict] = None, + client=None, + ): + try: + if model is None or messages is None: + raise AzureOpenAIError( + status_code=422, message="Missing model or messages" + ) + + max_retries = optional_params.pop("max_retries", 2) + prompt = prompt_factory( + messages=messages, model=model, custom_llm_provider="azure_text" + ) + + ### CHECK IF CLOUDFLARE AI GATEWAY ### + ### if so - set the model as part of the base url + if "gateway.ai.cloudflare.com" in api_base: + ## build base url - assume api base includes resource name + client = self._init_azure_client_for_cloudflare_ai_gateway( + api_key=api_key, + api_version=api_version, + api_base=api_base, + model=model, + client=client, + max_retries=max_retries, + timeout=timeout, + azure_ad_token=azure_ad_token, + azure_ad_token_provider=azure_ad_token_provider, + acompletion=acompletion, + ) + + data = {"model": None, "prompt": prompt, **optional_params} + else: + data = { + "model": model, # type: ignore + "prompt": prompt, + **optional_params, + } + + if acompletion is True: + if optional_params.get("stream", False): + return self.async_streaming( + logging_obj=logging_obj, + api_base=api_base, + data=data, + model=model, + api_key=api_key, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + client=client, + litellm_params=litellm_params, + ) + else: + return self.acompletion( + api_base=api_base, + data=data, + model_response=model_response, + api_key=api_key, + api_version=api_version, + model=model, + azure_ad_token=azure_ad_token, + timeout=timeout, + client=client, + logging_obj=logging_obj, + max_retries=max_retries, + litellm_params=litellm_params, + ) + elif "stream" in optional_params and optional_params["stream"] is True: + return self.streaming( + logging_obj=logging_obj, + api_base=api_base, + data=data, + model=model, + api_key=api_key, + api_version=api_version, + azure_ad_token=azure_ad_token, + timeout=timeout, + client=client, + ) + else: + ## LOGGING + logging_obj.pre_call( + input=prompt, + api_key=api_key, + additional_args={ + "headers": { + "api_key": api_key, + "azure_ad_token": azure_ad_token, + }, + "api_version": api_version, + "api_base": api_base, + "complete_input_dict": data, + }, + ) + if not isinstance(max_retries, int): + raise AzureOpenAIError( + status_code=422, message="max retries must be an int" + ) + # init AzureOpenAI Client + azure_client = self.get_azure_openai_client( + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + litellm_params=litellm_params, + _is_async=False, + model=model, + ) + + if not isinstance(azure_client, AzureOpenAI): + raise AzureOpenAIError( + status_code=500, + message="azure_client is not an instance of AzureOpenAI", + ) + + raw_response = azure_client.completions.with_raw_response.create( + **data, timeout=timeout + ) + response = raw_response.parse() + stringified_response = response.model_dump() + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key=api_key, + original_response=stringified_response, + additional_args={ + "headers": headers, + "api_version": api_version, + "api_base": api_base, + }, + ) + return ( + openai_text_completion_config.convert_to_chat_model_response_object( + response_object=TextCompletionResponse(**stringified_response), + model_response_object=model_response, + ) + ) + except AzureOpenAIError as e: + raise e + except Exception as e: + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise AzureOpenAIError( + status_code=status_code, message=str(e), headers=error_headers + ) + + async def acompletion( + self, + api_key: str, + api_version: str, + model: str, + api_base: str, + data: dict, + timeout: Any, + model_response: ModelResponse, + logging_obj: Any, + max_retries: int, + azure_ad_token: Optional[str] = None, + client=None, # this is the AsyncAzureOpenAI + litellm_params: dict = {}, + ): + response = None + try: + # init AzureOpenAI Client + # setting Azure client + azure_client = self.get_azure_openai_client( + api_version=api_version, + api_base=api_base, + api_key=api_key, + model=model, + _is_async=True, + client=client, + litellm_params=litellm_params, + ) + if not isinstance(azure_client, AsyncAzureOpenAI): + raise AzureOpenAIError( + status_code=500, + message="azure_client is not an instance of AsyncAzureOpenAI", + ) + + ## LOGGING + logging_obj.pre_call( + input=data["prompt"], + api_key=azure_client.api_key, + additional_args={ + "headers": {"Authorization": f"Bearer {azure_client.api_key}"}, + "api_base": azure_client._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, + ) + raw_response = await azure_client.completions.with_raw_response.create( + **data, timeout=timeout + ) + response = raw_response.parse() + return openai_text_completion_config.convert_to_chat_model_response_object( + response_object=response.model_dump(), + model_response_object=model_response, + ) + except AzureOpenAIError as e: + raise e + except Exception as e: + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise AzureOpenAIError( + status_code=status_code, message=str(e), headers=error_headers + ) + + def streaming( + self, + logging_obj, + api_base: str, + api_key: str, + api_version: str, + data: dict, + model: str, + timeout: Any, + azure_ad_token: Optional[str] = None, + client=None, + litellm_params: dict = {}, + ): + max_retries = data.pop("max_retries", 2) + if not isinstance(max_retries, int): + raise AzureOpenAIError( + status_code=422, message="max retries must be an int" + ) + # init AzureOpenAI Client + azure_client = self.get_azure_openai_client( + api_version=api_version, + api_base=api_base, + api_key=api_key, + model=model, + _is_async=False, + client=client, + litellm_params=litellm_params, + ) + if not isinstance(azure_client, AzureOpenAI): + raise AzureOpenAIError( + status_code=500, + message="azure_client is not an instance of AzureOpenAI", + ) + + ## LOGGING + logging_obj.pre_call( + input=data["prompt"], + api_key=azure_client.api_key, + additional_args={ + "headers": {"Authorization": f"Bearer {azure_client.api_key}"}, + "api_base": azure_client._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, + ) + raw_response = azure_client.completions.with_raw_response.create( + **data, timeout=timeout + ) + response = raw_response.parse() + streamwrapper = CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider="azure_text", + logging_obj=logging_obj, + ) + return streamwrapper + + async def async_streaming( + self, + logging_obj, + api_base: str, + api_key: str, + api_version: str, + data: dict, + model: str, + timeout: Any, + azure_ad_token: Optional[str] = None, + client=None, + litellm_params: dict = {}, + ): + try: + # init AzureOpenAI Client + azure_client = self.get_azure_openai_client( + api_version=api_version, + api_base=api_base, + api_key=api_key, + model=model, + _is_async=True, + client=client, + litellm_params=litellm_params, + ) + if not isinstance(azure_client, AsyncAzureOpenAI): + raise AzureOpenAIError( + status_code=500, + message="azure_client is not an instance of AsyncAzureOpenAI", + ) + ## LOGGING + logging_obj.pre_call( + input=data["prompt"], + api_key=azure_client.api_key, + additional_args={ + "headers": {"Authorization": f"Bearer {azure_client.api_key}"}, + "api_base": azure_client._base_url._uri_reference, + "acompletion": True, + "complete_input_dict": data, + }, + ) + raw_response = await azure_client.completions.with_raw_response.create( + **data, timeout=timeout + ) + response = raw_response.parse() + # return response + streamwrapper = CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider="azure_text", + logging_obj=logging_obj, + ) + return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails + except Exception as e: + status_code = getattr(e, "status_code", 500) + error_headers = getattr(e, "headers", None) + error_response = getattr(e, "response", None) + if error_headers is None and error_response: + error_headers = getattr(error_response, "headers", None) + raise AzureOpenAIError( + status_code=status_code, message=str(e), headers=error_headers + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/transformation.py new file mode 100644 index 00000000..bc7b97c6 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/transformation.py @@ -0,0 +1,53 @@ +from typing import Optional, Union + +from ...openai.completion.transformation import OpenAITextCompletionConfig + + +class AzureOpenAITextConfig(OpenAITextCompletionConfig): + """ + Reference: https://platform.openai.com/docs/api-reference/chat/create + + The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters:: + + - `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition. + + - `function_call` (string or object): This optional parameter controls how the model calls functions. + + - `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs. + + - `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion. + + - `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. + + - `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message. + + - `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics. + + - `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens. + + - `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. + + - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. + """ + + def __init__( + self, + frequency_penalty: Optional[int] = None, + logit_bias: Optional[dict] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[int] = None, + stop: Optional[Union[str, list]] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + ) -> None: + super().__init__( + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + stop=stop, + temperature=temperature, + top_p=top_p, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure/cost_calculation.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure/cost_calculation.py new file mode 100644 index 00000000..96c58d95 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure/cost_calculation.py @@ -0,0 +1,61 @@ +""" +Helper util for handling azure openai-specific cost calculation +- e.g.: prompt caching +""" + +from typing import Optional, Tuple + +from litellm._logging import verbose_logger +from litellm.types.utils import Usage +from litellm.utils import get_model_info + + +def cost_per_token( + model: str, usage: Usage, response_time_ms: Optional[float] = 0.0 +) -> Tuple[float, float]: + """ + Calculates the cost per token for a given model, prompt tokens, and completion tokens. + + Input: + - model: str, the model name without provider prefix + - usage: LiteLLM Usage block, containing anthropic caching information + + Returns: + Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd + """ + ## GET MODEL INFO + model_info = get_model_info(model=model, custom_llm_provider="azure") + cached_tokens: Optional[int] = None + ## CALCULATE INPUT COST + non_cached_text_tokens = usage.prompt_tokens + if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens: + cached_tokens = usage.prompt_tokens_details.cached_tokens + non_cached_text_tokens = non_cached_text_tokens - cached_tokens + prompt_cost: float = non_cached_text_tokens * model_info["input_cost_per_token"] + + ## CALCULATE OUTPUT COST + completion_cost: float = ( + usage["completion_tokens"] * model_info["output_cost_per_token"] + ) + + ## Prompt Caching cost calculation + if model_info.get("cache_read_input_token_cost") is not None and cached_tokens: + # Note: We read ._cache_read_input_tokens from the Usage - since cost_calculator.py standardizes the cache read tokens on usage._cache_read_input_tokens + prompt_cost += cached_tokens * ( + model_info.get("cache_read_input_token_cost", 0) or 0 + ) + + ## Speech / Audio cost calculation + if ( + "output_cost_per_second" in model_info + and model_info["output_cost_per_second"] is not None + and response_time_ms is not None + ): + verbose_logger.debug( + f"For model={model} - output_cost_per_second: {model_info.get('output_cost_per_second')}; response time: {response_time_ms}" + ) + ## COST PER SECOND ## + prompt_cost = 0 + completion_cost = model_info["output_cost_per_second"] * response_time_ms / 1000 + + return prompt_cost, completion_cost diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure/files/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure/files/handler.py new file mode 100644 index 00000000..d45ac9a3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure/files/handler.py @@ -0,0 +1,284 @@ +from typing import Any, Coroutine, Optional, Union, cast + +import httpx +from openai import AsyncAzureOpenAI, AzureOpenAI +from openai.types.file_deleted import FileDeleted + +from litellm._logging import verbose_logger +from litellm.types.llms.openai import * + +from ..common_utils import BaseAzureLLM + + +class AzureOpenAIFilesAPI(BaseAzureLLM): + """ + AzureOpenAI methods to support for batches + - create_file() + - retrieve_file() + - list_files() + - delete_file() + - file_content() + - update_file() + """ + + def __init__(self) -> None: + super().__init__() + + async def acreate_file( + self, + create_file_data: CreateFileRequest, + openai_client: AsyncAzureOpenAI, + ) -> FileObject: + verbose_logger.debug("create_file_data=%s", create_file_data) + response = await openai_client.files.create(**create_file_data) + verbose_logger.debug("create_file_response=%s", response) + return response + + def create_file( + self, + _is_async: bool, + create_file_data: CreateFileRequest, + api_base: Optional[str], + api_key: Optional[str], + api_version: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + litellm_params: Optional[dict] = None, + ) -> Union[FileObject, Coroutine[Any, Any, FileObject]]: + + openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( + self.get_azure_openai_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, + ) + ) + if openai_client is None: + raise ValueError( + "AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncAzureOpenAI): + raise ValueError( + "AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client." + ) + return self.acreate_file( # type: ignore + create_file_data=create_file_data, openai_client=openai_client + ) + response = openai_client.files.create(**create_file_data) + return response + + async def afile_content( + self, + file_content_request: FileContentRequest, + openai_client: AsyncAzureOpenAI, + ) -> HttpxBinaryResponseContent: + response = await openai_client.files.content(**file_content_request) + return HttpxBinaryResponseContent(response=response.response) + + def file_content( + self, + _is_async: bool, + file_content_request: FileContentRequest, + api_base: Optional[str], + api_key: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + api_version: Optional[str] = None, + client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + litellm_params: Optional[dict] = None, + ) -> Union[ + HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent] + ]: + openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( + self.get_azure_openai_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, + ) + ) + if openai_client is None: + raise ValueError( + "AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncAzureOpenAI): + raise ValueError( + "AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client." + ) + return self.afile_content( # type: ignore + file_content_request=file_content_request, + openai_client=openai_client, + ) + response = cast(AzureOpenAI, openai_client).files.content( + **file_content_request + ) + + return HttpxBinaryResponseContent(response=response.response) + + async def aretrieve_file( + self, + file_id: str, + openai_client: AsyncAzureOpenAI, + ) -> FileObject: + response = await openai_client.files.retrieve(file_id=file_id) + return response + + def retrieve_file( + self, + _is_async: bool, + file_id: str, + api_base: Optional[str], + api_key: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + api_version: Optional[str] = None, + client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + litellm_params: Optional[dict] = None, + ): + openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( + self.get_azure_openai_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, + ) + ) + if openai_client is None: + raise ValueError( + "AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncAzureOpenAI): + raise ValueError( + "AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client." + ) + return self.aretrieve_file( # type: ignore + file_id=file_id, + openai_client=openai_client, + ) + response = openai_client.files.retrieve(file_id=file_id) + + return response + + async def adelete_file( + self, + file_id: str, + openai_client: AsyncAzureOpenAI, + ) -> FileDeleted: + response = await openai_client.files.delete(file_id=file_id) + + if not isinstance(response, FileDeleted): # azure returns an empty string + return FileDeleted(id=file_id, deleted=True, object="file") + return response + + def delete_file( + self, + _is_async: bool, + file_id: str, + api_base: Optional[str], + api_key: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str] = None, + api_version: Optional[str] = None, + client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + litellm_params: Optional[dict] = None, + ): + openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( + self.get_azure_openai_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, + ) + ) + if openai_client is None: + raise ValueError( + "AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncAzureOpenAI): + raise ValueError( + "AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client." + ) + return self.adelete_file( # type: ignore + file_id=file_id, + openai_client=openai_client, + ) + response = openai_client.files.delete(file_id=file_id) + + if not isinstance(response, FileDeleted): # azure returns an empty string + return FileDeleted(id=file_id, deleted=True, object="file") + + return response + + async def alist_files( + self, + openai_client: AsyncAzureOpenAI, + purpose: Optional[str] = None, + ): + if isinstance(purpose, str): + response = await openai_client.files.list(purpose=purpose) + else: + response = await openai_client.files.list() + return response + + def list_files( + self, + _is_async: bool, + api_base: Optional[str], + api_key: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + purpose: Optional[str] = None, + api_version: Optional[str] = None, + client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None, + litellm_params: Optional[dict] = None, + ): + openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = ( + self.get_azure_openai_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, + ) + ) + if openai_client is None: + raise ValueError( + "AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment." + ) + + if _is_async is True: + if not isinstance(openai_client, AsyncAzureOpenAI): + raise ValueError( + "AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client." + ) + return self.alist_files( # type: ignore + purpose=purpose, + openai_client=openai_client, + ) + + if isinstance(purpose, str): + response = openai_client.files.list(purpose=purpose) + else: + response = openai_client.files.list() + + return response diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure/fine_tuning/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure/fine_tuning/handler.py new file mode 100644 index 00000000..3d7cc336 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure/fine_tuning/handler.py @@ -0,0 +1,47 @@ +from typing import Optional, Union + +import httpx +from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI + +from litellm.llms.azure.common_utils import BaseAzureLLM +from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI + + +class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI, BaseAzureLLM): + """ + AzureOpenAI methods to support fine tuning, inherits from OpenAIFineTuningAPI. + """ + + def get_openai_client( + self, + api_key: Optional[str], + api_base: Optional[str], + timeout: Union[float, httpx.Timeout], + max_retries: Optional[int], + organization: Optional[str], + client: Optional[ + Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI] + ] = None, + _is_async: bool = False, + api_version: Optional[str] = None, + litellm_params: Optional[dict] = None, + ) -> Optional[ + Union[ + OpenAI, + AsyncOpenAI, + AzureOpenAI, + AsyncAzureOpenAI, + ] + ]: + # Override to use Azure-specific client initialization + if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI): + client = None + + return self.get_azure_openai_client( + litellm_params=litellm_params or {}, + api_key=api_key, + api_base=api_base, + api_version=api_version, + client=client, + _is_async=_is_async, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure/realtime/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure/realtime/handler.py new file mode 100644 index 00000000..5a4865e7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure/realtime/handler.py @@ -0,0 +1,75 @@ +""" +This file contains the calling Azure OpenAI's `/openai/realtime` endpoint. + +This requires websockets, and is currently only supported on LiteLLM Proxy. +""" + +from typing import Any, Optional + +from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging +from ....litellm_core_utils.realtime_streaming import RealTimeStreaming +from ..azure import AzureChatCompletion + +# BACKEND_WS_URL = "ws://localhost:8080/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01" + + +async def forward_messages(client_ws: Any, backend_ws: Any): + import websockets + + try: + while True: + message = await backend_ws.recv() + await client_ws.send_text(message) + except websockets.exceptions.ConnectionClosed: # type: ignore + pass + + +class AzureOpenAIRealtime(AzureChatCompletion): + def _construct_url(self, api_base: str, model: str, api_version: str) -> str: + """ + Example output: + "wss://my-endpoint-sweden-berri992.openai.azure.com/openai/realtime?api-version=2024-10-01-preview&deployment=gpt-4o-realtime-preview"; + + """ + api_base = api_base.replace("https://", "wss://") + return ( + f"{api_base}/openai/realtime?api-version={api_version}&deployment={model}" + ) + + async def async_realtime( + self, + model: str, + websocket: Any, + api_base: Optional[str] = None, + api_key: Optional[str] = None, + api_version: Optional[str] = None, + azure_ad_token: Optional[str] = None, + client: Optional[Any] = None, + logging_obj: Optional[LiteLLMLogging] = None, + timeout: Optional[float] = None, + ): + import websockets + + if api_base is None: + raise ValueError("api_base is required for Azure OpenAI calls") + if api_version is None: + raise ValueError("api_version is required for Azure OpenAI calls") + + url = self._construct_url(api_base, model, api_version) + + try: + async with websockets.connect( # type: ignore + url, + extra_headers={ + "api-key": api_key, # type: ignore + }, + ) as backend_ws: + realtime_streaming = RealTimeStreaming( + websocket, backend_ws, logging_obj + ) + await realtime_streaming.bidirectional_forward() + + except websockets.exceptions.InvalidStatusCode as e: # type: ignore + await websocket.close(code=e.status_code, reason=str(e)) + except Exception: + pass |