about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/azure/completion
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/azure/completion
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/azure/completion')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/handler.py378
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/transformation.py53
2 files changed, 431 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/handler.py
new file mode 100644
index 00000000..8301c4d6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/handler.py
@@ -0,0 +1,378 @@
+from typing import Any, Callable, Optional
+
+from openai import AsyncAzureOpenAI, AzureOpenAI
+
+from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory
+from litellm.utils import CustomStreamWrapper, ModelResponse, TextCompletionResponse
+
+from ...openai.completion.transformation import OpenAITextCompletionConfig
+from ..common_utils import AzureOpenAIError, BaseAzureLLM
+
+openai_text_completion_config = OpenAITextCompletionConfig()
+
+
+class AzureTextCompletion(BaseAzureLLM):
+    def __init__(self) -> None:
+        super().__init__()
+
+    def validate_environment(self, api_key, azure_ad_token):
+        headers = {
+            "content-type": "application/json",
+        }
+        if api_key is not None:
+            headers["api-key"] = api_key
+        elif azure_ad_token is not None:
+            headers["Authorization"] = f"Bearer {azure_ad_token}"
+        return headers
+
+    def completion(  # noqa: PLR0915
+        self,
+        model: str,
+        messages: list,
+        model_response: ModelResponse,
+        api_key: str,
+        api_base: str,
+        api_version: str,
+        api_type: str,
+        azure_ad_token: str,
+        azure_ad_token_provider: Optional[Callable],
+        print_verbose: Callable,
+        timeout,
+        logging_obj,
+        optional_params,
+        litellm_params,
+        logger_fn,
+        acompletion: bool = False,
+        headers: Optional[dict] = None,
+        client=None,
+    ):
+        try:
+            if model is None or messages is None:
+                raise AzureOpenAIError(
+                    status_code=422, message="Missing model or messages"
+                )
+
+            max_retries = optional_params.pop("max_retries", 2)
+            prompt = prompt_factory(
+                messages=messages, model=model, custom_llm_provider="azure_text"
+            )
+
+            ### CHECK IF CLOUDFLARE AI GATEWAY ###
+            ### if so - set the model as part of the base url
+            if "gateway.ai.cloudflare.com" in api_base:
+                ## build base url - assume api base includes resource name
+                client = self._init_azure_client_for_cloudflare_ai_gateway(
+                    api_key=api_key,
+                    api_version=api_version,
+                    api_base=api_base,
+                    model=model,
+                    client=client,
+                    max_retries=max_retries,
+                    timeout=timeout,
+                    azure_ad_token=azure_ad_token,
+                    azure_ad_token_provider=azure_ad_token_provider,
+                    acompletion=acompletion,
+                )
+
+                data = {"model": None, "prompt": prompt, **optional_params}
+            else:
+                data = {
+                    "model": model,  # type: ignore
+                    "prompt": prompt,
+                    **optional_params,
+                }
+
+            if acompletion is True:
+                if optional_params.get("stream", False):
+                    return self.async_streaming(
+                        logging_obj=logging_obj,
+                        api_base=api_base,
+                        data=data,
+                        model=model,
+                        api_key=api_key,
+                        api_version=api_version,
+                        azure_ad_token=azure_ad_token,
+                        timeout=timeout,
+                        client=client,
+                        litellm_params=litellm_params,
+                    )
+                else:
+                    return self.acompletion(
+                        api_base=api_base,
+                        data=data,
+                        model_response=model_response,
+                        api_key=api_key,
+                        api_version=api_version,
+                        model=model,
+                        azure_ad_token=azure_ad_token,
+                        timeout=timeout,
+                        client=client,
+                        logging_obj=logging_obj,
+                        max_retries=max_retries,
+                        litellm_params=litellm_params,
+                    )
+            elif "stream" in optional_params and optional_params["stream"] is True:
+                return self.streaming(
+                    logging_obj=logging_obj,
+                    api_base=api_base,
+                    data=data,
+                    model=model,
+                    api_key=api_key,
+                    api_version=api_version,
+                    azure_ad_token=azure_ad_token,
+                    timeout=timeout,
+                    client=client,
+                )
+            else:
+                ## LOGGING
+                logging_obj.pre_call(
+                    input=prompt,
+                    api_key=api_key,
+                    additional_args={
+                        "headers": {
+                            "api_key": api_key,
+                            "azure_ad_token": azure_ad_token,
+                        },
+                        "api_version": api_version,
+                        "api_base": api_base,
+                        "complete_input_dict": data,
+                    },
+                )
+                if not isinstance(max_retries, int):
+                    raise AzureOpenAIError(
+                        status_code=422, message="max retries must be an int"
+                    )
+                # init AzureOpenAI Client
+                azure_client = self.get_azure_openai_client(
+                    api_key=api_key,
+                    api_base=api_base,
+                    api_version=api_version,
+                    client=client,
+                    litellm_params=litellm_params,
+                    _is_async=False,
+                    model=model,
+                )
+
+                if not isinstance(azure_client, AzureOpenAI):
+                    raise AzureOpenAIError(
+                        status_code=500,
+                        message="azure_client is not an instance of AzureOpenAI",
+                    )
+
+                raw_response = azure_client.completions.with_raw_response.create(
+                    **data, timeout=timeout
+                )
+                response = raw_response.parse()
+                stringified_response = response.model_dump()
+                ## LOGGING
+                logging_obj.post_call(
+                    input=prompt,
+                    api_key=api_key,
+                    original_response=stringified_response,
+                    additional_args={
+                        "headers": headers,
+                        "api_version": api_version,
+                        "api_base": api_base,
+                    },
+                )
+                return (
+                    openai_text_completion_config.convert_to_chat_model_response_object(
+                        response_object=TextCompletionResponse(**stringified_response),
+                        model_response_object=model_response,
+                    )
+                )
+        except AzureOpenAIError as e:
+            raise e
+        except Exception as e:
+            status_code = getattr(e, "status_code", 500)
+            error_headers = getattr(e, "headers", None)
+            error_response = getattr(e, "response", None)
+            if error_headers is None and error_response:
+                error_headers = getattr(error_response, "headers", None)
+            raise AzureOpenAIError(
+                status_code=status_code, message=str(e), headers=error_headers
+            )
+
+    async def acompletion(
+        self,
+        api_key: str,
+        api_version: str,
+        model: str,
+        api_base: str,
+        data: dict,
+        timeout: Any,
+        model_response: ModelResponse,
+        logging_obj: Any,
+        max_retries: int,
+        azure_ad_token: Optional[str] = None,
+        client=None,  # this is the AsyncAzureOpenAI
+        litellm_params: dict = {},
+    ):
+        response = None
+        try:
+            # init AzureOpenAI Client
+            # setting Azure client
+            azure_client = self.get_azure_openai_client(
+                api_version=api_version,
+                api_base=api_base,
+                api_key=api_key,
+                model=model,
+                _is_async=True,
+                client=client,
+                litellm_params=litellm_params,
+            )
+            if not isinstance(azure_client, AsyncAzureOpenAI):
+                raise AzureOpenAIError(
+                    status_code=500,
+                    message="azure_client is not an instance of AsyncAzureOpenAI",
+                )
+
+            ## LOGGING
+            logging_obj.pre_call(
+                input=data["prompt"],
+                api_key=azure_client.api_key,
+                additional_args={
+                    "headers": {"Authorization": f"Bearer {azure_client.api_key}"},
+                    "api_base": azure_client._base_url._uri_reference,
+                    "acompletion": True,
+                    "complete_input_dict": data,
+                },
+            )
+            raw_response = await azure_client.completions.with_raw_response.create(
+                **data, timeout=timeout
+            )
+            response = raw_response.parse()
+            return openai_text_completion_config.convert_to_chat_model_response_object(
+                response_object=response.model_dump(),
+                model_response_object=model_response,
+            )
+        except AzureOpenAIError as e:
+            raise e
+        except Exception as e:
+            status_code = getattr(e, "status_code", 500)
+            error_headers = getattr(e, "headers", None)
+            error_response = getattr(e, "response", None)
+            if error_headers is None and error_response:
+                error_headers = getattr(error_response, "headers", None)
+            raise AzureOpenAIError(
+                status_code=status_code, message=str(e), headers=error_headers
+            )
+
+    def streaming(
+        self,
+        logging_obj,
+        api_base: str,
+        api_key: str,
+        api_version: str,
+        data: dict,
+        model: str,
+        timeout: Any,
+        azure_ad_token: Optional[str] = None,
+        client=None,
+        litellm_params: dict = {},
+    ):
+        max_retries = data.pop("max_retries", 2)
+        if not isinstance(max_retries, int):
+            raise AzureOpenAIError(
+                status_code=422, message="max retries must be an int"
+            )
+        # init AzureOpenAI Client
+        azure_client = self.get_azure_openai_client(
+            api_version=api_version,
+            api_base=api_base,
+            api_key=api_key,
+            model=model,
+            _is_async=False,
+            client=client,
+            litellm_params=litellm_params,
+        )
+        if not isinstance(azure_client, AzureOpenAI):
+            raise AzureOpenAIError(
+                status_code=500,
+                message="azure_client is not an instance of AzureOpenAI",
+            )
+
+        ## LOGGING
+        logging_obj.pre_call(
+            input=data["prompt"],
+            api_key=azure_client.api_key,
+            additional_args={
+                "headers": {"Authorization": f"Bearer {azure_client.api_key}"},
+                "api_base": azure_client._base_url._uri_reference,
+                "acompletion": True,
+                "complete_input_dict": data,
+            },
+        )
+        raw_response = azure_client.completions.with_raw_response.create(
+            **data, timeout=timeout
+        )
+        response = raw_response.parse()
+        streamwrapper = CustomStreamWrapper(
+            completion_stream=response,
+            model=model,
+            custom_llm_provider="azure_text",
+            logging_obj=logging_obj,
+        )
+        return streamwrapper
+
+    async def async_streaming(
+        self,
+        logging_obj,
+        api_base: str,
+        api_key: str,
+        api_version: str,
+        data: dict,
+        model: str,
+        timeout: Any,
+        azure_ad_token: Optional[str] = None,
+        client=None,
+        litellm_params: dict = {},
+    ):
+        try:
+            # init AzureOpenAI Client
+            azure_client = self.get_azure_openai_client(
+                api_version=api_version,
+                api_base=api_base,
+                api_key=api_key,
+                model=model,
+                _is_async=True,
+                client=client,
+                litellm_params=litellm_params,
+            )
+            if not isinstance(azure_client, AsyncAzureOpenAI):
+                raise AzureOpenAIError(
+                    status_code=500,
+                    message="azure_client is not an instance of AsyncAzureOpenAI",
+                )
+            ## LOGGING
+            logging_obj.pre_call(
+                input=data["prompt"],
+                api_key=azure_client.api_key,
+                additional_args={
+                    "headers": {"Authorization": f"Bearer {azure_client.api_key}"},
+                    "api_base": azure_client._base_url._uri_reference,
+                    "acompletion": True,
+                    "complete_input_dict": data,
+                },
+            )
+            raw_response = await azure_client.completions.with_raw_response.create(
+                **data, timeout=timeout
+            )
+            response = raw_response.parse()
+            # return response
+            streamwrapper = CustomStreamWrapper(
+                completion_stream=response,
+                model=model,
+                custom_llm_provider="azure_text",
+                logging_obj=logging_obj,
+            )
+            return streamwrapper  ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
+        except Exception as e:
+            status_code = getattr(e, "status_code", 500)
+            error_headers = getattr(e, "headers", None)
+            error_response = getattr(e, "response", None)
+            if error_headers is None and error_response:
+                error_headers = getattr(error_response, "headers", None)
+            raise AzureOpenAIError(
+                status_code=status_code, message=str(e), headers=error_headers
+            )
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/transformation.py
new file mode 100644
index 00000000..bc7b97c6
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/azure/completion/transformation.py
@@ -0,0 +1,53 @@
+from typing import Optional, Union
+
+from ...openai.completion.transformation import OpenAITextCompletionConfig
+
+
+class AzureOpenAITextConfig(OpenAITextCompletionConfig):
+    """
+    Reference: https://platform.openai.com/docs/api-reference/chat/create
+
+    The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
+
+    - `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
+
+    - `function_call` (string or object): This optional parameter controls how the model calls functions.
+
+    - `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
+
+    - `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
+
+    - `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
+
+    - `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
+
+    - `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
+
+    - `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
+
+    - `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
+
+    - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
+    """
+
+    def __init__(
+        self,
+        frequency_penalty: Optional[int] = None,
+        logit_bias: Optional[dict] = None,
+        max_tokens: Optional[int] = None,
+        n: Optional[int] = None,
+        presence_penalty: Optional[int] = None,
+        stop: Optional[Union[str, list]] = None,
+        temperature: Optional[int] = None,
+        top_p: Optional[int] = None,
+    ) -> None:
+        super().__init__(
+            frequency_penalty=frequency_penalty,
+            logit_bias=logit_bias,
+            max_tokens=max_tokens,
+            n=n,
+            presence_penalty=presence_penalty,
+            stop=stop,
+            temperature=temperature,
+            top_p=top_p,
+        )