about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/openai/completion/handler.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/openai/completion/handler.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/openai/completion/handler.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/openai/completion/handler.py319
1 files changed, 319 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/openai/completion/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/openai/completion/handler.py
new file mode 100644
index 00000000..2e60f55b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/openai/completion/handler.py
@@ -0,0 +1,319 @@
+import json
+from typing import Callable, List, Optional, Union
+
+from openai import AsyncOpenAI, OpenAI
+
+import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
+from litellm.llms.base import BaseLLM
+from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage
+from litellm.types.utils import LlmProviders, ModelResponse, TextCompletionResponse
+from litellm.utils import ProviderConfigManager
+
+from ..common_utils import OpenAIError
+from .transformation import OpenAITextCompletionConfig
+
+
+class OpenAITextCompletion(BaseLLM):
+    openai_text_completion_global_config = OpenAITextCompletionConfig()
+
+    def __init__(self) -> None:
+        super().__init__()
+
+    def validate_environment(self, api_key):
+        headers = {
+            "content-type": "application/json",
+        }
+        if api_key:
+            headers["Authorization"] = f"Bearer {api_key}"
+        return headers
+
+    def completion(
+        self,
+        model_response: ModelResponse,
+        api_key: str,
+        model: str,
+        messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
+        timeout: float,
+        custom_llm_provider: str,
+        logging_obj: LiteLLMLoggingObj,
+        optional_params: dict,
+        print_verbose: Optional[Callable] = None,
+        api_base: Optional[str] = None,
+        acompletion: bool = False,
+        litellm_params=None,
+        logger_fn=None,
+        client=None,
+        organization: Optional[str] = None,
+        headers: Optional[dict] = None,
+    ):
+        try:
+            if headers is None:
+                headers = self.validate_environment(api_key=api_key)
+            if model is None or messages is None:
+                raise OpenAIError(status_code=422, message="Missing model or messages")
+
+            # don't send max retries to the api, if set
+
+            provider_config = ProviderConfigManager.get_provider_text_completion_config(
+                model=model,
+                provider=LlmProviders(custom_llm_provider),
+            )
+
+            data = provider_config.transform_text_completion_request(
+                model=model,
+                messages=messages,
+                optional_params=optional_params,
+                headers=headers,
+            )
+            max_retries = data.pop("max_retries", 2)
+            ## LOGGING
+            logging_obj.pre_call(
+                input=messages,
+                api_key=api_key,
+                additional_args={
+                    "headers": headers,
+                    "api_base": api_base,
+                    "complete_input_dict": data,
+                },
+            )
+            if acompletion is True:
+                if optional_params.get("stream", False):
+                    return self.async_streaming(
+                        logging_obj=logging_obj,
+                        api_base=api_base,
+                        api_key=api_key,
+                        data=data,
+                        headers=headers,
+                        model_response=model_response,
+                        model=model,
+                        timeout=timeout,
+                        max_retries=max_retries,
+                        client=client,
+                        organization=organization,
+                    )
+                else:
+                    return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client)  # type: ignore
+            elif optional_params.get("stream", False):
+                return self.streaming(
+                    logging_obj=logging_obj,
+                    api_base=api_base,
+                    api_key=api_key,
+                    data=data,
+                    headers=headers,
+                    model_response=model_response,
+                    model=model,
+                    timeout=timeout,
+                    max_retries=max_retries,  # type: ignore
+                    client=client,
+                    organization=organization,
+                )
+            else:
+                if client is None:
+                    openai_client = OpenAI(
+                        api_key=api_key,
+                        base_url=api_base,
+                        http_client=litellm.client_session,
+                        timeout=timeout,
+                        max_retries=max_retries,  # type: ignore
+                        organization=organization,
+                    )
+                else:
+                    openai_client = client
+
+                raw_response = openai_client.completions.with_raw_response.create(**data)  # type: ignore
+                response = raw_response.parse()
+                response_json = response.model_dump()
+
+                ## LOGGING
+                logging_obj.post_call(
+                    api_key=api_key,
+                    original_response=response_json,
+                    additional_args={
+                        "headers": headers,
+                        "api_base": api_base,
+                    },
+                )
+
+                ## RESPONSE OBJECT
+                return TextCompletionResponse(**response_json)
+        except Exception as e:
+            status_code = getattr(e, "status_code", 500)
+            error_headers = getattr(e, "headers", None)
+            error_text = getattr(e, "text", str(e))
+            error_response = getattr(e, "response", None)
+            if error_headers is None and error_response:
+                error_headers = getattr(error_response, "headers", None)
+            raise OpenAIError(
+                status_code=status_code, message=error_text, headers=error_headers
+            )
+
+    async def acompletion(
+        self,
+        logging_obj,
+        api_base: str,
+        data: dict,
+        headers: dict,
+        model_response: ModelResponse,
+        api_key: str,
+        model: str,
+        timeout: float,
+        max_retries: int,
+        organization: Optional[str] = None,
+        client=None,
+    ):
+        try:
+            if client is None:
+                openai_aclient = AsyncOpenAI(
+                    api_key=api_key,
+                    base_url=api_base,
+                    http_client=litellm.aclient_session,
+                    timeout=timeout,
+                    max_retries=max_retries,
+                    organization=organization,
+                )
+            else:
+                openai_aclient = client
+
+            raw_response = await openai_aclient.completions.with_raw_response.create(
+                **data
+            )
+            response = raw_response.parse()
+            response_json = response.model_dump()
+
+            ## LOGGING
+            logging_obj.post_call(
+                api_key=api_key,
+                original_response=response,
+                additional_args={
+                    "headers": headers,
+                    "api_base": api_base,
+                },
+            )
+            ## RESPONSE OBJECT
+            response_obj = TextCompletionResponse(**response_json)
+            response_obj._hidden_params.original_response = json.dumps(response_json)
+            return response_obj
+        except Exception as e:
+            status_code = getattr(e, "status_code", 500)
+            error_headers = getattr(e, "headers", None)
+            error_text = getattr(e, "text", str(e))
+            error_response = getattr(e, "response", None)
+            if error_headers is None and error_response:
+                error_headers = getattr(error_response, "headers", None)
+            raise OpenAIError(
+                status_code=status_code, message=error_text, headers=error_headers
+            )
+
+    def streaming(
+        self,
+        logging_obj,
+        api_key: str,
+        data: dict,
+        headers: dict,
+        model_response: ModelResponse,
+        model: str,
+        timeout: float,
+        api_base: Optional[str] = None,
+        max_retries=None,
+        client=None,
+        organization=None,
+    ):
+
+        if client is None:
+            openai_client = OpenAI(
+                api_key=api_key,
+                base_url=api_base,
+                http_client=litellm.client_session,
+                timeout=timeout,
+                max_retries=max_retries,  # type: ignore
+                organization=organization,
+            )
+        else:
+            openai_client = client
+
+        try:
+            raw_response = openai_client.completions.with_raw_response.create(**data)
+            response = raw_response.parse()
+        except Exception as e:
+            status_code = getattr(e, "status_code", 500)
+            error_headers = getattr(e, "headers", None)
+            error_text = getattr(e, "text", str(e))
+            error_response = getattr(e, "response", None)
+            if error_headers is None and error_response:
+                error_headers = getattr(error_response, "headers", None)
+            raise OpenAIError(
+                status_code=status_code, message=error_text, headers=error_headers
+            )
+        streamwrapper = CustomStreamWrapper(
+            completion_stream=response,
+            model=model,
+            custom_llm_provider="text-completion-openai",
+            logging_obj=logging_obj,
+            stream_options=data.get("stream_options", None),
+        )
+
+        try:
+            for chunk in streamwrapper:
+                yield chunk
+        except Exception as e:
+            status_code = getattr(e, "status_code", 500)
+            error_headers = getattr(e, "headers", None)
+            error_text = getattr(e, "text", str(e))
+            error_response = getattr(e, "response", None)
+            if error_headers is None and error_response:
+                error_headers = getattr(error_response, "headers", None)
+            raise OpenAIError(
+                status_code=status_code, message=error_text, headers=error_headers
+            )
+
+    async def async_streaming(
+        self,
+        logging_obj,
+        api_key: str,
+        data: dict,
+        headers: dict,
+        model_response: ModelResponse,
+        model: str,
+        timeout: float,
+        max_retries: int,
+        api_base: Optional[str] = None,
+        client=None,
+        organization=None,
+    ):
+        if client is None:
+            openai_client = AsyncOpenAI(
+                api_key=api_key,
+                base_url=api_base,
+                http_client=litellm.aclient_session,
+                timeout=timeout,
+                max_retries=max_retries,
+                organization=organization,
+            )
+        else:
+            openai_client = client
+
+        raw_response = await openai_client.completions.with_raw_response.create(**data)
+        response = raw_response.parse()
+        streamwrapper = CustomStreamWrapper(
+            completion_stream=response,
+            model=model,
+            custom_llm_provider="text-completion-openai",
+            logging_obj=logging_obj,
+            stream_options=data.get("stream_options", None),
+        )
+
+        try:
+            async for transformed_chunk in streamwrapper:
+                yield transformed_chunk
+        except Exception as e:
+            status_code = getattr(e, "status_code", 500)
+            error_headers = getattr(e, "headers", None)
+            error_text = getattr(e, "text", str(e))
+            error_response = getattr(e, "response", None)
+            if error_headers is None and error_response:
+                error_headers = getattr(error_response, "headers", None)
+            raise OpenAIError(
+                status_code=status_code, message=error_text, headers=error_headers
+            )