aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py117
1 files changed, 117 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py b/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py
new file mode 100644
index 00000000..87860477
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py
@@ -0,0 +1,117 @@
+from typing import Literal, Optional, Union
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_guardrail import (
+ CustomGuardrail,
+ log_guardrail_information,
+)
+from litellm.proxy._types import UserAPIKeyAuth
+
+
+class myCustomGuardrail(CustomGuardrail):
+ def __init__(
+ self,
+ **kwargs,
+ ):
+ # store kwargs as optional_params
+ self.optional_params = kwargs
+
+ super().__init__(**kwargs)
+
+ @log_guardrail_information
+ async def async_pre_call_hook(
+ self,
+ user_api_key_dict: UserAPIKeyAuth,
+ cache: DualCache,
+ data: dict,
+ call_type: Literal[
+ "completion",
+ "text_completion",
+ "embeddings",
+ "image_generation",
+ "moderation",
+ "audio_transcription",
+ "pass_through_endpoint",
+ "rerank",
+ ],
+ ) -> Optional[Union[Exception, str, dict]]:
+ """
+ Runs before the LLM API call
+ Runs on only Input
+ Use this if you want to MODIFY the input
+ """
+
+ # In this guardrail, if a user inputs `litellm` we will mask it and then send it to the LLM
+ _messages = data.get("messages")
+ if _messages:
+ for message in _messages:
+ _content = message.get("content")
+ if isinstance(_content, str):
+ if "litellm" in _content.lower():
+ _content = _content.replace("litellm", "********")
+ message["content"] = _content
+
+ verbose_proxy_logger.debug(
+ "async_pre_call_hook: Message after masking %s", _messages
+ )
+
+ return data
+
+ @log_guardrail_information
+ async def async_moderation_hook(
+ self,
+ data: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ call_type: Literal[
+ "completion",
+ "embeddings",
+ "image_generation",
+ "moderation",
+ "audio_transcription",
+ "responses",
+ ],
+ ):
+ """
+ Runs in parallel to LLM API call
+ Runs on only Input
+
+ This can NOT modify the input, only used to reject or accept a call before going to LLM API
+ """
+
+ # this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call
+ # In this guardrail, if a user inputs `litellm` we will mask it.
+ _messages = data.get("messages")
+ if _messages:
+ for message in _messages:
+ _content = message.get("content")
+ if isinstance(_content, str):
+ if "litellm" in _content.lower():
+ raise ValueError("Guardrail failed words - `litellm` detected")
+
+ @log_guardrail_information
+ async def async_post_call_success_hook(
+ self,
+ data: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ response,
+ ):
+ """
+ Runs on response from LLM API call
+
+ It can be used to reject a response
+
+ If a response contains the word "coffee" -> we will raise an exception
+ """
+ verbose_proxy_logger.debug("async_pre_call_hook response: %s", response)
+ if isinstance(response, litellm.ModelResponse):
+ for choice in response.choices:
+ if isinstance(choice, litellm.Choices):
+ verbose_proxy_logger.debug("async_pre_call_hook choice: %s", choice)
+ if (
+ choice.message.content
+ and isinstance(choice.message.content, str)
+ and "coffee" in choice.message.content
+ ):
+ raise ValueError("Guardrail failed Coffee Detected")