aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/guardrails_ai.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/guardrails_ai.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/guardrails_ai.py114
1 files changed, 114 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/guardrails_ai.py b/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/guardrails_ai.py
new file mode 100644
index 00000000..1a2c5a21
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/guardrails_ai.py
@@ -0,0 +1,114 @@
+# +-------------------------------------------------------------+
+#
+# Use GuardrailsAI for your LLM calls
+#
+# +-------------------------------------------------------------+
+# Thank you for using Litellm! - Krrish & Ishaan
+
+import json
+from typing import Optional, TypedDict
+
+from fastapi import HTTPException
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.integrations.custom_guardrail import (
+ CustomGuardrail,
+ log_guardrail_information,
+)
+from litellm.litellm_core_utils.prompt_templates.common_utils import (
+ get_content_from_model_response,
+)
+from litellm.proxy._types import UserAPIKeyAuth
+from litellm.proxy.common_utils.callback_utils import (
+ add_guardrail_to_applied_guardrails_header,
+)
+from litellm.types.guardrails import GuardrailEventHooks
+
+
+class GuardrailsAIResponse(TypedDict):
+ callId: str
+ rawLlmOutput: str
+ validatedOutput: str
+ validationPassed: bool
+
+
+class GuardrailsAI(CustomGuardrail):
+ def __init__(
+ self,
+ guard_name: str,
+ api_base: Optional[str] = None,
+ **kwargs,
+ ):
+ if guard_name is None:
+ raise Exception(
+ "GuardrailsAIException - Please pass the Guardrails AI guard name via 'litellm_params::guard_name'"
+ )
+ # store kwargs as optional_params
+ self.guardrails_ai_api_base = api_base or "http://0.0.0.0:8000"
+ self.guardrails_ai_guard_name = guard_name
+ self.optional_params = kwargs
+ supported_event_hooks = [GuardrailEventHooks.post_call]
+ super().__init__(supported_event_hooks=supported_event_hooks, **kwargs)
+
+ async def make_guardrails_ai_api_request(self, llm_output: str, request_data: dict):
+ from httpx import URL
+
+ data = {
+ "llmOutput": llm_output,
+ **self.get_guardrail_dynamic_request_body_params(request_data=request_data),
+ }
+ _json_data = json.dumps(data)
+ response = await litellm.module_level_aclient.post(
+ url=str(
+ URL(self.guardrails_ai_api_base).join(
+ f"guards/{self.guardrails_ai_guard_name}/validate"
+ )
+ ),
+ data=_json_data,
+ headers={
+ "Content-Type": "application/json",
+ },
+ )
+ verbose_proxy_logger.debug("guardrails_ai response: %s", response)
+ _json_response = GuardrailsAIResponse(**response.json()) # type: ignore
+ if _json_response.get("validationPassed") is False:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "Violated guardrail policy",
+ "guardrails_ai_response": _json_response,
+ },
+ )
+ return _json_response
+
+ @log_guardrail_information
+ async def async_post_call_success_hook(
+ self,
+ data: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ response,
+ ):
+ """
+ Runs on response from LLM API call
+
+ It can be used to reject a response
+ """
+ event_type: GuardrailEventHooks = GuardrailEventHooks.post_call
+ if self.should_run_guardrail(data=data, event_type=event_type) is not True:
+ return
+
+ if not isinstance(response, litellm.ModelResponse):
+ return
+
+ response_str: str = get_content_from_model_response(response)
+ if response_str is not None and len(response_str) > 0:
+ await self.make_guardrails_ai_api_request(
+ llm_output=response_str, request_data=data
+ )
+
+ add_guardrail_to_applied_guardrails_header(
+ request_data=data, guardrail_name=self.guardrail_name
+ )
+
+ return