aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/integrations/custom_guardrail.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/integrations/custom_guardrail.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/custom_guardrail.py274
1 files changed, 274 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/custom_guardrail.py b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_guardrail.py
new file mode 100644
index 00000000..4421664b
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/custom_guardrail.py
@@ -0,0 +1,274 @@
+from typing import Dict, List, Literal, Optional, Union
+
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.types.guardrails import DynamicGuardrailParams, GuardrailEventHooks
+from litellm.types.utils import StandardLoggingGuardrailInformation
+
+
+class CustomGuardrail(CustomLogger):
+
+ def __init__(
+ self,
+ guardrail_name: Optional[str] = None,
+ supported_event_hooks: Optional[List[GuardrailEventHooks]] = None,
+ event_hook: Optional[
+ Union[GuardrailEventHooks, List[GuardrailEventHooks]]
+ ] = None,
+ default_on: bool = False,
+ **kwargs,
+ ):
+ """
+ Initialize the CustomGuardrail class
+
+ Args:
+ guardrail_name: The name of the guardrail. This is the name used in your requests.
+ supported_event_hooks: The event hooks that the guardrail supports
+ event_hook: The event hook to run the guardrail on
+ default_on: If True, the guardrail will be run by default on all requests
+ """
+ self.guardrail_name = guardrail_name
+ self.supported_event_hooks = supported_event_hooks
+ self.event_hook: Optional[
+ Union[GuardrailEventHooks, List[GuardrailEventHooks]]
+ ] = event_hook
+ self.default_on: bool = default_on
+
+ if supported_event_hooks:
+ ## validate event_hook is in supported_event_hooks
+ self._validate_event_hook(event_hook, supported_event_hooks)
+ super().__init__(**kwargs)
+
+ def _validate_event_hook(
+ self,
+ event_hook: Optional[Union[GuardrailEventHooks, List[GuardrailEventHooks]]],
+ supported_event_hooks: List[GuardrailEventHooks],
+ ) -> None:
+ if event_hook is None:
+ return
+ if isinstance(event_hook, list):
+ for hook in event_hook:
+ if hook not in supported_event_hooks:
+ raise ValueError(
+ f"Event hook {hook} is not in the supported event hooks {supported_event_hooks}"
+ )
+ elif isinstance(event_hook, GuardrailEventHooks):
+ if event_hook not in supported_event_hooks:
+ raise ValueError(
+ f"Event hook {event_hook} is not in the supported event hooks {supported_event_hooks}"
+ )
+
+ def get_guardrail_from_metadata(
+ self, data: dict
+ ) -> Union[List[str], List[Dict[str, DynamicGuardrailParams]]]:
+ """
+ Returns the guardrail(s) to be run from the metadata
+ """
+ metadata = data.get("metadata") or {}
+ requested_guardrails = metadata.get("guardrails") or []
+ return requested_guardrails
+
+ def _guardrail_is_in_requested_guardrails(
+ self,
+ requested_guardrails: Union[List[str], List[Dict[str, DynamicGuardrailParams]]],
+ ) -> bool:
+ for _guardrail in requested_guardrails:
+ if isinstance(_guardrail, dict):
+ if self.guardrail_name in _guardrail:
+ return True
+ elif isinstance(_guardrail, str):
+ if self.guardrail_name == _guardrail:
+ return True
+ return False
+
+ def should_run_guardrail(self, data, event_type: GuardrailEventHooks) -> bool:
+ """
+ Returns True if the guardrail should be run on the event_type
+ """
+ requested_guardrails = self.get_guardrail_from_metadata(data)
+
+ verbose_logger.debug(
+ "inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s self.default_on= %s",
+ self.guardrail_name,
+ event_type,
+ self.event_hook,
+ requested_guardrails,
+ self.default_on,
+ )
+
+ if self.default_on is True:
+ if self._event_hook_is_event_type(event_type):
+ return True
+ return False
+
+ if (
+ self.event_hook
+ and not self._guardrail_is_in_requested_guardrails(requested_guardrails)
+ and event_type.value != "logging_only"
+ ):
+ return False
+
+ if not self._event_hook_is_event_type(event_type):
+ return False
+
+ return True
+
+ def _event_hook_is_event_type(self, event_type: GuardrailEventHooks) -> bool:
+ """
+ Returns True if the event_hook is the same as the event_type
+
+ eg. if `self.event_hook == "pre_call" and event_type == "pre_call"` -> then True
+ eg. if `self.event_hook == "pre_call" and event_type == "post_call"` -> then False
+ """
+
+ if self.event_hook is None:
+ return True
+ if isinstance(self.event_hook, list):
+ return event_type.value in self.event_hook
+ return self.event_hook == event_type.value
+
+ def get_guardrail_dynamic_request_body_params(self, request_data: dict) -> dict:
+ """
+ Returns `extra_body` to be added to the request body for the Guardrail API call
+
+ Use this to pass dynamic params to the guardrail API call - eg. success_threshold, failure_threshold, etc.
+
+ ```
+ [{"lakera_guard": {"extra_body": {"foo": "bar"}}}]
+ ```
+
+ Will return: for guardrail=`lakera-guard`:
+ {
+ "foo": "bar"
+ }
+
+ Args:
+ request_data: The original `request_data` passed to LiteLLM Proxy
+ """
+ requested_guardrails = self.get_guardrail_from_metadata(request_data)
+
+ # Look for the guardrail configuration matching self.guardrail_name
+ for guardrail in requested_guardrails:
+ if isinstance(guardrail, dict) and self.guardrail_name in guardrail:
+ # Get the configuration for this guardrail
+ guardrail_config: DynamicGuardrailParams = DynamicGuardrailParams(
+ **guardrail[self.guardrail_name]
+ )
+ if self._validate_premium_user() is not True:
+ return {}
+
+ # Return the extra_body if it exists, otherwise empty dict
+ return guardrail_config.get("extra_body", {})
+
+ return {}
+
+ def _validate_premium_user(self) -> bool:
+ """
+ Returns True if the user is a premium user
+ """
+ from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
+
+ if premium_user is not True:
+ verbose_logger.warning(
+ f"Trying to use premium guardrail without premium user {CommonProxyErrors.not_premium_user.value}"
+ )
+ return False
+ return True
+
+ def add_standard_logging_guardrail_information_to_request_data(
+ self,
+ guardrail_json_response: Union[Exception, str, dict],
+ request_data: dict,
+ guardrail_status: Literal["success", "failure"],
+ ) -> None:
+ """
+ Builds `StandardLoggingGuardrailInformation` and adds it to the request metadata so it can be used for logging to DataDog, Langfuse, etc.
+ """
+ from litellm.proxy.proxy_server import premium_user
+
+ if premium_user is not True:
+ verbose_logger.warning(
+ f"Guardrail Tracing is only available for premium users. Skipping guardrail logging for guardrail={self.guardrail_name} event_hook={self.event_hook}"
+ )
+ return
+ if isinstance(guardrail_json_response, Exception):
+ guardrail_json_response = str(guardrail_json_response)
+ slg = StandardLoggingGuardrailInformation(
+ guardrail_name=self.guardrail_name,
+ guardrail_mode=self.event_hook,
+ guardrail_response=guardrail_json_response,
+ guardrail_status=guardrail_status,
+ )
+ if "metadata" in request_data:
+ request_data["metadata"]["standard_logging_guardrail_information"] = slg
+ elif "litellm_metadata" in request_data:
+ request_data["litellm_metadata"][
+ "standard_logging_guardrail_information"
+ ] = slg
+ else:
+ verbose_logger.warning(
+ "unable to log guardrail information. No metadata found in request_data"
+ )
+
+
+def log_guardrail_information(func):
+ """
+ Decorator to add standard logging guardrail information to any function
+
+ Add this decorator to ensure your guardrail response is logged to DataDog, OTEL, s3, GCS etc.
+
+ Logs for:
+ - pre_call
+ - during_call
+ - TODO: log post_call. This is more involved since the logs are sent to DD, s3 before the guardrail is even run
+ """
+ import asyncio
+ import functools
+
+ def process_response(self, response, request_data):
+ self.add_standard_logging_guardrail_information_to_request_data(
+ guardrail_json_response=response,
+ request_data=request_data,
+ guardrail_status="success",
+ )
+ return response
+
+ def process_error(self, e, request_data):
+ self.add_standard_logging_guardrail_information_to_request_data(
+ guardrail_json_response=e,
+ request_data=request_data,
+ guardrail_status="failure",
+ )
+ raise e
+
+ @functools.wraps(func)
+ async def async_wrapper(*args, **kwargs):
+ self: CustomGuardrail = args[0]
+ request_data: Optional[dict] = (
+ kwargs.get("data") or kwargs.get("request_data") or {}
+ )
+ try:
+ response = await func(*args, **kwargs)
+ return process_response(self, response, request_data)
+ except Exception as e:
+ return process_error(self, e, request_data)
+
+ @functools.wraps(func)
+ def sync_wrapper(*args, **kwargs):
+ self: CustomGuardrail = args[0]
+ request_data: Optional[dict] = (
+ kwargs.get("data") or kwargs.get("request_data") or {}
+ )
+ try:
+ response = func(*args, **kwargs)
+ return process_response(self, response, request_data)
+ except Exception as e:
+ return process_error(self, e, request_data)
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ if asyncio.iscoroutinefunction(func):
+ return async_wrapper(*args, **kwargs)
+ return sync_wrapper(*args, **kwargs)
+
+ return wrapper