aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/transformation.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/transformation.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/transformation.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/transformation.py180
1 files changed, 180 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/transformation.py
new file mode 100644
index 00000000..f5742386
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/predibase/chat/transformation.py
@@ -0,0 +1,180 @@
+from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union
+
+from httpx import Headers, Response
+
+from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import ModelResponse
+
+from ..common_utils import PredibaseError
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
+
+ LiteLLMLoggingObj = _LiteLLMLoggingObj
+else:
+ LiteLLMLoggingObj = Any
+
+
+class PredibaseConfig(BaseConfig):
+ """
+ Reference: https://docs.predibase.com/user-guide/inference/rest_api
+ """
+
+ adapter_id: Optional[str] = None
+ adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None
+ best_of: Optional[int] = None
+ decoder_input_details: Optional[bool] = None
+ details: bool = True # enables returning logprobs + best of
+ max_new_tokens: int = (
+ 256 # openai default - requests hang if max_new_tokens not given
+ )
+ repetition_penalty: Optional[float] = None
+ return_full_text: Optional[bool] = (
+ False # by default don't return the input as part of the output
+ )
+ seed: Optional[int] = None
+ stop: Optional[List[str]] = None
+ temperature: Optional[float] = None
+ top_k: Optional[int] = None
+ top_p: Optional[int] = None
+ truncate: Optional[int] = None
+ typical_p: Optional[float] = None
+ watermark: Optional[bool] = None
+
+ def __init__(
+ self,
+ best_of: Optional[int] = None,
+ decoder_input_details: Optional[bool] = None,
+ details: Optional[bool] = None,
+ max_new_tokens: Optional[int] = None,
+ repetition_penalty: Optional[float] = None,
+ return_full_text: Optional[bool] = None,
+ seed: Optional[int] = None,
+ stop: Optional[List[str]] = None,
+ temperature: Optional[float] = None,
+ top_k: Optional[int] = None,
+ top_p: Optional[int] = None,
+ truncate: Optional[int] = None,
+ typical_p: Optional[float] = None,
+ watermark: Optional[bool] = None,
+ ) -> None:
+ locals_ = locals().copy()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return super().get_config()
+
+ def get_supported_openai_params(self, model: str):
+ return [
+ "stream",
+ "temperature",
+ "max_completion_tokens",
+ "max_tokens",
+ "top_p",
+ "stop",
+ "n",
+ "response_format",
+ ]
+
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
+ ) -> dict:
+ for param, value in non_default_params.items():
+ # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
+ if param == "temperature":
+ if value == 0.0 or value == 0:
+ # hugging face exception raised when temp==0
+ # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
+ value = 0.01
+ optional_params["temperature"] = value
+ if param == "top_p":
+ optional_params["top_p"] = value
+ if param == "n":
+ optional_params["best_of"] = value
+ optional_params["do_sample"] = (
+ True # Need to sample if you want best of for hf inference endpoints
+ )
+ if param == "stream":
+ optional_params["stream"] = value
+ if param == "stop":
+ optional_params["stop"] = value
+ if param == "max_tokens" or param == "max_completion_tokens":
+ # HF TGI raises the following exception when max_new_tokens==0
+ # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
+ if value == 0:
+ value = 1
+ optional_params["max_new_tokens"] = value
+ if param == "echo":
+ # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
+ # Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
+ optional_params["decoder_input_details"] = True
+ if param == "response_format":
+ optional_params["response_format"] = value
+ return optional_params
+
+ def transform_response(
+ self,
+ model: str,
+ raw_response: Response,
+ model_response: ModelResponse,
+ logging_obj: LiteLLMLoggingObj,
+ request_data: dict,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ encoding: str,
+ api_key: Optional[str] = None,
+ json_mode: Optional[bool] = None,
+ ) -> ModelResponse:
+ raise NotImplementedError(
+ "Predibase transformation currently done in handler.py. Need to migrate to this file."
+ )
+
+ def transform_request(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ headers: dict,
+ ) -> dict:
+ raise NotImplementedError(
+ "Predibase transformation currently done in handler.py. Need to migrate to this file."
+ )
+
+ def get_error_class(
+ self, error_message: str, status_code: int, headers: Union[dict, Headers]
+ ) -> BaseLLMException:
+ return PredibaseError(
+ status_code=status_code, message=error_message, headers=headers
+ )
+
+ def validate_environment(
+ self,
+ headers: dict,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ ) -> dict:
+ if api_key is None:
+ raise ValueError(
+ "Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params"
+ )
+
+ default_headers = {
+ "content-type": "application/json",
+ "Authorization": "Bearer {}".format(api_key),
+ }
+ if headers is not None and isinstance(headers, dict):
+ headers = {**default_headers, **headers}
+ return headers