diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/oobabooga/chat')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/oobabooga/chat/oobabooga.py | 158 | ||||
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/oobabooga/chat/transformation.py | 100 |
2 files changed, 258 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/oobabooga/chat/oobabooga.py b/.venv/lib/python3.12/site-packages/litellm/llms/oobabooga/chat/oobabooga.py new file mode 100644 index 00000000..8829d223 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/oobabooga/chat/oobabooga.py @@ -0,0 +1,158 @@ +import json +from typing import Any, Callable, Optional + +import litellm +from litellm.llms.custom_httpx.http_handler import _get_httpx_client +from litellm.utils import EmbeddingResponse, ModelResponse, Usage + +from ..common_utils import OobaboogaError +from .transformation import OobaboogaConfig + +oobabooga_config = OobaboogaConfig() + + +def completion( + model: str, + messages: list, + api_base: Optional[str], + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + optional_params: dict, + litellm_params: dict, + custom_prompt_dict={}, + logger_fn=None, + default_max_tokens_to_sample=None, +): + headers = oobabooga_config.validate_environment( + api_key=api_key, + headers={}, + model=model, + messages=messages, + optional_params=optional_params, + ) + if "https" in model: + completion_url = model + elif api_base: + completion_url = api_base + else: + raise OobaboogaError( + status_code=404, + message="API Base not set. Set one via completion(..,api_base='your-api-url')", + ) + model = model + + completion_url = completion_url + "/v1/chat/completions" + data = oobabooga_config.transform_request( + model=model, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + headers=headers, + ) + ## LOGGING + + logging_obj.pre_call( + input=messages, + api_key=api_key, + additional_args={"complete_input_dict": data}, + ) + ## COMPLETION CALL + client = _get_httpx_client() + response = client.post( + completion_url, + headers=headers, + data=json.dumps(data), + stream=optional_params["stream"] if "stream" in optional_params else False, + ) + if "stream" in optional_params and optional_params["stream"] is True: + return response.iter_lines() + else: + return oobabooga_config.transform_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key=api_key, + request_data=data, + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + ) + + +def embedding( + model: str, + input: list, + model_response: EmbeddingResponse, + api_key: Optional[str], + api_base: Optional[str], + logging_obj: Any, + optional_params: dict, + encoding=None, +): + # Create completion URL + if "https" in model: + embeddings_url = model + elif api_base: + embeddings_url = f"{api_base}/v1/embeddings" + else: + raise OobaboogaError( + status_code=404, + message="API Base not set. Set one via completion(..,api_base='your-api-url')", + ) + + # Prepare request data + data = {"input": input} + if optional_params: + data.update(optional_params) + + # Logging before API call + if logging_obj: + logging_obj.pre_call( + input=input, api_key=api_key, additional_args={"complete_input_dict": data} + ) + + # Send POST request + headers = oobabooga_config.validate_environment( + api_key=api_key, + headers={}, + model=model, + messages=[], + optional_params=optional_params, + ) + response = litellm.module_level_client.post( + embeddings_url, headers=headers, json=data + ) + completion_response = response.json() + + # Check for errors in response + if "error" in completion_response: + raise OobaboogaError( + message=completion_response["error"], + status_code=completion_response.get("status_code", 500), + ) + + # Process response data + model_response.data = [ + { + "embedding": completion_response["data"][0]["embedding"], + "index": 0, + "object": "embedding", + } + ] + + num_tokens = len(completion_response["data"][0]["embedding"]) + # Adding metadata to response + setattr( + model_response, + "usage", + Usage(prompt_tokens=num_tokens, total_tokens=num_tokens), + ) + model_response.object = "list" + model_response.model = model + + return model_response diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/oobabooga/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/oobabooga/chat/transformation.py new file mode 100644 index 00000000..6fd56f93 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/oobabooga/chat/transformation.py @@ -0,0 +1,100 @@ +import time +from typing import TYPE_CHECKING, Any, List, Optional, Union + +import httpx + +from litellm.llms.base_llm.chat.transformation import BaseLLMException +from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ModelResponse, Usage + +from ..common_utils import OobaboogaError + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + + LoggingClass = LiteLLMLoggingObj +else: + LoggingClass = Any + + +class OobaboogaConfig(OpenAIGPTConfig): + def get_error_class( + self, + error_message: str, + status_code: int, + headers: Optional[Union[dict, httpx.Headers]] = None, + ) -> BaseLLMException: + return OobaboogaError( + status_code=status_code, message=error_message, headers=headers + ) + + def transform_response( + self, + model: str, + raw_response: httpx.Response, + model_response: ModelResponse, + logging_obj: LoggingClass, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=raw_response.text, + additional_args={"complete_input_dict": request_data}, + ) + + ## RESPONSE OBJECT + try: + completion_response = raw_response.json() + except Exception: + raise OobaboogaError( + message=raw_response.text, status_code=raw_response.status_code + ) + if "error" in completion_response: + raise OobaboogaError( + message=completion_response["error"], + status_code=raw_response.status_code, + ) + else: + try: + model_response.choices[0].message.content = completion_response["choices"][0]["message"]["content"] # type: ignore + except Exception as e: + raise OobaboogaError( + message=str(e), + status_code=raw_response.status_code, + ) + + model_response.created = int(time.time()) + model_response.model = model + usage = Usage( + prompt_tokens=completion_response["usage"]["prompt_tokens"], + completion_tokens=completion_response["usage"]["completion_tokens"], + total_tokens=completion_response["usage"]["total_tokens"], + ) + setattr(model_response, "usage", usage) + return model_response + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + headers = { + "accept": "application/json", + "content-type": "application/json", + } + if api_key is not None: + headers["Authorization"] = f"Token {api_key}" + return headers |