aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/openai/completion/handler.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/openai/completion/handler.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/openai/completion/handler.py319
1 files changed, 319 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/openai/completion/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/openai/completion/handler.py
new file mode 100644
index 00000000..2e60f55b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/openai/completion/handler.py
@@ -0,0 +1,319 @@
+import json
+from typing import Callable, List, Optional, Union
+
+from openai import AsyncOpenAI, OpenAI
+
+import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
+from litellm.llms.base import BaseLLM
+from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage
+from litellm.types.utils import LlmProviders, ModelResponse, TextCompletionResponse
+from litellm.utils import ProviderConfigManager
+
+from ..common_utils import OpenAIError
+from .transformation import OpenAITextCompletionConfig
+
+
+class OpenAITextCompletion(BaseLLM):
+ openai_text_completion_global_config = OpenAITextCompletionConfig()
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def validate_environment(self, api_key):
+ headers = {
+ "content-type": "application/json",
+ }
+ if api_key:
+ headers["Authorization"] = f"Bearer {api_key}"
+ return headers
+
+ def completion(
+ self,
+ model_response: ModelResponse,
+ api_key: str,
+ model: str,
+ messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
+ timeout: float,
+ custom_llm_provider: str,
+ logging_obj: LiteLLMLoggingObj,
+ optional_params: dict,
+ print_verbose: Optional[Callable] = None,
+ api_base: Optional[str] = None,
+ acompletion: bool = False,
+ litellm_params=None,
+ logger_fn=None,
+ client=None,
+ organization: Optional[str] = None,
+ headers: Optional[dict] = None,
+ ):
+ try:
+ if headers is None:
+ headers = self.validate_environment(api_key=api_key)
+ if model is None or messages is None:
+ raise OpenAIError(status_code=422, message="Missing model or messages")
+
+ # don't send max retries to the api, if set
+
+ provider_config = ProviderConfigManager.get_provider_text_completion_config(
+ model=model,
+ provider=LlmProviders(custom_llm_provider),
+ )
+
+ data = provider_config.transform_text_completion_request(
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ headers=headers,
+ )
+ max_retries = data.pop("max_retries", 2)
+ ## LOGGING
+ logging_obj.pre_call(
+ input=messages,
+ api_key=api_key,
+ additional_args={
+ "headers": headers,
+ "api_base": api_base,
+ "complete_input_dict": data,
+ },
+ )
+ if acompletion is True:
+ if optional_params.get("stream", False):
+ return self.async_streaming(
+ logging_obj=logging_obj,
+ api_base=api_base,
+ api_key=api_key,
+ data=data,
+ headers=headers,
+ model_response=model_response,
+ model=model,
+ timeout=timeout,
+ max_retries=max_retries,
+ client=client,
+ organization=organization,
+ )
+ else:
+ return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
+ elif optional_params.get("stream", False):
+ return self.streaming(
+ logging_obj=logging_obj,
+ api_base=api_base,
+ api_key=api_key,
+ data=data,
+ headers=headers,
+ model_response=model_response,
+ model=model,
+ timeout=timeout,
+ max_retries=max_retries, # type: ignore
+ client=client,
+ organization=organization,
+ )
+ else:
+ if client is None:
+ openai_client = OpenAI(
+ api_key=api_key,
+ base_url=api_base,
+ http_client=litellm.client_session,
+ timeout=timeout,
+ max_retries=max_retries, # type: ignore
+ organization=organization,
+ )
+ else:
+ openai_client = client
+
+ raw_response = openai_client.completions.with_raw_response.create(**data) # type: ignore
+ response = raw_response.parse()
+ response_json = response.model_dump()
+
+ ## LOGGING
+ logging_obj.post_call(
+ api_key=api_key,
+ original_response=response_json,
+ additional_args={
+ "headers": headers,
+ "api_base": api_base,
+ },
+ )
+
+ ## RESPONSE OBJECT
+ return TextCompletionResponse(**response_json)
+ except Exception as e:
+ status_code = getattr(e, "status_code", 500)
+ error_headers = getattr(e, "headers", None)
+ error_text = getattr(e, "text", str(e))
+ error_response = getattr(e, "response", None)
+ if error_headers is None and error_response:
+ error_headers = getattr(error_response, "headers", None)
+ raise OpenAIError(
+ status_code=status_code, message=error_text, headers=error_headers
+ )
+
+ async def acompletion(
+ self,
+ logging_obj,
+ api_base: str,
+ data: dict,
+ headers: dict,
+ model_response: ModelResponse,
+ api_key: str,
+ model: str,
+ timeout: float,
+ max_retries: int,
+ organization: Optional[str] = None,
+ client=None,
+ ):
+ try:
+ if client is None:
+ openai_aclient = AsyncOpenAI(
+ api_key=api_key,
+ base_url=api_base,
+ http_client=litellm.aclient_session,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ )
+ else:
+ openai_aclient = client
+
+ raw_response = await openai_aclient.completions.with_raw_response.create(
+ **data
+ )
+ response = raw_response.parse()
+ response_json = response.model_dump()
+
+ ## LOGGING
+ logging_obj.post_call(
+ api_key=api_key,
+ original_response=response,
+ additional_args={
+ "headers": headers,
+ "api_base": api_base,
+ },
+ )
+ ## RESPONSE OBJECT
+ response_obj = TextCompletionResponse(**response_json)
+ response_obj._hidden_params.original_response = json.dumps(response_json)
+ return response_obj
+ except Exception as e:
+ status_code = getattr(e, "status_code", 500)
+ error_headers = getattr(e, "headers", None)
+ error_text = getattr(e, "text", str(e))
+ error_response = getattr(e, "response", None)
+ if error_headers is None and error_response:
+ error_headers = getattr(error_response, "headers", None)
+ raise OpenAIError(
+ status_code=status_code, message=error_text, headers=error_headers
+ )
+
+ def streaming(
+ self,
+ logging_obj,
+ api_key: str,
+ data: dict,
+ headers: dict,
+ model_response: ModelResponse,
+ model: str,
+ timeout: float,
+ api_base: Optional[str] = None,
+ max_retries=None,
+ client=None,
+ organization=None,
+ ):
+
+ if client is None:
+ openai_client = OpenAI(
+ api_key=api_key,
+ base_url=api_base,
+ http_client=litellm.client_session,
+ timeout=timeout,
+ max_retries=max_retries, # type: ignore
+ organization=organization,
+ )
+ else:
+ openai_client = client
+
+ try:
+ raw_response = openai_client.completions.with_raw_response.create(**data)
+ response = raw_response.parse()
+ except Exception as e:
+ status_code = getattr(e, "status_code", 500)
+ error_headers = getattr(e, "headers", None)
+ error_text = getattr(e, "text", str(e))
+ error_response = getattr(e, "response", None)
+ if error_headers is None and error_response:
+ error_headers = getattr(error_response, "headers", None)
+ raise OpenAIError(
+ status_code=status_code, message=error_text, headers=error_headers
+ )
+ streamwrapper = CustomStreamWrapper(
+ completion_stream=response,
+ model=model,
+ custom_llm_provider="text-completion-openai",
+ logging_obj=logging_obj,
+ stream_options=data.get("stream_options", None),
+ )
+
+ try:
+ for chunk in streamwrapper:
+ yield chunk
+ except Exception as e:
+ status_code = getattr(e, "status_code", 500)
+ error_headers = getattr(e, "headers", None)
+ error_text = getattr(e, "text", str(e))
+ error_response = getattr(e, "response", None)
+ if error_headers is None and error_response:
+ error_headers = getattr(error_response, "headers", None)
+ raise OpenAIError(
+ status_code=status_code, message=error_text, headers=error_headers
+ )
+
+ async def async_streaming(
+ self,
+ logging_obj,
+ api_key: str,
+ data: dict,
+ headers: dict,
+ model_response: ModelResponse,
+ model: str,
+ timeout: float,
+ max_retries: int,
+ api_base: Optional[str] = None,
+ client=None,
+ organization=None,
+ ):
+ if client is None:
+ openai_client = AsyncOpenAI(
+ api_key=api_key,
+ base_url=api_base,
+ http_client=litellm.aclient_session,
+ timeout=timeout,
+ max_retries=max_retries,
+ organization=organization,
+ )
+ else:
+ openai_client = client
+
+ raw_response = await openai_client.completions.with_raw_response.create(**data)
+ response = raw_response.parse()
+ streamwrapper = CustomStreamWrapper(
+ completion_stream=response,
+ model=model,
+ custom_llm_provider="text-completion-openai",
+ logging_obj=logging_obj,
+ stream_options=data.get("stream_options", None),
+ )
+
+ try:
+ async for transformed_chunk in streamwrapper:
+ yield transformed_chunk
+ except Exception as e:
+ status_code = getattr(e, "status_code", 500)
+ error_headers = getattr(e, "headers", None)
+ error_text = getattr(e, "text", str(e))
+ error_response = getattr(e, "response", None)
+ if error_headers is None and error_response:
+ error_headers = getattr(error_response, "headers", None)
+ raise OpenAIError(
+ status_code=status_code, message=error_text, headers=error_headers
+ )