diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py | 280 |
1 files changed, 280 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py new file mode 100644 index 00000000..b1b2bbee --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py @@ -0,0 +1,280 @@ +# +------------------------------------+ +# +# Prompt Injection Detection +# +# +------------------------------------+ +# Thank you users! We ❤️ you! - Krrish & Ishaan +## Reject a call if it contains a prompt injection attack. + + +from difflib import SequenceMatcher +from typing import List, Literal, Optional + +from fastapi import HTTPException + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.prompt_templates.factory import ( + prompt_injection_detection_default_pt, +) +from litellm.proxy._types import LiteLLMPromptInjectionParams, UserAPIKeyAuth +from litellm.router import Router +from litellm.utils import get_formatted_prompt + + +class _OPTIONAL_PromptInjectionDetection(CustomLogger): + # Class variables or attributes + def __init__( + self, + prompt_injection_params: Optional[LiteLLMPromptInjectionParams] = None, + ): + self.prompt_injection_params = prompt_injection_params + self.llm_router: Optional[Router] = None + + self.verbs = [ + "Ignore", + "Disregard", + "Skip", + "Forget", + "Neglect", + "Overlook", + "Omit", + "Bypass", + "Pay no attention to", + "Do not follow", + "Do not obey", + ] + self.adjectives = [ + "", + "prior", + "previous", + "preceding", + "above", + "foregoing", + "earlier", + "initial", + ] + self.prepositions = [ + "", + "and start over", + "and start anew", + "and begin afresh", + "and start from scratch", + ] + + def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"): + if level == "INFO": + verbose_proxy_logger.info(print_statement) + elif level == "DEBUG": + verbose_proxy_logger.debug(print_statement) + + if litellm.set_verbose is True: + print(print_statement) # noqa + + def update_environment(self, router: Optional[Router] = None): + self.llm_router = router + + if ( + self.prompt_injection_params is not None + and self.prompt_injection_params.llm_api_check is True + ): + if self.llm_router is None: + raise Exception( + "PromptInjectionDetection: Model List not set. Required for Prompt Injection detection." + ) + + self.print_verbose( + f"model_names: {self.llm_router.model_names}; self.prompt_injection_params.llm_api_name: {self.prompt_injection_params.llm_api_name}" + ) + if ( + self.prompt_injection_params.llm_api_name is None + or self.prompt_injection_params.llm_api_name + not in self.llm_router.model_names + ): + raise Exception( + "PromptInjectionDetection: Invalid LLM API Name. LLM API Name must be a 'model_name' in 'model_list'." + ) + + def generate_injection_keywords(self) -> List[str]: + combinations = [] + for verb in self.verbs: + for adj in self.adjectives: + for prep in self.prepositions: + phrase = " ".join(filter(None, [verb, adj, prep])).strip() + if ( + len(phrase.split()) > 2 + ): # additional check to ensure more than 2 words + combinations.append(phrase.lower()) + return combinations + + def check_user_input_similarity( + self, user_input: str, similarity_threshold: float = 0.7 + ) -> bool: + user_input_lower = user_input.lower() + keywords = self.generate_injection_keywords() + + for keyword in keywords: + # Calculate the length of the keyword to extract substrings of the same length from user input + keyword_length = len(keyword) + + for i in range(len(user_input_lower) - keyword_length + 1): + # Extract a substring of the same length as the keyword + substring = user_input_lower[i : i + keyword_length] + + # Calculate similarity + match_ratio = SequenceMatcher(None, substring, keyword).ratio() + if match_ratio > similarity_threshold: + self.print_verbose( + print_statement=f"Rejected user input - {user_input}. {match_ratio} similar to {keyword}", + level="INFO", + ) + return True # Found a highly similar substring + return False # No substring crossed the threshold + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, # "completion", "embeddings", "image_generation", "moderation" + ): + try: + """ + - check if user id part of call + - check if user id part of blocked list + """ + self.print_verbose("Inside Prompt Injection Detection Pre-Call Hook") + try: + assert call_type in [ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + ] + except Exception: + self.print_verbose( + f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']" + ) + return data + formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore + + is_prompt_attack = False + + if self.prompt_injection_params is not None: + # 1. check if heuristics check turned on + if self.prompt_injection_params.heuristics_check is True: + is_prompt_attack = self.check_user_input_similarity( + user_input=formatted_prompt + ) + if is_prompt_attack is True: + raise HTTPException( + status_code=400, + detail={ + "error": "Rejected message. This is a prompt injection attack." + }, + ) + # 2. check if vector db similarity check turned on [TODO] Not Implemented yet + if self.prompt_injection_params.vector_db_check is True: + pass + else: + is_prompt_attack = self.check_user_input_similarity( + user_input=formatted_prompt + ) + + if is_prompt_attack is True: + raise HTTPException( + status_code=400, + detail={ + "error": "Rejected message. This is a prompt injection attack." + }, + ) + + return data + + except HTTPException as e: + + if ( + e.status_code == 400 + and isinstance(e.detail, dict) + and "error" in e.detail # type: ignore + and self.prompt_injection_params is not None + and self.prompt_injection_params.reject_as_response + ): + return e.detail.get("error") + raise e + except Exception as e: + verbose_proxy_logger.exception( + "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format( + str(e) + ) + ) + + async def async_moderation_hook( # type: ignore + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal[ + "completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + ], + ) -> Optional[bool]: + self.print_verbose( + f"IN ASYNC MODERATION HOOK - self.prompt_injection_params = {self.prompt_injection_params}" + ) + + if self.prompt_injection_params is None: + return None + + formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore + is_prompt_attack = False + + prompt_injection_system_prompt = getattr( + self.prompt_injection_params, + "llm_api_system_prompt", + prompt_injection_detection_default_pt(), + ) + + # 3. check if llm api check turned on + if ( + self.prompt_injection_params.llm_api_check is True + and self.prompt_injection_params.llm_api_name is not None + and self.llm_router is not None + ): + # make a call to the llm api + response = await self.llm_router.acompletion( + model=self.prompt_injection_params.llm_api_name, + messages=[ + { + "role": "system", + "content": prompt_injection_system_prompt, + }, + {"role": "user", "content": formatted_prompt}, + ], + ) + + self.print_verbose(f"Received LLM Moderation response: {response}") + self.print_verbose( + f"llm_api_fail_call_string: {self.prompt_injection_params.llm_api_fail_call_string}" + ) + if isinstance(response, litellm.ModelResponse) and isinstance( + response.choices[0], litellm.Choices + ): + if self.prompt_injection_params.llm_api_fail_call_string in response.choices[0].message.content: # type: ignore + is_prompt_attack = True + + if is_prompt_attack is True: + raise HTTPException( + status_code=400, + detail={ + "error": "Rejected message. This is a prompt injection attack." + }, + ) + + return is_prompt_attack |