From 4a52a71956a8d46fcb7294ac71734504bb09bcc2 Mon Sep 17 00:00:00 2001 From: S. Solomon Darnell Date: Fri, 28 Mar 2025 21:52:21 -0500 Subject: two version of R2R are here --- .../litellm/llms/predibase/chat/transformation.py | 180 +++++++++++++++++++++ 1 file changed, 180 insertions(+) create mode 100644 .venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/transformation.py (limited to '.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/transformation.py') diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/transformation.py new file mode 100644 index 00000000..f5742386 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/transformation.py @@ -0,0 +1,180 @@ +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union + +from httpx import Headers, Response + +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ModelResponse + +from ..common_utils import PredibaseError + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class PredibaseConfig(BaseConfig): + """ + Reference: https://docs.predibase.com/user-guide/inference/rest_api + """ + + adapter_id: Optional[str] = None + adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None + best_of: Optional[int] = None + decoder_input_details: Optional[bool] = None + details: bool = True # enables returning logprobs + best of + max_new_tokens: int = ( + 256 # openai default - requests hang if max_new_tokens not given + ) + repetition_penalty: Optional[float] = None + return_full_text: Optional[bool] = ( + False # by default don't return the input as part of the output + ) + seed: Optional[int] = None + stop: Optional[List[str]] = None + temperature: Optional[float] = None + top_k: Optional[int] = None + top_p: Optional[int] = None + truncate: Optional[int] = None + typical_p: Optional[float] = None + watermark: Optional[bool] = None + + def __init__( + self, + best_of: Optional[int] = None, + decoder_input_details: Optional[bool] = None, + details: Optional[bool] = None, + max_new_tokens: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: Optional[bool] = None, + seed: Optional[int] = None, + stop: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[int] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: Optional[bool] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return super().get_config() + + def get_supported_openai_params(self, model: str): + return [ + "stream", + "temperature", + "max_completion_tokens", + "max_tokens", + "top_p", + "stop", + "n", + "response_format", + ] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + for param, value in non_default_params.items(): + # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None + if param == "temperature": + if value == 0.0 or value == 0: + # hugging face exception raised when temp==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive + value = 0.01 + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "n": + optional_params["best_of"] = value + optional_params["do_sample"] = ( + True # Need to sample if you want best of for hf inference endpoints + ) + if param == "stream": + optional_params["stream"] = value + if param == "stop": + optional_params["stop"] = value + if param == "max_tokens" or param == "max_completion_tokens": + # HF TGI raises the following exception when max_new_tokens==0 + # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive + if value == 0: + value = 1 + optional_params["max_new_tokens"] = value + if param == "echo": + # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details + # Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False + optional_params["decoder_input_details"] = True + if param == "response_format": + optional_params["response_format"] = value + return optional_params + + def transform_response( + self, + model: str, + raw_response: Response, + model_response: ModelResponse, + logging_obj: LiteLLMLoggingObj, + request_data: dict, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + encoding: str, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + raise NotImplementedError( + "Predibase transformation currently done in handler.py. Need to migrate to this file." + ) + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + raise NotImplementedError( + "Predibase transformation currently done in handler.py. Need to migrate to this file." + ) + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[dict, Headers] + ) -> BaseLLMException: + return PredibaseError( + status_code=status_code, message=error_message, headers=headers + ) + + def validate_environment( + self, + headers: dict, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + if api_key is None: + raise ValueError( + "Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params" + ) + + default_headers = { + "content-type": "application/json", + "Authorization": "Bearer {}".format(api_key), + } + if headers is not None and isinstance(headers, dict): + headers = {**default_headers, **headers} + return headers -- cgit v1.2.3