about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py280
1 files changed, 280 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py
new file mode 100644
index 00000000..b1b2bbee
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/prompt_injection_detection.py
@@ -0,0 +1,280 @@
+# +------------------------------------+
+#
+#        Prompt Injection Detection
+#
+# +------------------------------------+
+#  Thank you users! We ❤️ you! - Krrish & Ishaan
+## Reject a call if it contains a prompt injection attack.
+
+
+from difflib import SequenceMatcher
+from typing import List, Literal, Optional
+
+from fastapi import HTTPException
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.litellm_core_utils.prompt_templates.factory import (
+    prompt_injection_detection_default_pt,
+)
+from litellm.proxy._types import LiteLLMPromptInjectionParams, UserAPIKeyAuth
+from litellm.router import Router
+from litellm.utils import get_formatted_prompt
+
+
+class _OPTIONAL_PromptInjectionDetection(CustomLogger):
+    # Class variables or attributes
+    def __init__(
+        self,
+        prompt_injection_params: Optional[LiteLLMPromptInjectionParams] = None,
+    ):
+        self.prompt_injection_params = prompt_injection_params
+        self.llm_router: Optional[Router] = None
+
+        self.verbs = [
+            "Ignore",
+            "Disregard",
+            "Skip",
+            "Forget",
+            "Neglect",
+            "Overlook",
+            "Omit",
+            "Bypass",
+            "Pay no attention to",
+            "Do not follow",
+            "Do not obey",
+        ]
+        self.adjectives = [
+            "",
+            "prior",
+            "previous",
+            "preceding",
+            "above",
+            "foregoing",
+            "earlier",
+            "initial",
+        ]
+        self.prepositions = [
+            "",
+            "and start over",
+            "and start anew",
+            "and begin afresh",
+            "and start from scratch",
+        ]
+
+    def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
+        if level == "INFO":
+            verbose_proxy_logger.info(print_statement)
+        elif level == "DEBUG":
+            verbose_proxy_logger.debug(print_statement)
+
+        if litellm.set_verbose is True:
+            print(print_statement)  # noqa
+
+    def update_environment(self, router: Optional[Router] = None):
+        self.llm_router = router
+
+        if (
+            self.prompt_injection_params is not None
+            and self.prompt_injection_params.llm_api_check is True
+        ):
+            if self.llm_router is None:
+                raise Exception(
+                    "PromptInjectionDetection: Model List not set. Required for Prompt Injection detection."
+                )
+
+            self.print_verbose(
+                f"model_names: {self.llm_router.model_names}; self.prompt_injection_params.llm_api_name: {self.prompt_injection_params.llm_api_name}"
+            )
+            if (
+                self.prompt_injection_params.llm_api_name is None
+                or self.prompt_injection_params.llm_api_name
+                not in self.llm_router.model_names
+            ):
+                raise Exception(
+                    "PromptInjectionDetection: Invalid LLM API Name. LLM API Name must be a 'model_name' in 'model_list'."
+                )
+
+    def generate_injection_keywords(self) -> List[str]:
+        combinations = []
+        for verb in self.verbs:
+            for adj in self.adjectives:
+                for prep in self.prepositions:
+                    phrase = " ".join(filter(None, [verb, adj, prep])).strip()
+                    if (
+                        len(phrase.split()) > 2
+                    ):  # additional check to ensure more than 2 words
+                        combinations.append(phrase.lower())
+        return combinations
+
+    def check_user_input_similarity(
+        self, user_input: str, similarity_threshold: float = 0.7
+    ) -> bool:
+        user_input_lower = user_input.lower()
+        keywords = self.generate_injection_keywords()
+
+        for keyword in keywords:
+            # Calculate the length of the keyword to extract substrings of the same length from user input
+            keyword_length = len(keyword)
+
+            for i in range(len(user_input_lower) - keyword_length + 1):
+                # Extract a substring of the same length as the keyword
+                substring = user_input_lower[i : i + keyword_length]
+
+                # Calculate similarity
+                match_ratio = SequenceMatcher(None, substring, keyword).ratio()
+                if match_ratio > similarity_threshold:
+                    self.print_verbose(
+                        print_statement=f"Rejected user input - {user_input}. {match_ratio} similar to {keyword}",
+                        level="INFO",
+                    )
+                    return True  # Found a highly similar substring
+        return False  # No substring crossed the threshold
+
+    async def async_pre_call_hook(
+        self,
+        user_api_key_dict: UserAPIKeyAuth,
+        cache: DualCache,
+        data: dict,
+        call_type: str,  # "completion", "embeddings", "image_generation", "moderation"
+    ):
+        try:
+            """
+            - check if user id part of call
+            - check if user id part of blocked list
+            """
+            self.print_verbose("Inside Prompt Injection Detection Pre-Call Hook")
+            try:
+                assert call_type in [
+                    "completion",
+                    "text_completion",
+                    "embeddings",
+                    "image_generation",
+                    "moderation",
+                    "audio_transcription",
+                ]
+            except Exception:
+                self.print_verbose(
+                    f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']"
+                )
+                return data
+            formatted_prompt = get_formatted_prompt(data=data, call_type=call_type)  # type: ignore
+
+            is_prompt_attack = False
+
+            if self.prompt_injection_params is not None:
+                # 1. check if heuristics check turned on
+                if self.prompt_injection_params.heuristics_check is True:
+                    is_prompt_attack = self.check_user_input_similarity(
+                        user_input=formatted_prompt
+                    )
+                    if is_prompt_attack is True:
+                        raise HTTPException(
+                            status_code=400,
+                            detail={
+                                "error": "Rejected message. This is a prompt injection attack."
+                            },
+                        )
+                # 2. check if vector db similarity check turned on [TODO] Not Implemented yet
+                if self.prompt_injection_params.vector_db_check is True:
+                    pass
+            else:
+                is_prompt_attack = self.check_user_input_similarity(
+                    user_input=formatted_prompt
+                )
+
+            if is_prompt_attack is True:
+                raise HTTPException(
+                    status_code=400,
+                    detail={
+                        "error": "Rejected message. This is a prompt injection attack."
+                    },
+                )
+
+            return data
+
+        except HTTPException as e:
+
+            if (
+                e.status_code == 400
+                and isinstance(e.detail, dict)
+                and "error" in e.detail  # type: ignore
+                and self.prompt_injection_params is not None
+                and self.prompt_injection_params.reject_as_response
+            ):
+                return e.detail.get("error")
+            raise e
+        except Exception as e:
+            verbose_proxy_logger.exception(
+                "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
+                    str(e)
+                )
+            )
+
+    async def async_moderation_hook(  # type: ignore
+        self,
+        data: dict,
+        user_api_key_dict: UserAPIKeyAuth,
+        call_type: Literal[
+            "completion",
+            "embeddings",
+            "image_generation",
+            "moderation",
+            "audio_transcription",
+        ],
+    ) -> Optional[bool]:
+        self.print_verbose(
+            f"IN ASYNC MODERATION HOOK - self.prompt_injection_params = {self.prompt_injection_params}"
+        )
+
+        if self.prompt_injection_params is None:
+            return None
+
+        formatted_prompt = get_formatted_prompt(data=data, call_type=call_type)  # type: ignore
+        is_prompt_attack = False
+
+        prompt_injection_system_prompt = getattr(
+            self.prompt_injection_params,
+            "llm_api_system_prompt",
+            prompt_injection_detection_default_pt(),
+        )
+
+        # 3. check if llm api check turned on
+        if (
+            self.prompt_injection_params.llm_api_check is True
+            and self.prompt_injection_params.llm_api_name is not None
+            and self.llm_router is not None
+        ):
+            # make a call to the llm api
+            response = await self.llm_router.acompletion(
+                model=self.prompt_injection_params.llm_api_name,
+                messages=[
+                    {
+                        "role": "system",
+                        "content": prompt_injection_system_prompt,
+                    },
+                    {"role": "user", "content": formatted_prompt},
+                ],
+            )
+
+            self.print_verbose(f"Received LLM Moderation response: {response}")
+            self.print_verbose(
+                f"llm_api_fail_call_string: {self.prompt_injection_params.llm_api_fail_call_string}"
+            )
+            if isinstance(response, litellm.ModelResponse) and isinstance(
+                response.choices[0], litellm.Choices
+            ):
+                if self.prompt_injection_params.llm_api_fail_call_string in response.choices[0].message.content:  # type: ignore
+                    is_prompt_attack = True
+
+        if is_prompt_attack is True:
+            raise HTTPException(
+                status_code=400,
+                detail={
+                    "error": "Rejected message. This is a prompt injection attack."
+                },
+            )
+
+        return is_prompt_attack