diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/guardrails_ai.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/guardrails_ai.py | 114 |
1 files changed, 114 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/guardrails_ai.py b/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/guardrails_ai.py new file mode 100644 index 00000000..1a2c5a21 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/guardrails_ai.py @@ -0,0 +1,114 @@ +# +-------------------------------------------------------------+ +# +# Use GuardrailsAI for your LLM calls +# +# +-------------------------------------------------------------+ +# Thank you for using Litellm! - Krrish & Ishaan + +import json +from typing import Optional, TypedDict + +from fastapi import HTTPException + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.integrations.custom_guardrail import ( + CustomGuardrail, + log_guardrail_information, +) +from litellm.litellm_core_utils.prompt_templates.common_utils import ( + get_content_from_model_response, +) +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.common_utils.callback_utils import ( + add_guardrail_to_applied_guardrails_header, +) +from litellm.types.guardrails import GuardrailEventHooks + + +class GuardrailsAIResponse(TypedDict): + callId: str + rawLlmOutput: str + validatedOutput: str + validationPassed: bool + + +class GuardrailsAI(CustomGuardrail): + def __init__( + self, + guard_name: str, + api_base: Optional[str] = None, + **kwargs, + ): + if guard_name is None: + raise Exception( + "GuardrailsAIException - Please pass the Guardrails AI guard name via 'litellm_params::guard_name'" + ) + # store kwargs as optional_params + self.guardrails_ai_api_base = api_base or "http://0.0.0.0:8000" + self.guardrails_ai_guard_name = guard_name + self.optional_params = kwargs + supported_event_hooks = [GuardrailEventHooks.post_call] + super().__init__(supported_event_hooks=supported_event_hooks, **kwargs) + + async def make_guardrails_ai_api_request(self, llm_output: str, request_data: dict): + from httpx import URL + + data = { + "llmOutput": llm_output, + **self.get_guardrail_dynamic_request_body_params(request_data=request_data), + } + _json_data = json.dumps(data) + response = await litellm.module_level_aclient.post( + url=str( + URL(self.guardrails_ai_api_base).join( + f"guards/{self.guardrails_ai_guard_name}/validate" + ) + ), + data=_json_data, + headers={ + "Content-Type": "application/json", + }, + ) + verbose_proxy_logger.debug("guardrails_ai response: %s", response) + _json_response = GuardrailsAIResponse(**response.json()) # type: ignore + if _json_response.get("validationPassed") is False: + raise HTTPException( + status_code=400, + detail={ + "error": "Violated guardrail policy", + "guardrails_ai_response": _json_response, + }, + ) + return _json_response + + @log_guardrail_information + async def async_post_call_success_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + """ + Runs on response from LLM API call + + It can be used to reject a response + """ + event_type: GuardrailEventHooks = GuardrailEventHooks.post_call + if self.should_run_guardrail(data=data, event_type=event_type) is not True: + return + + if not isinstance(response, litellm.ModelResponse): + return + + response_str: str = get_content_from_model_response(response) + if response_str is not None and len(response_str) > 0: + await self.make_guardrails_ai_api_request( + llm_output=response_str, request_data=data + ) + + add_guardrail_to_applied_guardrails_header( + request_data=data, guardrail_name=self.guardrail_name + ) + + return |