aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/handler.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/handler.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/handler.py378
1 files changed, 378 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/handler.py
new file mode 100644
index 00000000..8301c4d6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/handler.py
@@ -0,0 +1,378 @@
+from typing import Any, Callable, Optional
+
+from openai import AsyncAzureOpenAI, AzureOpenAI
+
+from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory
+from litellm.utils import CustomStreamWrapper, ModelResponse, TextCompletionResponse
+
+from ...openai.completion.transformation import OpenAITextCompletionConfig
+from ..common_utils import AzureOpenAIError, BaseAzureLLM
+
+openai_text_completion_config = OpenAITextCompletionConfig()
+
+
+class AzureTextCompletion(BaseAzureLLM):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def validate_environment(self, api_key, azure_ad_token):
+ headers = {
+ "content-type": "application/json",
+ }
+ if api_key is not None:
+ headers["api-key"] = api_key
+ elif azure_ad_token is not None:
+ headers["Authorization"] = f"Bearer {azure_ad_token}"
+ return headers
+
+ def completion( # noqa: PLR0915
+ self,
+ model: str,
+ messages: list,
+ model_response: ModelResponse,
+ api_key: str,
+ api_base: str,
+ api_version: str,
+ api_type: str,
+ azure_ad_token: str,
+ azure_ad_token_provider: Optional[Callable],
+ print_verbose: Callable,
+ timeout,
+ logging_obj,
+ optional_params,
+ litellm_params,
+ logger_fn,
+ acompletion: bool = False,
+ headers: Optional[dict] = None,
+ client=None,
+ ):
+ try:
+ if model is None or messages is None:
+ raise AzureOpenAIError(
+ status_code=422, message="Missing model or messages"
+ )
+
+ max_retries = optional_params.pop("max_retries", 2)
+ prompt = prompt_factory(
+ messages=messages, model=model, custom_llm_provider="azure_text"
+ )
+
+ ### CHECK IF CLOUDFLARE AI GATEWAY ###
+ ### if so - set the model as part of the base url
+ if "gateway.ai.cloudflare.com" in api_base:
+ ## build base url - assume api base includes resource name
+ client = self._init_azure_client_for_cloudflare_ai_gateway(
+ api_key=api_key,
+ api_version=api_version,
+ api_base=api_base,
+ model=model,
+ client=client,
+ max_retries=max_retries,
+ timeout=timeout,
+ azure_ad_token=azure_ad_token,
+ azure_ad_token_provider=azure_ad_token_provider,
+ acompletion=acompletion,
+ )
+
+ data = {"model": None, "prompt": prompt, **optional_params}
+ else:
+ data = {
+ "model": model, # type: ignore
+ "prompt": prompt,
+ **optional_params,
+ }
+
+ if acompletion is True:
+ if optional_params.get("stream", False):
+ return self.async_streaming(
+ logging_obj=logging_obj,
+ api_base=api_base,
+ data=data,
+ model=model,
+ api_key=api_key,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ client=client,
+ litellm_params=litellm_params,
+ )
+ else:
+ return self.acompletion(
+ api_base=api_base,
+ data=data,
+ model_response=model_response,
+ api_key=api_key,
+ api_version=api_version,
+ model=model,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ client=client,
+ logging_obj=logging_obj,
+ max_retries=max_retries,
+ litellm_params=litellm_params,
+ )
+ elif "stream" in optional_params and optional_params["stream"] is True:
+ return self.streaming(
+ logging_obj=logging_obj,
+ api_base=api_base,
+ data=data,
+ model=model,
+ api_key=api_key,
+ api_version=api_version,
+ azure_ad_token=azure_ad_token,
+ timeout=timeout,
+ client=client,
+ )
+ else:
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=api_key,
+ additional_args={
+ "headers": {
+ "api_key": api_key,
+ "azure_ad_token": azure_ad_token,
+ },
+ "api_version": api_version,
+ "api_base": api_base,
+ "complete_input_dict": data,
+ },
+ )
+ if not isinstance(max_retries, int):
+ raise AzureOpenAIError(
+ status_code=422, message="max retries must be an int"
+ )
+ # init AzureOpenAI Client
+ azure_client = self.get_azure_openai_client(
+ api_key=api_key,
+ api_base=api_base,
+ api_version=api_version,
+ client=client,
+ litellm_params=litellm_params,
+ _is_async=False,
+ model=model,
+ )
+
+ if not isinstance(azure_client, AzureOpenAI):
+ raise AzureOpenAIError(
+ status_code=500,
+ message="azure_client is not an instance of AzureOpenAI",
+ )
+
+ raw_response = azure_client.completions.with_raw_response.create(
+ **data, timeout=timeout
+ )
+ response = raw_response.parse()
+ stringified_response = response.model_dump()
+ ## LOGGING
+ logging_obj.post_call(
+ input=prompt,
+ api_key=api_key,
+ original_response=stringified_response,
+ additional_args={
+ "headers": headers,
+ "api_version": api_version,
+ "api_base": api_base,
+ },
+ )
+ return (
+ openai_text_completion_config.convert_to_chat_model_response_object(
+ response_object=TextCompletionResponse(**stringified_response),
+ model_response_object=model_response,
+ )
+ )
+ except AzureOpenAIError as e:
+ raise e
+ except Exception as e:
+ status_code = getattr(e, "status_code", 500)
+ error_headers = getattr(e, "headers", None)
+ error_response = getattr(e, "response", None)
+ if error_headers is None and error_response:
+ error_headers = getattr(error_response, "headers", None)
+ raise AzureOpenAIError(
+ status_code=status_code, message=str(e), headers=error_headers
+ )
+
+ async def acompletion(
+ self,
+ api_key: str,
+ api_version: str,
+ model: str,
+ api_base: str,
+ data: dict,
+ timeout: Any,
+ model_response: ModelResponse,
+ logging_obj: Any,
+ max_retries: int,
+ azure_ad_token: Optional[str] = None,
+ client=None, # this is the AsyncAzureOpenAI
+ litellm_params: dict = {},
+ ):
+ response = None
+ try:
+ # init AzureOpenAI Client
+ # setting Azure client
+ azure_client = self.get_azure_openai_client(
+ api_version=api_version,
+ api_base=api_base,
+ api_key=api_key,
+ model=model,
+ _is_async=True,
+ client=client,
+ litellm_params=litellm_params,
+ )
+ if not isinstance(azure_client, AsyncAzureOpenAI):
+ raise AzureOpenAIError(
+ status_code=500,
+ message="azure_client is not an instance of AsyncAzureOpenAI",
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=data["prompt"],
+ api_key=azure_client.api_key,
+ additional_args={
+ "headers": {"Authorization": f"Bearer {azure_client.api_key}"},
+ "api_base": azure_client._base_url._uri_reference,
+ "acompletion": True,
+ "complete_input_dict": data,
+ },
+ )
+ raw_response = await azure_client.completions.with_raw_response.create(
+ **data, timeout=timeout
+ )
+ response = raw_response.parse()
+ return openai_text_completion_config.convert_to_chat_model_response_object(
+ response_object=response.model_dump(),
+ model_response_object=model_response,
+ )
+ except AzureOpenAIError as e:
+ raise e
+ except Exception as e:
+ status_code = getattr(e, "status_code", 500)
+ error_headers = getattr(e, "headers", None)
+ error_response = getattr(e, "response", None)
+ if error_headers is None and error_response:
+ error_headers = getattr(error_response, "headers", None)
+ raise AzureOpenAIError(
+ status_code=status_code, message=str(e), headers=error_headers
+ )
+
+ def streaming(
+ self,
+ logging_obj,
+ api_base: str,
+ api_key: str,
+ api_version: str,
+ data: dict,
+ model: str,
+ timeout: Any,
+ azure_ad_token: Optional[str] = None,
+ client=None,
+ litellm_params: dict = {},
+ ):
+ max_retries = data.pop("max_retries", 2)
+ if not isinstance(max_retries, int):
+ raise AzureOpenAIError(
+ status_code=422, message="max retries must be an int"
+ )
+ # init AzureOpenAI Client
+ azure_client = self.get_azure_openai_client(
+ api_version=api_version,
+ api_base=api_base,
+ api_key=api_key,
+ model=model,
+ _is_async=False,
+ client=client,
+ litellm_params=litellm_params,
+ )
+ if not isinstance(azure_client, AzureOpenAI):
+ raise AzureOpenAIError(
+ status_code=500,
+ message="azure_client is not an instance of AzureOpenAI",
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=data["prompt"],
+ api_key=azure_client.api_key,
+ additional_args={
+ "headers": {"Authorization": f"Bearer {azure_client.api_key}"},
+ "api_base": azure_client._base_url._uri_reference,
+ "acompletion": True,
+ "complete_input_dict": data,
+ },
+ )
+ raw_response = azure_client.completions.with_raw_response.create(
+ **data, timeout=timeout
+ )
+ response = raw_response.parse()
+ streamwrapper = CustomStreamWrapper(
+ completion_stream=response,
+ model=model,
+ custom_llm_provider="azure_text",
+ logging_obj=logging_obj,
+ )
+ return streamwrapper
+
+ async def async_streaming(
+ self,
+ logging_obj,
+ api_base: str,
+ api_key: str,
+ api_version: str,
+ data: dict,
+ model: str,
+ timeout: Any,
+ azure_ad_token: Optional[str] = None,
+ client=None,
+ litellm_params: dict = {},
+ ):
+ try:
+ # init AzureOpenAI Client
+ azure_client = self.get_azure_openai_client(
+ api_version=api_version,
+ api_base=api_base,
+ api_key=api_key,
+ model=model,
+ _is_async=True,
+ client=client,
+ litellm_params=litellm_params,
+ )
+ if not isinstance(azure_client, AsyncAzureOpenAI):
+ raise AzureOpenAIError(
+ status_code=500,
+ message="azure_client is not an instance of AsyncAzureOpenAI",
+ )
+ ## LOGGING
+ logging_obj.pre_call(
+ input=data["prompt"],
+ api_key=azure_client.api_key,
+ additional_args={
+ "headers": {"Authorization": f"Bearer {azure_client.api_key}"},
+ "api_base": azure_client._base_url._uri_reference,
+ "acompletion": True,
+ "complete_input_dict": data,
+ },
+ )
+ raw_response = await azure_client.completions.with_raw_response.create(
+ **data, timeout=timeout
+ )
+ response = raw_response.parse()
+ # return response
+ streamwrapper = CustomStreamWrapper(
+ completion_stream=response,
+ model=model,
+ custom_llm_provider="azure_text",
+ logging_obj=logging_obj,
+ )
+ return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
+ except Exception as e:
+ status_code = getattr(e, "status_code", 500)
+ error_headers = getattr(e, "headers", None)
+ error_response = getattr(e, "response", None)
+ if error_headers is None and error_response:
+ error_headers = getattr(error_response, "headers", None)
+ raise AzureOpenAIError(
+ status_code=status_code, message=str(e), headers=error_headers
+ )