about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/petals
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/petals
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/petals')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/petals/common_utils.py10
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/petals/completion/handler.py149
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/petals/completion/transformation.py137
3 files changed, 296 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/petals/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/petals/common_utils.py
new file mode 100644
index 00000000..bffee338
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/petals/common_utils.py
@@ -0,0 +1,10 @@
+from typing import Union
+
+from httpx import Headers
+
+from litellm.llms.base_llm.chat.transformation import BaseLLMException
+
+
+class PetalsError(BaseLLMException):
+    def __init__(self, status_code: int, message: str, headers: Union[dict, Headers]):
+        super().__init__(status_code=status_code, message=message, headers=headers)
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/petals/completion/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/petals/completion/handler.py
new file mode 100644
index 00000000..ae38baec
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/petals/completion/handler.py
@@ -0,0 +1,149 @@
+import time
+from typing import Callable, Optional, Union
+
+import litellm
+from litellm.litellm_core_utils.prompt_templates.factory import (
+    custom_prompt,
+    prompt_factory,
+)
+from litellm.llms.custom_httpx.http_handler import (
+    AsyncHTTPHandler,
+    HTTPHandler,
+    _get_httpx_client,
+)
+from litellm.utils import ModelResponse, Usage
+
+from ..common_utils import PetalsError
+
+
+def completion(
+    model: str,
+    messages: list,
+    api_base: Optional[str],
+    model_response: ModelResponse,
+    print_verbose: Callable,
+    encoding,
+    logging_obj,
+    optional_params: dict,
+    stream=False,
+    litellm_params=None,
+    logger_fn=None,
+    client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+):
+    ## Load Config
+    config = litellm.PetalsConfig.get_config()
+    for k, v in config.items():
+        if (
+            k not in optional_params
+        ):  # completion(top_k=3) > petals_config(top_k=3) <- allows for dynamic variables to be passed in
+            optional_params[k] = v
+
+    if model in litellm.custom_prompt_dict:
+        # check if the model has a registered custom prompt
+        model_prompt_details = litellm.custom_prompt_dict[model]
+        prompt = custom_prompt(
+            role_dict=model_prompt_details["roles"],
+            initial_prompt_value=model_prompt_details["initial_prompt_value"],
+            final_prompt_value=model_prompt_details["final_prompt_value"],
+            messages=messages,
+        )
+    else:
+        prompt = prompt_factory(model=model, messages=messages)
+
+    output_text: Optional[str] = None
+    if api_base:
+        ## LOGGING
+        logging_obj.pre_call(
+            input=prompt,
+            api_key="",
+            additional_args={
+                "complete_input_dict": optional_params,
+                "api_base": api_base,
+            },
+        )
+        data = {"model": model, "inputs": prompt, **optional_params}
+
+        ## COMPLETION CALL
+        if client is None or not isinstance(client, HTTPHandler):
+            client = _get_httpx_client()
+        response = client.post(api_base, data=data)
+
+        ## LOGGING
+        logging_obj.post_call(
+            input=prompt,
+            api_key="",
+            original_response=response.text,
+            additional_args={"complete_input_dict": optional_params},
+        )
+
+        ## RESPONSE OBJECT
+        try:
+            output_text = response.json()["outputs"]
+        except Exception as e:
+            PetalsError(
+                status_code=response.status_code,
+                message=str(e),
+                headers=response.headers,
+            )
+
+    else:
+        try:
+            from petals import AutoDistributedModelForCausalLM  # type: ignore
+            from transformers import AutoTokenizer
+        except Exception:
+            raise Exception(
+                "Importing torch, transformers, petals failed\nTry pip installing petals \npip install git+https://github.com/bigscience-workshop/petals"
+            )
+
+        model = model
+
+        tokenizer = AutoTokenizer.from_pretrained(
+            model, use_fast=False, add_bos_token=False
+        )
+        model_obj = AutoDistributedModelForCausalLM.from_pretrained(model)
+
+        ## LOGGING
+        logging_obj.pre_call(
+            input=prompt,
+            api_key="",
+            additional_args={"complete_input_dict": optional_params},
+        )
+
+        ## COMPLETION CALL
+        inputs = tokenizer(prompt, return_tensors="pt")["input_ids"]
+
+        # optional params: max_new_tokens=1,temperature=0.9, top_p=0.6
+        outputs = model_obj.generate(inputs, **optional_params)
+
+        ## LOGGING
+        logging_obj.post_call(
+            input=prompt,
+            api_key="",
+            original_response=outputs,
+            additional_args={"complete_input_dict": optional_params},
+        )
+        ## RESPONSE OBJECT
+        output_text = tokenizer.decode(outputs[0])
+
+    if output_text is not None and len(output_text) > 0:
+        model_response.choices[0].message.content = output_text  # type: ignore
+
+    prompt_tokens = len(encoding.encode(prompt))
+    completion_tokens = len(
+        encoding.encode(model_response["choices"][0]["message"].get("content"))
+    )
+
+    model_response.created = int(time.time())
+    model_response.model = model
+    usage = Usage(
+        prompt_tokens=prompt_tokens,
+        completion_tokens=completion_tokens,
+        total_tokens=prompt_tokens + completion_tokens,
+    )
+    setattr(model_response, "usage", usage)
+    return model_response
+
+
+def embedding():
+    # logic for parsing in - calling - parsing out model embedding calls
+    pass
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/petals/completion/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/petals/completion/transformation.py
new file mode 100644
index 00000000..08ec15de
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/petals/completion/transformation.py
@@ -0,0 +1,137 @@
+from typing import Any, List, Optional, Union
+
+from httpx import Headers, Response
+
+import litellm
+from litellm.llms.base_llm.chat.transformation import (
+    BaseConfig,
+    BaseLLMException,
+    LiteLLMLoggingObj,
+)
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import ModelResponse
+
+from ..common_utils import PetalsError
+
+
+class PetalsConfig(BaseConfig):
+    """
+    Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate
+    The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below:
+
+    - `max_length` (integer): This represents the maximum length of the generated text (including the prefix) in tokens.
+
+    - `max_new_tokens` (integer): This represents the maximum number of newly generated tokens (excluding the prefix).
+
+    The generation parameters are compatible with `.generate()` from Hugging Face's Transformers library:
+
+    - `do_sample` (boolean, optional): If set to 0 (default), the API runs greedy generation. If set to 1, the API performs sampling using the parameters below:
+
+    - `temperature` (float, optional): This value sets the temperature for sampling.
+
+    - `top_k` (integer, optional): This value sets the limit for top-k sampling.
+
+    - `top_p` (float, optional): This value sets the limit for top-p (nucleus) sampling.
+
+    - `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper.
+    """
+
+    max_length: Optional[int] = None
+    max_new_tokens: Optional[int] = (
+        litellm.max_tokens
+    )  # petals requires max tokens to be set
+    do_sample: Optional[bool] = None
+    temperature: Optional[float] = None
+    top_k: Optional[int] = None
+    top_p: Optional[float] = None
+    repetition_penalty: Optional[float] = None
+
+    def __init__(
+        self,
+        max_length: Optional[int] = None,
+        max_new_tokens: Optional[
+            int
+        ] = litellm.max_tokens,  # petals requires max tokens to be set
+        do_sample: Optional[bool] = None,
+        temperature: Optional[float] = None,
+        top_k: Optional[int] = None,
+        top_p: Optional[float] = None,
+        repetition_penalty: Optional[float] = None,
+    ) -> None:
+        locals_ = locals().copy()
+        for key, value in locals_.items():
+            if key != "self" and value is not None:
+                setattr(self.__class__, key, value)
+
+    @classmethod
+    def get_config(cls):
+        return super().get_config()
+
+    def get_error_class(
+        self, error_message: str, status_code: int, headers: Union[dict, Headers]
+    ) -> BaseLLMException:
+        return PetalsError(
+            status_code=status_code, message=error_message, headers=headers
+        )
+
+    def get_supported_openai_params(self, model: str) -> List:
+        return ["max_tokens", "temperature", "top_p", "stream"]
+
+    def map_openai_params(
+        self,
+        non_default_params: dict,
+        optional_params: dict,
+        model: str,
+        drop_params: bool,
+    ) -> dict:
+        for param, value in non_default_params.items():
+            if param == "max_tokens":
+                optional_params["max_new_tokens"] = value
+            if param == "temperature":
+                optional_params["temperature"] = value
+            if param == "top_p":
+                optional_params["top_p"] = value
+            if param == "stream":
+                optional_params["stream"] = value
+        return optional_params
+
+    def transform_request(
+        self,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        headers: dict,
+    ) -> dict:
+        raise NotImplementedError(
+            "Petals transformation currently done in handler.py. [TODO] Move to the transformation.py"
+        )
+
+    def transform_response(
+        self,
+        model: str,
+        raw_response: Response,
+        model_response: ModelResponse,
+        logging_obj: LiteLLMLoggingObj,
+        request_data: dict,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        litellm_params: dict,
+        encoding: Any,
+        api_key: Optional[str] = None,
+        json_mode: Optional[bool] = None,
+    ) -> ModelResponse:
+        raise NotImplementedError(
+            "Petals transformation currently done in handler.py. [TODO] Move to the transformation.py"
+        )
+
+    def validate_environment(
+        self,
+        headers: dict,
+        model: str,
+        messages: List[AllMessageValues],
+        optional_params: dict,
+        api_key: Optional[str] = None,
+        api_base: Optional[str] = None,
+    ) -> dict:
+        return {}