about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/aim.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/aim.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/aim.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/aim.py212
1 files changed, 212 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/aim.py b/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/aim.py
new file mode 100644
index 00000000..e1298b63
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/aim.py
@@ -0,0 +1,212 @@
+# +-------------------------------------------------------------+
+#
+#           Use Aim Security Guardrails for your LLM calls
+#                   https://www.aim.security/
+#
+# +-------------------------------------------------------------+
+import asyncio
+import json
+import os
+from typing import Any, AsyncGenerator, Literal, Optional, Union
+
+from fastapi import HTTPException
+from pydantic import BaseModel
+from websockets.asyncio.client import ClientConnection, connect
+
+from litellm import DualCache
+from litellm._logging import verbose_proxy_logger
+from litellm.integrations.custom_guardrail import CustomGuardrail
+from litellm.llms.custom_httpx.http_handler import (
+    get_async_httpx_client,
+    httpxSpecialProvider,
+)
+from litellm.proxy._types import UserAPIKeyAuth
+from litellm.proxy.proxy_server import StreamingCallbackError
+from litellm.types.utils import (
+    Choices,
+    EmbeddingResponse,
+    ImageResponse,
+    ModelResponse,
+    ModelResponseStream,
+)
+
+
+class AimGuardrailMissingSecrets(Exception):
+    pass
+
+
+class AimGuardrail(CustomGuardrail):
+    def __init__(
+        self, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
+    ):
+        self.async_handler = get_async_httpx_client(
+            llm_provider=httpxSpecialProvider.GuardrailCallback
+        )
+        self.api_key = api_key or os.environ.get("AIM_API_KEY")
+        if not self.api_key:
+            msg = (
+                "Couldn't get Aim api key, either set the `AIM_API_KEY` in the environment or "
+                "pass it as a parameter to the guardrail in the config file"
+            )
+            raise AimGuardrailMissingSecrets(msg)
+        self.api_base = (
+            api_base or os.environ.get("AIM_API_BASE") or "https://api.aim.security"
+        )
+        self.ws_api_base = self.api_base.replace("http://", "ws://").replace(
+            "https://", "wss://"
+        )
+        super().__init__(**kwargs)
+
+    async def async_pre_call_hook(
+        self,
+        user_api_key_dict: UserAPIKeyAuth,
+        cache: DualCache,
+        data: dict,
+        call_type: Literal[
+            "completion",
+            "text_completion",
+            "embeddings",
+            "image_generation",
+            "moderation",
+            "audio_transcription",
+            "pass_through_endpoint",
+            "rerank",
+        ],
+    ) -> Union[Exception, str, dict, None]:
+        verbose_proxy_logger.debug("Inside AIM Pre-Call Hook")
+
+        await self.call_aim_guardrail(data, hook="pre_call")
+        return data
+
+    async def async_moderation_hook(
+        self,
+        data: dict,
+        user_api_key_dict: UserAPIKeyAuth,
+        call_type: Literal[
+            "completion",
+            "embeddings",
+            "image_generation",
+            "moderation",
+            "audio_transcription",
+            "responses",
+        ],
+    ) -> Union[Exception, str, dict, None]:
+        verbose_proxy_logger.debug("Inside AIM Moderation Hook")
+
+        await self.call_aim_guardrail(data, hook="moderation")
+        return data
+
+    async def call_aim_guardrail(self, data: dict, hook: str) -> None:
+        user_email = data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
+        headers = {
+            "Authorization": f"Bearer {self.api_key}",
+            "x-aim-litellm-hook": hook,
+        } | ({"x-aim-user-email": user_email} if user_email else {})
+        response = await self.async_handler.post(
+            f"{self.api_base}/detect/openai",
+            headers=headers,
+            json={"messages": data.get("messages", [])},
+        )
+        response.raise_for_status()
+        res = response.json()
+        detected = res["detected"]
+        verbose_proxy_logger.info(
+            "Aim: detected: {detected}, enabled policies: {policies}".format(
+                detected=detected,
+                policies=list(res["details"].keys()),
+            ),
+        )
+        if detected:
+            raise HTTPException(status_code=400, detail=res["detection_message"])
+
+    async def call_aim_guardrail_on_output(
+        self, request_data: dict, output: str, hook: str
+    ) -> Optional[str]:
+        user_email = (
+            request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
+        )
+        headers = {
+            "Authorization": f"Bearer {self.api_key}",
+            "x-aim-litellm-hook": hook,
+        } | ({"x-aim-user-email": user_email} if user_email else {})
+        response = await self.async_handler.post(
+            f"{self.api_base}/detect/output",
+            headers=headers,
+            json={"output": output, "messages": request_data.get("messages", [])},
+        )
+        response.raise_for_status()
+        res = response.json()
+        detected = res["detected"]
+        verbose_proxy_logger.info(
+            "Aim: detected: {detected}, enabled policies: {policies}".format(
+                detected=detected,
+                policies=list(res["details"].keys()),
+            ),
+        )
+        if detected:
+            return res["detection_message"]
+        return None
+
+    async def async_post_call_success_hook(
+        self,
+        data: dict,
+        user_api_key_dict: UserAPIKeyAuth,
+        response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
+    ) -> Any:
+        if (
+            isinstance(response, ModelResponse)
+            and response.choices
+            and isinstance(response.choices[0], Choices)
+        ):
+            content = response.choices[0].message.content or ""
+            detection = await self.call_aim_guardrail_on_output(
+                data, content, hook="output"
+            )
+            if detection:
+                raise HTTPException(status_code=400, detail=detection)
+
+    async def async_post_call_streaming_iterator_hook(
+        self,
+        user_api_key_dict: UserAPIKeyAuth,
+        response,
+        request_data: dict,
+    ) -> AsyncGenerator[ModelResponseStream, None]:
+        user_email = (
+            request_data.get("metadata", {}).get("headers", {}).get("x-aim-user-email")
+        )
+        headers = {
+            "Authorization": f"Bearer {self.api_key}",
+        } | ({"x-aim-user-email": user_email} if user_email else {})
+        async with connect(
+            f"{self.ws_api_base}/detect/output/ws", additional_headers=headers
+        ) as websocket:
+            sender = asyncio.create_task(
+                self.forward_the_stream_to_aim(websocket, response)
+            )
+            while True:
+                result = json.loads(await websocket.recv())
+                if verified_chunk := result.get("verified_chunk"):
+                    yield ModelResponseStream.model_validate(verified_chunk)
+                else:
+                    sender.cancel()
+                    if result.get("done"):
+                        return
+                    if blocking_message := result.get("blocking_message"):
+                        raise StreamingCallbackError(blocking_message)
+                    verbose_proxy_logger.error(
+                        f"Unknown message received from AIM: {result}"
+                    )
+                    return
+
+    async def forward_the_stream_to_aim(
+        self,
+        websocket: ClientConnection,
+        response_iter,
+    ) -> None:
+        async for chunk in response_iter:
+            if isinstance(chunk, BaseModel):
+                chunk = chunk.model_dump_json()
+            if isinstance(chunk, dict):
+                chunk = json.dumps(chunk)
+            await websocket.send(chunk)
+        await websocket.send(json.dumps({"done": True}))