diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/integrations/custom_guardrail.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/integrations/custom_guardrail.py | 274 |
1 files changed, 274 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/custom_guardrail.py b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_guardrail.py new file mode 100644 index 00000000..4421664b --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_guardrail.py @@ -0,0 +1,274 @@ +from typing import Dict, List, Literal, Optional, Union + +from litellm._logging import verbose_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.types.guardrails import DynamicGuardrailParams, GuardrailEventHooks +from litellm.types.utils import StandardLoggingGuardrailInformation + + +class CustomGuardrail(CustomLogger): + + def __init__( + self, + guardrail_name: Optional[str] = None, + supported_event_hooks: Optional[List[GuardrailEventHooks]] = None, + event_hook: Optional[ + Union[GuardrailEventHooks, List[GuardrailEventHooks]] + ] = None, + default_on: bool = False, + **kwargs, + ): + """ + Initialize the CustomGuardrail class + + Args: + guardrail_name: The name of the guardrail. This is the name used in your requests. + supported_event_hooks: The event hooks that the guardrail supports + event_hook: The event hook to run the guardrail on + default_on: If True, the guardrail will be run by default on all requests + """ + self.guardrail_name = guardrail_name + self.supported_event_hooks = supported_event_hooks + self.event_hook: Optional[ + Union[GuardrailEventHooks, List[GuardrailEventHooks]] + ] = event_hook + self.default_on: bool = default_on + + if supported_event_hooks: + ## validate event_hook is in supported_event_hooks + self._validate_event_hook(event_hook, supported_event_hooks) + super().__init__(**kwargs) + + def _validate_event_hook( + self, + event_hook: Optional[Union[GuardrailEventHooks, List[GuardrailEventHooks]]], + supported_event_hooks: List[GuardrailEventHooks], + ) -> None: + if event_hook is None: + return + if isinstance(event_hook, list): + for hook in event_hook: + if hook not in supported_event_hooks: + raise ValueError( + f"Event hook {hook} is not in the supported event hooks {supported_event_hooks}" + ) + elif isinstance(event_hook, GuardrailEventHooks): + if event_hook not in supported_event_hooks: + raise ValueError( + f"Event hook {event_hook} is not in the supported event hooks {supported_event_hooks}" + ) + + def get_guardrail_from_metadata( + self, data: dict + ) -> Union[List[str], List[Dict[str, DynamicGuardrailParams]]]: + """ + Returns the guardrail(s) to be run from the metadata + """ + metadata = data.get("metadata") or {} + requested_guardrails = metadata.get("guardrails") or [] + return requested_guardrails + + def _guardrail_is_in_requested_guardrails( + self, + requested_guardrails: Union[List[str], List[Dict[str, DynamicGuardrailParams]]], + ) -> bool: + for _guardrail in requested_guardrails: + if isinstance(_guardrail, dict): + if self.guardrail_name in _guardrail: + return True + elif isinstance(_guardrail, str): + if self.guardrail_name == _guardrail: + return True + return False + + def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool: + """ + Returns True if the guardrail should be run on the event_type + """ + requested_guardrails = self.get_guardrail_from_metadata(data) + + verbose_logger.debug( + "inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s self.default_on= %s", + self.guardrail_name, + event_type, + self.event_hook, + requested_guardrails, + self.default_on, + ) + + if self.default_on is True: + if self._event_hook_is_event_type(event_type): + return True + return False + + if ( + self.event_hook + and not self._guardrail_is_in_requested_guardrails(requested_guardrails) + and event_type.value != "logging_only" + ): + return False + + if not self._event_hook_is_event_type(event_type): + return False + + return True + + def _event_hook_is_event_type(self, event_type: GuardrailEventHooks) -> bool: + """ + Returns True if the event_hook is the same as the event_type + + eg. if `self.event_hook == "pre_call" and event_type == "pre_call"` -> then True + eg. if `self.event_hook == "pre_call" and event_type == "post_call"` -> then False + """ + + if self.event_hook is None: + return True + if isinstance(self.event_hook, list): + return event_type.value in self.event_hook + return self.event_hook == event_type.value + + def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict: + """ + Returns `extra_body` to be added to the request body for the Guardrail API call + + Use this to pass dynamic params to the guardrail API call - eg. success_threshold, failure_threshold, etc. + + ``` + [{"lakera_guard": {"extra_body": {"foo": "bar"}}}] + ``` + + Will return: for guardrail=`lakera-guard`: + { + "foo": "bar" + } + + Args: + request_data: The original `request_data` passed to LiteLLM Proxy + """ + requested_guardrails = self.get_guardrail_from_metadata(request_data) + + # Look for the guardrail configuration matching self.guardrail_name + for guardrail in requested_guardrails: + if isinstance(guardrail, dict) and self.guardrail_name in guardrail: + # Get the configuration for this guardrail + guardrail_config: DynamicGuardrailParams = DynamicGuardrailParams( + **guardrail[self.guardrail_name] + ) + if self._validate_premium_user() is not True: + return {} + + # Return the extra_body if it exists, otherwise empty dict + return guardrail_config.get("extra_body", {}) + + return {} + + def _validate_premium_user(self) -> bool: + """ + Returns True if the user is a premium user + """ + from litellm.proxy.proxy_server import CommonProxyErrors, premium_user + + if premium_user is not True: + verbose_logger.warning( + f"Trying to use premium guardrail without premium user {CommonProxyErrors.not_premium_user.value}" + ) + return False + return True + + def add_standard_logging_guardrail_information_to_request_data( + self, + guardrail_json_response: Union[Exception, str, dict], + request_data: dict, + guardrail_status: Literal["success", "failure"], + ) -> None: + """ + Builds `StandardLoggingGuardrailInformation` and adds it to the request metadata so it can be used for logging to DataDog, Langfuse, etc. + """ + from litellm.proxy.proxy_server import premium_user + + if premium_user is not True: + verbose_logger.warning( + f"Guardrail Tracing is only available for premium users. Skipping guardrail logging for guardrail={self.guardrail_name} event_hook={self.event_hook}" + ) + return + if isinstance(guardrail_json_response, Exception): + guardrail_json_response = str(guardrail_json_response) + slg = StandardLoggingGuardrailInformation( + guardrail_name=self.guardrail_name, + guardrail_mode=self.event_hook, + guardrail_response=guardrail_json_response, + guardrail_status=guardrail_status, + ) + if "metadata" in request_data: + request_data["metadata"]["standard_logging_guardrail_information"] = slg + elif "litellm_metadata" in request_data: + request_data["litellm_metadata"][ + "standard_logging_guardrail_information" + ] = slg + else: + verbose_logger.warning( + "unable to log guardrail information. No metadata found in request_data" + ) + + +def log_guardrail_information(func): + """ + Decorator to add standard logging guardrail information to any function + + Add this decorator to ensure your guardrail response is logged to DataDog, OTEL, s3, GCS etc. + + Logs for: + - pre_call + - during_call + - TODO: log post_call. This is more involved since the logs are sent to DD, s3 before the guardrail is even run + """ + import asyncio + import functools + + def process_response(self, response, request_data): + self.add_standard_logging_guardrail_information_to_request_data( + guardrail_json_response=response, + request_data=request_data, + guardrail_status="success", + ) + return response + + def process_error(self, e, request_data): + self.add_standard_logging_guardrail_information_to_request_data( + guardrail_json_response=e, + request_data=request_data, + guardrail_status="failure", + ) + raise e + + @functools.wraps(func) + async def async_wrapper(*args, **kwargs): + self: CustomGuardrail = args[0] + request_data: Optional[dict] = ( + kwargs.get("data") or kwargs.get("request_data") or {} + ) + try: + response = await func(*args, **kwargs) + return process_response(self, response, request_data) + except Exception as e: + return process_error(self, e, request_data) + + @functools.wraps(func) + def sync_wrapper(*args, **kwargs): + self: CustomGuardrail = args[0] + request_data: Optional[dict] = ( + kwargs.get("data") or kwargs.get("request_data") or {} + ) + try: + response = func(*args, **kwargs) + return process_response(self, response, request_data) + except Exception as e: + return process_error(self, e, request_data) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if asyncio.iscoroutinefunction(func): + return async_wrapper(*args, **kwargs) + return sync_wrapper(*args, **kwargs) + + return wrapper |