about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/oobabooga/chat
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/oobabooga/chat')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/oobabooga/chat/oobabooga.py158
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/oobabooga/chat/transformation.py100
2 files changed, 258 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/oobabooga/chat/oobabooga.py b/.venv/lib/python3.12/site-packages/litellm/llms/oobabooga/chat/oobabooga.py
new file mode 100644
index 00000000..8829d223
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/oobabooga/chat/oobabooga.py
@@ -0,0 +1,158 @@
+import json
+from typing import Any, Callable, Optional
+
+import litellm
+from litellm.llms.custom_httpx.http_handler import _get_httpx_client
+from litellm.utils import EmbeddingResponse, ModelResponse, Usage
+
+from ..common_utils import OobaboogaError
+from .transformation import OobaboogaConfig
+
+oobabooga_config = OobaboogaConfig()
+
+
+def completion(
+    model: str,
+    messages: list,
+    api_base: Optional[str],
+    model_response: ModelResponse,
+    print_verbose: Callable,
+    encoding,
+    api_key,
+    logging_obj,
+    optional_params: dict,
+    litellm_params: dict,
+    custom_prompt_dict={},
+    logger_fn=None,
+    default_max_tokens_to_sample=None,
+):
+    headers = oobabooga_config.validate_environment(
+        api_key=api_key,
+        headers={},
+        model=model,
+        messages=messages,
+        optional_params=optional_params,
+    )
+    if "https" in model:
+        completion_url = model
+    elif api_base:
+        completion_url = api_base
+    else:
+        raise OobaboogaError(
+            status_code=404,
+            message="API Base not set. Set one via completion(..,api_base='your-api-url')",
+        )
+    model = model
+
+    completion_url = completion_url + "/v1/chat/completions"
+    data = oobabooga_config.transform_request(
+        model=model,
+        messages=messages,
+        optional_params=optional_params,
+        litellm_params=litellm_params,
+        headers=headers,
+    )
+    ## LOGGING
+
+    logging_obj.pre_call(
+        input=messages,
+        api_key=api_key,
+        additional_args={"complete_input_dict": data},
+    )
+    ## COMPLETION CALL
+    client = _get_httpx_client()
+    response = client.post(
+        completion_url,
+        headers=headers,
+        data=json.dumps(data),
+        stream=optional_params["stream"] if "stream" in optional_params else False,
+    )
+    if "stream" in optional_params and optional_params["stream"] is True:
+        return response.iter_lines()
+    else:
+        return oobabooga_config.transform_response(
+            model=model,
+            raw_response=response,
+            model_response=model_response,
+            logging_obj=logging_obj,
+            api_key=api_key,
+            request_data=data,
+            messages=messages,
+            optional_params=optional_params,
+            litellm_params=litellm_params,
+            encoding=encoding,
+        )
+
+
+def embedding(
+    model: str,
+    input: list,
+    model_response: EmbeddingResponse,
+    api_key: Optional[str],
+    api_base: Optional[str],
+    logging_obj: Any,
+    optional_params: dict,
+    encoding=None,
+):
+    # Create completion URL
+    if "https" in model:
+        embeddings_url = model
+    elif api_base:
+        embeddings_url = f"{api_base}/v1/embeddings"
+    else:
+        raise OobaboogaError(
+            status_code=404,
+            message="API Base not set. Set one via completion(..,api_base='your-api-url')",
+        )
+
+    # Prepare request data
+    data = {"input": input}
+    if optional_params:
+        data.update(optional_params)
+
+    # Logging before API call
+    if logging_obj:
+        logging_obj.pre_call(
+            input=input, api_key=api_key, additional_args={"complete_input_dict": data}
+        )
+
+    # Send POST request
+    headers = oobabooga_config.validate_environment(
+        api_key=api_key,
+        headers={},
+        model=model,
+        messages=[],
+        optional_params=optional_params,
+    )
+    response = litellm.module_level_client.post(
+        embeddings_url, headers=headers, json=data
+    )
+    completion_response = response.json()
+
+    # Check for errors in response
+    if "error" in completion_response:
+        raise OobaboogaError(
+            message=completion_response["error"],
+            status_code=completion_response.get("status_code", 500),
+        )
+
+    # Process response data
+    model_response.data = [
+        {
+            "embedding": completion_response["data"][0]["embedding"],
+            "index": 0,
+            "object": "embedding",
+        }
+    ]
+
+    num_tokens = len(completion_response["data"][0]["embedding"])
+    # Adding metadata to response
+    setattr(
+        model_response,
+        "usage",
+        Usage(prompt_tokens=num_tokens, total_tokens=num_tokens),
+    )
+    model_response.object = "list"
+    model_response.model = model
+
+    return model_response
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/oobabooga/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/oobabooga/chat/transformation.py
new file mode 100644
index 00000000..6fd56f93
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/oobabooga/chat/transformation.py
@@ -0,0 +1,100 @@
+import time
+from typing import TYPE_CHECKING, Any, List, Optional, Union
+
+import httpx
+
+from litellm.llms.base_llm.chat.transformation import BaseLLMException
+from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import ModelResponse, Usage
+
+from ..common_utils import OobaboogaError
+
+if TYPE_CHECKING:
+    from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+
+    LoggingClass = LiteLLMLoggingObj
+else:
+    LoggingClass = Any
+
+
+class OobaboogaConfig(OpenAIGPTConfig):
+    def get_error_class(
+        self,
+        error_message: str,
+        status_code: int,
+        headers: Optional[Union[dict, httpx.Headers]] = None,
+    ) -> BaseLLMException:
+        return OobaboogaError(
+            status_code=status_code, message=error_message, headers=headers
+        )
+
+    def transform_response(
+        self,
+        model: str,
+        raw_response: httpx.Response,
+        model_response: ModelResponse,
+        logging_obj: LoggingClass,
+        request_data: dict,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        encoding: Any,
+        api_key: Optional[str] = None,
+        json_mode: Optional[bool] = None,
+    ) -> ModelResponse:
+        ## LOGGING
+        logging_obj.post_call(
+            input=messages,
+            api_key=api_key,
+            original_response=raw_response.text,
+            additional_args={"complete_input_dict": request_data},
+        )
+
+        ## RESPONSE OBJECT
+        try:
+            completion_response = raw_response.json()
+        except Exception:
+            raise OobaboogaError(
+                message=raw_response.text, status_code=raw_response.status_code
+            )
+        if "error" in completion_response:
+            raise OobaboogaError(
+                message=completion_response["error"],
+                status_code=raw_response.status_code,
+            )
+        else:
+            try:
+                model_response.choices[0].message.content = completion_response["choices"][0]["message"]["content"]  # type: ignore
+            except Exception as e:
+                raise OobaboogaError(
+                    message=str(e),
+                    status_code=raw_response.status_code,
+                )
+
+        model_response.created = int(time.time())
+        model_response.model = model
+        usage = Usage(
+            prompt_tokens=completion_response["usage"]["prompt_tokens"],
+            completion_tokens=completion_response["usage"]["completion_tokens"],
+            total_tokens=completion_response["usage"]["total_tokens"],
+        )
+        setattr(model_response, "usage", usage)
+        return model_response
+
+    def validate_environment(
+        self,
+        headers: dict,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+    ) -> dict:
+        headers = {
+            "accept": "application/json",
+            "content-type": "application/json",
+        }
+        if api_key is not None:
+            headers["Authorization"] = f"Token {api_key}"
+        return headers