aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/huggingface/common_utils.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/huggingface/common_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/huggingface/common_utils.py45
1 files changed, 45 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/huggingface/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/huggingface/common_utils.py
new file mode 100644
index 00000000..d793b298
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/huggingface/common_utils.py
@@ -0,0 +1,45 @@
+from typing import Literal, Optional, Union
+
+import httpx
+
+from litellm.llms.base_llm.chat.transformation import BaseLLMException
+
+
+class HuggingfaceError(BaseLLMException):
+ def __init__(
+ self,
+ status_code: int,
+ message: str,
+ headers: Optional[Union[dict, httpx.Headers]] = None,
+ ):
+ super().__init__(status_code=status_code, message=message, headers=headers)
+
+
+hf_tasks = Literal[
+ "text-generation-inference",
+ "conversational",
+ "text-classification",
+ "text-generation",
+]
+
+hf_task_list = [
+ "text-generation-inference",
+ "conversational",
+ "text-classification",
+ "text-generation",
+]
+
+
+def output_parser(generated_text: str):
+ """
+ Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens.
+
+ Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763
+ """
+ chat_template_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
+ for token in chat_template_tokens:
+ if generated_text.strip().startswith(token):
+ generated_text = generated_text.replace(token, "", 1)
+ if generated_text.endswith(token):
+ generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1]
+ return generated_text