aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/huggingface/chat/transformation.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/huggingface/chat/transformation.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/huggingface/chat/transformation.py589
1 files changed, 589 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/huggingface/chat/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/huggingface/chat/transformation.py
new file mode 100644
index 00000000..858fda47
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/huggingface/chat/transformation.py
@@ -0,0 +1,589 @@
+import json
+import os
+import time
+from copy import deepcopy
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.prompt_templates.common_utils import (
+ convert_content_list_to_str,
+)
+from litellm.litellm_core_utils.prompt_templates.factory import (
+ custom_prompt,
+ prompt_factory,
+)
+from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
+from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
+from litellm.secret_managers.main import get_secret_str
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import Choices, Message, ModelResponse, Usage
+from litellm.utils import token_counter
+
+from ..common_utils import HuggingfaceError, hf_task_list, hf_tasks, output_parser
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+
+ LoggingClass = LiteLLMLoggingObj
+else:
+ LoggingClass = Any
+
+
+tgi_models_cache = None
+conv_models_cache = None
+
+
+class HuggingfaceChatConfig(BaseConfig):
+ """
+ Reference: https://huggingface.github.io/text-generation-inference/#/Text%20Generation%20Inference/compat_generate
+ """
+
+ hf_task: Optional[hf_tasks] = (
+ None # litellm-specific param, used to know the api spec to use when calling huggingface api
+ )
+ best_of: Optional[int] = None
+ decoder_input_details: Optional[bool] = None
+ details: Optional[bool] = True # enables returning logprobs + best of
+ max_new_tokens: Optional[int] = None
+ repetition_penalty: Optional[float] = None
+ return_full_text: Optional[bool] = (
+ False # by default don't return the input as part of the output
+ )
+ seed: Optional[int] = None
+ temperature: Optional[float] = None
+ top_k: Optional[int] = None
+ top_n_tokens: Optional[int] = None
+ top_p: Optional[int] = None
+ truncate: Optional[int] = None
+ typical_p: Optional[float] = None
+ watermark: Optional[bool] = None
+
+ def __init__(
+ self,
+ best_of: Optional[int] = None,
+ decoder_input_details: Optional[bool] = None,
+ details: Optional[bool] = None,
+ max_new_tokens: Optional[int] = None,
+ repetition_penalty: Optional[float] = None,
+ return_full_text: Optional[bool] = None,
+ seed: Optional[int] = None,
+ temperature: Optional[float] = None,
+ top_k: Optional[int] = None,
+ top_n_tokens: Optional[int] = None,
+ top_p: Optional[int] = None,
+ truncate: Optional[int] = None,
+ typical_p: Optional[float] = None,
+ watermark: Optional[bool] = None,
+ ) -> None:
+ locals_ = locals().copy()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return super().get_config()
+
+ def get_special_options_params(self):
+ return ["use_cache", "wait_for_model"]
+
+ def get_supported_openai_params(self, model: str):
+ return [
+ "stream",
+ "temperature",
+ "max_tokens",
+ "max_completion_tokens",
+ "top_p",
+ "stop",
+ "n",
+ "echo",
+ ]
+
+ def map_openai_params(
+ self,
+ non_default_params: Dict,
+ optional_params: Dict,
+ model: str,
+ drop_params: bool,
+ ) -> Dict:
+ for param, value in non_default_params.items():
+ # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None
+ if param == "temperature":
+ if value == 0.0 or value == 0:
+ # hugging face exception raised when temp==0
+ # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive
+ value = 0.01
+ optional_params["temperature"] = value
+ if param == "top_p":
+ optional_params["top_p"] = value
+ if param == "n":
+ optional_params["best_of"] = value
+ optional_params["do_sample"] = (
+ True # Need to sample if you want best of for hf inference endpoints
+ )
+ if param == "stream":
+ optional_params["stream"] = value
+ if param == "stop":
+ optional_params["stop"] = value
+ if param == "max_tokens" or param == "max_completion_tokens":
+ # HF TGI raises the following exception when max_new_tokens==0
+ # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive
+ if value == 0:
+ value = 1
+ optional_params["max_new_tokens"] = value
+ if param == "echo":
+ # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details
+ # Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False
+ optional_params["decoder_input_details"] = True
+
+ return optional_params
+
+ def get_hf_api_key(self) -> Optional[str]:
+ return get_secret_str("HUGGINGFACE_API_KEY")
+
+ def read_tgi_conv_models(self):
+ try:
+ global tgi_models_cache, conv_models_cache
+ # Check if the cache is already populated
+ # so we don't keep on reading txt file if there are 1k requests
+ if (tgi_models_cache is not None) and (conv_models_cache is not None):
+ return tgi_models_cache, conv_models_cache
+ # If not, read the file and populate the cache
+ tgi_models = set()
+ script_directory = os.path.dirname(os.path.abspath(__file__))
+ script_directory = os.path.dirname(script_directory)
+ # Construct the file path relative to the script's directory
+ file_path = os.path.join(
+ script_directory,
+ "huggingface_llms_metadata",
+ "hf_text_generation_models.txt",
+ )
+
+ with open(file_path, "r") as file:
+ for line in file:
+ tgi_models.add(line.strip())
+
+ # Cache the set for future use
+ tgi_models_cache = tgi_models
+
+ # If not, read the file and populate the cache
+ file_path = os.path.join(
+ script_directory,
+ "huggingface_llms_metadata",
+ "hf_conversational_models.txt",
+ )
+ conv_models = set()
+ with open(file_path, "r") as file:
+ for line in file:
+ conv_models.add(line.strip())
+ # Cache the set for future use
+ conv_models_cache = conv_models
+ return tgi_models, conv_models
+ except Exception:
+ return set(), set()
+
+ def get_hf_task_for_model(self, model: str) -> Tuple[hf_tasks, str]:
+ # read text file, cast it to set
+ # read the file called "huggingface_llms_metadata/hf_text_generation_models.txt"
+ if model.split("/")[0] in hf_task_list:
+ split_model = model.split("/", 1)
+ return split_model[0], split_model[1] # type: ignore
+ tgi_models, conversational_models = self.read_tgi_conv_models()
+
+ if model in tgi_models:
+ return "text-generation-inference", model
+ elif model in conversational_models:
+ return "conversational", model
+ elif "roneneldan/TinyStories" in model:
+ return "text-generation", model
+ else:
+ return "text-generation-inference", model # default to tgi
+
+ def transform_request(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ headers: dict,
+ ) -> dict:
+ task = litellm_params.get("task", None)
+ ## VALIDATE API FORMAT
+ if task is None or not isinstance(task, str) or task not in hf_task_list:
+ raise Exception(
+ "Invalid hf task - {}. Valid formats - {}.".format(task, hf_tasks)
+ )
+
+ ## Load Config
+ config = litellm.HuggingfaceConfig.get_config()
+ for k, v in config.items():
+ if (
+ k not in optional_params
+ ): # completion(top_k=3) > huggingfaceConfig(top_k=3) <- allows for dynamic variables to be passed in
+ optional_params[k] = v
+
+ ### MAP INPUT PARAMS
+ #### HANDLE SPECIAL PARAMS
+ special_params = self.get_special_options_params()
+ special_params_dict = {}
+ # Create a list of keys to pop after iteration
+ keys_to_pop = []
+
+ for k, v in optional_params.items():
+ if k in special_params:
+ special_params_dict[k] = v
+ keys_to_pop.append(k)
+
+ # Pop the keys from the dictionary after iteration
+ for k in keys_to_pop:
+ optional_params.pop(k)
+ if task == "conversational":
+ inference_params = deepcopy(optional_params)
+ inference_params.pop("details")
+ inference_params.pop("return_full_text")
+ past_user_inputs = []
+ generated_responses = []
+ text = ""
+ for message in messages:
+ if message["role"] == "user":
+ if text != "":
+ past_user_inputs.append(text)
+ text = convert_content_list_to_str(message)
+ elif message["role"] == "assistant" or message["role"] == "system":
+ generated_responses.append(convert_content_list_to_str(message))
+ data = {
+ "inputs": {
+ "text": text,
+ "past_user_inputs": past_user_inputs,
+ "generated_responses": generated_responses,
+ },
+ "parameters": inference_params,
+ }
+
+ elif task == "text-generation-inference":
+ # always send "details" and "return_full_text" as params
+ if model in litellm.custom_prompt_dict:
+ # check if the model has a registered custom prompt
+ model_prompt_details = litellm.custom_prompt_dict[model]
+ prompt = custom_prompt(
+ role_dict=model_prompt_details.get("roles", None),
+ initial_prompt_value=model_prompt_details.get(
+ "initial_prompt_value", ""
+ ),
+ final_prompt_value=model_prompt_details.get(
+ "final_prompt_value", ""
+ ),
+ messages=messages,
+ )
+ else:
+ prompt = prompt_factory(model=model, messages=messages)
+ data = {
+ "inputs": prompt, # type: ignore
+ "parameters": optional_params,
+ "stream": ( # type: ignore
+ True
+ if "stream" in optional_params
+ and isinstance(optional_params["stream"], bool)
+ and optional_params["stream"] is True # type: ignore
+ else False
+ ),
+ }
+ else:
+ # Non TGI and Conversational llms
+ # We need this branch, it removes 'details' and 'return_full_text' from params
+ if model in litellm.custom_prompt_dict:
+ # check if the model has a registered custom prompt
+ model_prompt_details = litellm.custom_prompt_dict[model]
+ prompt = custom_prompt(
+ role_dict=model_prompt_details.get("roles", {}),
+ initial_prompt_value=model_prompt_details.get(
+ "initial_prompt_value", ""
+ ),
+ final_prompt_value=model_prompt_details.get(
+ "final_prompt_value", ""
+ ),
+ bos_token=model_prompt_details.get("bos_token", ""),
+ eos_token=model_prompt_details.get("eos_token", ""),
+ messages=messages,
+ )
+ else:
+ prompt = prompt_factory(model=model, messages=messages)
+ inference_params = deepcopy(optional_params)
+ inference_params.pop("details")
+ inference_params.pop("return_full_text")
+ data = {
+ "inputs": prompt, # type: ignore
+ }
+ if task == "text-generation-inference":
+ data["parameters"] = inference_params
+ data["stream"] = ( # type: ignore
+ True # type: ignore
+ if "stream" in optional_params and optional_params["stream"] is True
+ else False
+ )
+
+ ### RE-ADD SPECIAL PARAMS
+ if len(special_params_dict.keys()) > 0:
+ data.update({"options": special_params_dict})
+
+ return data
+
+ def get_api_base(self, api_base: Optional[str], model: str) -> str:
+ """
+ Get the API base for the Huggingface API.
+
+ Do not add the chat/embedding/rerank extension here. Let the handler do this.
+ """
+ if "https" in model:
+ completion_url = model
+ elif api_base is not None:
+ completion_url = api_base
+ elif "HF_API_BASE" in os.environ:
+ completion_url = os.getenv("HF_API_BASE", "")
+ elif "HUGGINGFACE_API_BASE" in os.environ:
+ completion_url = os.getenv("HUGGINGFACE_API_BASE", "")
+ else:
+ completion_url = f"https://api-inference.huggingface.co/models/{model}"
+
+ return completion_url
+
+ def validate_environment(
+ self,
+ headers: Dict,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: Dict,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ ) -> Dict:
+ default_headers = {
+ "content-type": "application/json",
+ }
+ if api_key is not None:
+ default_headers["Authorization"] = (
+ f"Bearer {api_key}" # Huggingface Inference Endpoint default is to accept bearer tokens
+ )
+
+ headers = {**headers, **default_headers}
+ return headers
+
+ def get_error_class(
+ self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
+ ) -> BaseLLMException:
+ return HuggingfaceError(
+ status_code=status_code, message=error_message, headers=headers
+ )
+
+ def _convert_streamed_response_to_complete_response(
+ self,
+ response: httpx.Response,
+ logging_obj: LoggingClass,
+ model: str,
+ data: dict,
+ api_key: Optional[str] = None,
+ ) -> List[Dict[str, Any]]:
+ streamed_response = CustomStreamWrapper(
+ completion_stream=response.iter_lines(),
+ model=model,
+ custom_llm_provider="huggingface",
+ logging_obj=logging_obj,
+ )
+ content = ""
+ for chunk in streamed_response:
+ content += chunk["choices"][0]["delta"]["content"]
+ completion_response: List[Dict[str, Any]] = [{"generated_text": content}]
+ ## LOGGING
+ logging_obj.post_call(
+ input=data,
+ api_key=api_key,
+ original_response=completion_response,
+ additional_args={"complete_input_dict": data},
+ )
+ return completion_response
+
+ def convert_to_model_response_object( # noqa: PLR0915
+ self,
+ completion_response: Union[List[Dict[str, Any]], Dict[str, Any]],
+ model_response: ModelResponse,
+ task: Optional[hf_tasks],
+ optional_params: dict,
+ encoding: Any,
+ messages: List[AllMessageValues],
+ model: str,
+ ):
+ if task is None:
+ task = "text-generation-inference" # default to tgi
+
+ if task == "conversational":
+ if len(completion_response["generated_text"]) > 0: # type: ignore
+ model_response.choices[0].message.content = completion_response[ # type: ignore
+ "generated_text"
+ ]
+ elif task == "text-generation-inference":
+ if (
+ not isinstance(completion_response, list)
+ or not isinstance(completion_response[0], dict)
+ or "generated_text" not in completion_response[0]
+ ):
+ raise HuggingfaceError(
+ status_code=422,
+ message=f"response is not in expected format - {completion_response}",
+ headers=None,
+ )
+
+ if len(completion_response[0]["generated_text"]) > 0:
+ model_response.choices[0].message.content = output_parser( # type: ignore
+ completion_response[0]["generated_text"]
+ )
+ ## GETTING LOGPROBS + FINISH REASON
+ if (
+ "details" in completion_response[0]
+ and "tokens" in completion_response[0]["details"]
+ ):
+ model_response.choices[0].finish_reason = completion_response[0][
+ "details"
+ ]["finish_reason"]
+ sum_logprob = 0
+ for token in completion_response[0]["details"]["tokens"]:
+ if token["logprob"] is not None:
+ sum_logprob += token["logprob"]
+ setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore
+ if "best_of" in optional_params and optional_params["best_of"] > 1:
+ if (
+ "details" in completion_response[0]
+ and "best_of_sequences" in completion_response[0]["details"]
+ ):
+ choices_list = []
+ for idx, item in enumerate(
+ completion_response[0]["details"]["best_of_sequences"]
+ ):
+ sum_logprob = 0
+ for token in item["tokens"]:
+ if token["logprob"] is not None:
+ sum_logprob += token["logprob"]
+ if len(item["generated_text"]) > 0:
+ message_obj = Message(
+ content=output_parser(item["generated_text"]),
+ logprobs=sum_logprob,
+ )
+ else:
+ message_obj = Message(content=None)
+ choice_obj = Choices(
+ finish_reason=item["finish_reason"],
+ index=idx + 1,
+ message=message_obj,
+ )
+ choices_list.append(choice_obj)
+ model_response.choices.extend(choices_list)
+ elif task == "text-classification":
+ model_response.choices[0].message.content = json.dumps( # type: ignore
+ completion_response
+ )
+ else:
+ if (
+ isinstance(completion_response, list)
+ and len(completion_response[0]["generated_text"]) > 0
+ ):
+ model_response.choices[0].message.content = output_parser( # type: ignore
+ completion_response[0]["generated_text"]
+ )
+ ## CALCULATING USAGE
+ prompt_tokens = 0
+ try:
+ prompt_tokens = token_counter(model=model, messages=messages)
+ except Exception:
+ # this should remain non blocking we should not block a response returning if calculating usage fails
+ pass
+ output_text = model_response["choices"][0]["message"].get("content", "")
+ if output_text is not None and len(output_text) > 0:
+ completion_tokens = 0
+ try:
+ completion_tokens = len(
+ encoding.encode(
+ model_response["choices"][0]["message"].get("content", "")
+ )
+ ) ##[TODO] use the llama2 tokenizer here
+ except Exception:
+ # this should remain non blocking we should not block a response returning if calculating usage fails
+ pass
+ else:
+ completion_tokens = 0
+
+ model_response.created = int(time.time())
+ model_response.model = model
+ usage = Usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ )
+ setattr(model_response, "usage", usage)
+ model_response._hidden_params["original_response"] = completion_response
+ return model_response
+
+ def transform_response(
+ self,
+ model: str,
+ raw_response: httpx.Response,
+ model_response: ModelResponse,
+ logging_obj: LoggingClass,
+ request_data: Dict,
+ messages: List[AllMessageValues],
+ optional_params: Dict,
+ litellm_params: Dict,
+ encoding: Any,
+ api_key: Optional[str] = None,
+ json_mode: Optional[bool] = None,
+ ) -> ModelResponse:
+ ## Some servers might return streaming responses even though stream was not set to true. (e.g. Baseten)
+ task = litellm_params.get("task", None)
+ is_streamed = False
+ if (
+ raw_response.__dict__["headers"].get("Content-Type", "")
+ == "text/event-stream"
+ ):
+ is_streamed = True
+
+ # iterate over the complete streamed response, and return the final answer
+ if is_streamed:
+ completion_response = self._convert_streamed_response_to_complete_response(
+ response=raw_response,
+ logging_obj=logging_obj,
+ model=model,
+ data=request_data,
+ api_key=api_key,
+ )
+ else:
+ ## LOGGING
+ logging_obj.post_call(
+ input=request_data,
+ api_key=api_key,
+ original_response=raw_response.text,
+ additional_args={"complete_input_dict": request_data},
+ )
+ ## RESPONSE OBJECT
+ try:
+ completion_response = raw_response.json()
+ if isinstance(completion_response, dict):
+ completion_response = [completion_response]
+ except Exception:
+ raise HuggingfaceError(
+ message=f"Original Response received: {raw_response.text}",
+ status_code=raw_response.status_code,
+ )
+
+ if isinstance(completion_response, dict) and "error" in completion_response:
+ raise HuggingfaceError(
+ message=completion_response["error"], # type: ignore
+ status_code=raw_response.status_code,
+ )
+ return self.convert_to_model_response_object(
+ completion_response=completion_response,
+ model_response=model_response,
+ task=task if task is not None and task in hf_task_list else None,
+ optional_params=optional_params,
+ encoding=encoding,
+ messages=messages,
+ model=model,
+ )