diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py | 117 |
1 files changed, 117 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py b/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py new file mode 100644 index 00000000..87860477 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/custom_guardrail.py @@ -0,0 +1,117 @@ +from typing import Literal, Optional, Union + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_guardrail import ( + CustomGuardrail, + log_guardrail_information, +) +from litellm.proxy._types import UserAPIKeyAuth + + +class myCustomGuardrail(CustomGuardrail): + def __init__( + self, + **kwargs, + ): + # store kwargs as optional_params + self.optional_params = kwargs + + super().__init__(**kwargs) + + @log_guardrail_information + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "pass_through_endpoint", + "rerank", + ], + ) -> Optional[Union[Exception, str, dict]]: + """ + Runs before the LLM API call + Runs on only Input + Use this if you want to MODIFY the input + """ + + # In this guardrail, if a user inputs `litellm` we will mask it and then send it to the LLM + _messages = data.get("messages") + if _messages: + for message in _messages: + _content = message.get("content") + if isinstance(_content, str): + if "litellm" in _content.lower(): + _content = _content.replace("litellm", "********") + message["content"] = _content + + verbose_proxy_logger.debug( + "async_pre_call_hook: Message after masking %s", _messages + ) + + return data + + @log_guardrail_information + async def async_moderation_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal[ + "completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "responses", + ], + ): + """ + Runs in parallel to LLM API call + Runs on only Input + + This can NOT modify the input, only used to reject or accept a call before going to LLM API + """ + + # this works the same as async_pre_call_hook, but just runs in parallel as the LLM API Call + # In this guardrail, if a user inputs `litellm` we will mask it. + _messages = data.get("messages") + if _messages: + for message in _messages: + _content = message.get("content") + if isinstance(_content, str): + if "litellm" in _content.lower(): + raise ValueError("Guardrail failed words - `litellm` detected") + + @log_guardrail_information + async def async_post_call_success_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + """ + Runs on response from LLM API call + + It can be used to reject a response + + If a response contains the word "coffee" -> we will raise an exception + """ + verbose_proxy_logger.debug("async_pre_call_hook response: %s", response) + if isinstance(response, litellm.ModelResponse): + for choice in response.choices: + if isinstance(choice, litellm.Choices): + verbose_proxy_logger.debug("async_pre_call_hook choice: %s", choice) + if ( + choice.message.content + and isinstance(choice.message.content, str) + and "coffee" in choice.message.content + ): + raise ValueError("Guardrail failed Coffee Detected") |