diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/petals/completion/transformation.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/petals/completion/transformation.py | 137 |
1 files changed, 137 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/petals/completion/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/petals/completion/transformation.py new file mode 100644 index 00000000..08ec15de --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/petals/completion/transformation.py @@ -0,0 +1,137 @@ +from typing import Any, List, Optional, Union + +from httpx import Headers, Response + +import litellm +from litellm.llms.base_llm.chat.transformation import ( + BaseConfig, + BaseLLMException, + LiteLLMLoggingObj, +) +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ModelResponse + +from ..common_utils import PetalsError + + +class PetalsConfig(BaseConfig): + """ + Reference: https://github.com/petals-infra/chat.petals.dev#post-apiv1generate + The `PetalsConfig` class encapsulates the configuration for the Petals API. The properties of this class are described below: + + - `max_length` (integer): This represents the maximum length of the generated text (including the prefix) in tokens. + + - `max_new_tokens` (integer): This represents the maximum number of newly generated tokens (excluding the prefix). + + The generation parameters are compatible with `.generate()` from Hugging Face's Transformers library: + + - `do_sample` (boolean, optional): If set to 0 (default), the API runs greedy generation. If set to 1, the API performs sampling using the parameters below: + + - `temperature` (float, optional): This value sets the temperature for sampling. + + - `top_k` (integer, optional): This value sets the limit for top-k sampling. + + - `top_p` (float, optional): This value sets the limit for top-p (nucleus) sampling. + + - `repetition_penalty` (float, optional): This helps apply the repetition penalty during text generation, as discussed in this paper. + """ + + max_length: Optional[int] = None + max_new_tokens: Optional[int] = ( + litellm.max_tokens + ) # petals requires max tokens to be set + do_sample: Optional[bool] = None + temperature: Optional[float] = None + top_k: Optional[int] = None + top_p: Optional[float] = None + repetition_penalty: Optional[float] = None + + def __init__( + self, + max_length: Optional[int] = None, + max_new_tokens: Optional[ + int + ] = litellm.max_tokens, # petals requires max tokens to be set + do_sample: Optional[bool] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + repetition_penalty: Optional[float] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return super().get_config() + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, Headers] + ) -> BaseLLMException: + return PetalsError( + status_code=status_code, message=error_message, headers=headers + ) + + def get_supported_openai_params(self, model: str) -> List: + return ["max_tokens", "temperature", "top_p", "stream"] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + for param, value in non_default_params.items(): + if param == "max_tokens": + optional_params["max_new_tokens"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "stream": + optional_params["stream"] = value + return optional_params + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + raise NotImplementedError( + "Petals transformation currently done in handler.py. [TODO] Move to the transformation.py" + ) + + def transform_response( + self, + model: str, + raw_response: Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + raise NotImplementedError( + "Petals transformation currently done in handler.py. [TODO] Move to the transformation.py" + ) + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + return {} |