aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py280
1 files changed, 280 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py
new file mode 100644
index 00000000..b1b2bbee
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py
@@ -0,0 +1,280 @@
+# +------------------------------------+
+#
+# Prompt Injection Detection
+#
+# +------------------------------------+
+# Thank you users! We ❤️ you! - Krrish & Ishaan
+## Reject a call if it contains a prompt injection attack.
+
+
+from difflib import SequenceMatcher
+from typing import List, Literal, Optional
+
+from fastapi import HTTPException
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.litellm_core_utils.prompt_templates.factory import (
+ prompt_injection_detection_default_pt,
+)
+from litellm.proxy._types import LiteLLMPromptInjectionParams, UserAPIKeyAuth
+from litellm.router import Router
+from litellm.utils import get_formatted_prompt
+
+
+class _OPTIONAL_PromptInjectionDetection(CustomLogger):
+ # Class variables or attributes
+ def __init__(
+ self,
+ prompt_injection_params: Optional[LiteLLMPromptInjectionParams] = None,
+ ):
+ self.prompt_injection_params = prompt_injection_params
+ self.llm_router: Optional[Router] = None
+
+ self.verbs = [
+ "Ignore",
+ "Disregard",
+ "Skip",
+ "Forget",
+ "Neglect",
+ "Overlook",
+ "Omit",
+ "Bypass",
+ "Pay no attention to",
+ "Do not follow",
+ "Do not obey",
+ ]
+ self.adjectives = [
+ "",
+ "prior",
+ "previous",
+ "preceding",
+ "above",
+ "foregoing",
+ "earlier",
+ "initial",
+ ]
+ self.prepositions = [
+ "",
+ "and start over",
+ "and start anew",
+ "and begin afresh",
+ "and start from scratch",
+ ]
+
+ def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
+ if level == "INFO":
+ verbose_proxy_logger.info(print_statement)
+ elif level == "DEBUG":
+ verbose_proxy_logger.debug(print_statement)
+
+ if litellm.set_verbose is True:
+ print(print_statement) # noqa
+
+ def update_environment(self, router: Optional[Router] = None):
+ self.llm_router = router
+
+ if (
+ self.prompt_injection_params is not None
+ and self.prompt_injection_params.llm_api_check is True
+ ):
+ if self.llm_router is None:
+ raise Exception(
+ "PromptInjectionDetection: Model List not set. Required for Prompt Injection detection."
+ )
+
+ self.print_verbose(
+ f"model_names: {self.llm_router.model_names}; self.prompt_injection_params.llm_api_name: {self.prompt_injection_params.llm_api_name}"
+ )
+ if (
+ self.prompt_injection_params.llm_api_name is None
+ or self.prompt_injection_params.llm_api_name
+ not in self.llm_router.model_names
+ ):
+ raise Exception(
+ "PromptInjectionDetection: Invalid LLM API Name. LLM API Name must be a 'model_name' in 'model_list'."
+ )
+
+ def generate_injection_keywords(self) -> List[str]:
+ combinations = []
+ for verb in self.verbs:
+ for adj in self.adjectives:
+ for prep in self.prepositions:
+ phrase = " ".join(filter(None, [verb, adj, prep])).strip()
+ if (
+ len(phrase.split()) > 2
+ ): # additional check to ensure more than 2 words
+ combinations.append(phrase.lower())
+ return combinations
+
+ def check_user_input_similarity(
+ self, user_input: str, similarity_threshold: float = 0.7
+ ) -> bool:
+ user_input_lower = user_input.lower()
+ keywords = self.generate_injection_keywords()
+
+ for keyword in keywords:
+ # Calculate the length of the keyword to extract substrings of the same length from user input
+ keyword_length = len(keyword)
+
+ for i in range(len(user_input_lower) - keyword_length + 1):
+ # Extract a substring of the same length as the keyword
+ substring = user_input_lower[i : i + keyword_length]
+
+ # Calculate similarity
+ match_ratio = SequenceMatcher(None, substring, keyword).ratio()
+ if match_ratio > similarity_threshold:
+ self.print_verbose(
+ print_statement=f"Rejected user input - {user_input}. {match_ratio} similar to {keyword}",
+ level="INFO",
+ )
+ return True # Found a highly similar substring
+ return False # No substring crossed the threshold
+
+ async def async_pre_call_hook(
+ self,
+ user_api_key_dict: UserAPIKeyAuth,
+ cache: DualCache,
+ data: dict,
+ call_type: str, # "completion", "embeddings", "image_generation", "moderation"
+ ):
+ try:
+ """
+ - check if user id part of call
+ - check if user id part of blocked list
+ """
+ self.print_verbose("Inside Prompt Injection Detection Pre-Call Hook")
+ try:
+ assert call_type in [
+ "completion",
+ "text_completion",
+ "embeddings",
+ "image_generation",
+ "moderation",
+ "audio_transcription",
+ ]
+ except Exception:
+ self.print_verbose(
+ f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']"
+ )
+ return data
+ formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore
+
+ is_prompt_attack = False
+
+ if self.prompt_injection_params is not None:
+ # 1. check if heuristics check turned on
+ if self.prompt_injection_params.heuristics_check is True:
+ is_prompt_attack = self.check_user_input_similarity(
+ user_input=formatted_prompt
+ )
+ if is_prompt_attack is True:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "Rejected message. This is a prompt injection attack."
+ },
+ )
+ # 2. check if vector db similarity check turned on [TODO] Not Implemented yet
+ if self.prompt_injection_params.vector_db_check is True:
+ pass
+ else:
+ is_prompt_attack = self.check_user_input_similarity(
+ user_input=formatted_prompt
+ )
+
+ if is_prompt_attack is True:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "Rejected message. This is a prompt injection attack."
+ },
+ )
+
+ return data
+
+ except HTTPException as e:
+
+ if (
+ e.status_code == 400
+ and isinstance(e.detail, dict)
+ and "error" in e.detail # type: ignore
+ and self.prompt_injection_params is not None
+ and self.prompt_injection_params.reject_as_response
+ ):
+ return e.detail.get("error")
+ raise e
+ except Exception as e:
+ verbose_proxy_logger.exception(
+ "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+
+ async def async_moderation_hook( # type: ignore
+ self,
+ data: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ call_type: Literal[
+ "completion",
+ "embeddings",
+ "image_generation",
+ "moderation",
+ "audio_transcription",
+ ],
+ ) -> Optional[bool]:
+ self.print_verbose(
+ f"IN ASYNC MODERATION HOOK - self.prompt_injection_params = {self.prompt_injection_params}"
+ )
+
+ if self.prompt_injection_params is None:
+ return None
+
+ formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore
+ is_prompt_attack = False
+
+ prompt_injection_system_prompt = getattr(
+ self.prompt_injection_params,
+ "llm_api_system_prompt",
+ prompt_injection_detection_default_pt(),
+ )
+
+ # 3. check if llm api check turned on
+ if (
+ self.prompt_injection_params.llm_api_check is True
+ and self.prompt_injection_params.llm_api_name is not None
+ and self.llm_router is not None
+ ):
+ # make a call to the llm api
+ response = await self.llm_router.acompletion(
+ model=self.prompt_injection_params.llm_api_name,
+ messages=[
+ {
+ "role": "system",
+ "content": prompt_injection_system_prompt,
+ },
+ {"role": "user", "content": formatted_prompt},
+ ],
+ )
+
+ self.print_verbose(f"Received LLM Moderation response: {response}")
+ self.print_verbose(
+ f"llm_api_fail_call_string: {self.prompt_injection_params.llm_api_fail_call_string}"
+ )
+ if isinstance(response, litellm.ModelResponse) and isinstance(
+ response.choices[0], litellm.Choices
+ ):
+ if self.prompt_injection_params.llm_api_fail_call_string in response.choices[0].message.content: # type: ignore
+ is_prompt_attack = True
+
+ if is_prompt_attack is True:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "Rejected message. This is a prompt injection attack."
+ },
+ )
+
+ return is_prompt_attack