about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py156
1 files changed, 156 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py
new file mode 100644
index 00000000..b35d6711
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py
@@ -0,0 +1,156 @@
+import traceback
+from typing import Optional
+
+from fastapi import HTTPException
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.proxy._types import UserAPIKeyAuth
+
+
+class _PROXY_AzureContentSafety(
+    CustomLogger
+):  # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
+    # Class variables or attributes
+
+    def __init__(self, endpoint, api_key, thresholds=None):
+        try:
+            from azure.ai.contentsafety.aio import ContentSafetyClient
+            from azure.ai.contentsafety.models import (
+                AnalyzeTextOptions,
+                AnalyzeTextOutputType,
+                TextCategory,
+            )
+            from azure.core.credentials import AzureKeyCredential
+            from azure.core.exceptions import HttpResponseError
+        except Exception as e:
+            raise Exception(
+                f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
+            )
+        self.endpoint = endpoint
+        self.api_key = api_key
+        self.text_category = TextCategory
+        self.analyze_text_options = AnalyzeTextOptions
+        self.analyze_text_output_type = AnalyzeTextOutputType
+        self.azure_http_error = HttpResponseError
+
+        self.thresholds = self._configure_thresholds(thresholds)
+
+        self.client = ContentSafetyClient(
+            self.endpoint, AzureKeyCredential(self.api_key)
+        )
+
+    def _configure_thresholds(self, thresholds=None):
+        default_thresholds = {
+            self.text_category.HATE: 4,
+            self.text_category.SELF_HARM: 4,
+            self.text_category.SEXUAL: 4,
+            self.text_category.VIOLENCE: 4,
+        }
+
+        if thresholds is None:
+            return default_thresholds
+
+        for key, default in default_thresholds.items():
+            if key not in thresholds:
+                thresholds[key] = default
+
+        return thresholds
+
+    def _compute_result(self, response):
+        result = {}
+
+        category_severity = {
+            item.category: item.severity for item in response.categories_analysis
+        }
+        for category in self.text_category:
+            severity = category_severity.get(category)
+            if severity is not None:
+                result[category] = {
+                    "filtered": severity >= self.thresholds[category],
+                    "severity": severity,
+                }
+
+        return result
+
+    async def test_violation(self, content: str, source: Optional[str] = None):
+        verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content)
+
+        # Construct a request
+        request = self.analyze_text_options(
+            text=content,
+            output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS,
+        )
+
+        # Analyze text
+        try:
+            response = await self.client.analyze_text(request)
+        except self.azure_http_error:
+            verbose_proxy_logger.debug(
+                "Error in Azure Content-Safety: %s", traceback.format_exc()
+            )
+            verbose_proxy_logger.debug(traceback.format_exc())
+            raise
+
+        result = self._compute_result(response)
+        verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result)
+
+        for key, value in result.items():
+            if value["filtered"]:
+                raise HTTPException(
+                    status_code=400,
+                    detail={
+                        "error": "Violated content safety policy",
+                        "source": source,
+                        "category": key,
+                        "severity": value["severity"],
+                    },
+                )
+
+    async def async_pre_call_hook(
+        self,
+        user_api_key_dict: UserAPIKeyAuth,
+        cache: DualCache,
+        data: dict,
+        call_type: str,  # "completion", "embeddings", "image_generation", "moderation"
+    ):
+        verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook")
+        try:
+            if call_type == "completion" and "messages" in data:
+                for m in data["messages"]:
+                    if "content" in m and isinstance(m["content"], str):
+                        await self.test_violation(content=m["content"], source="input")
+
+        except HTTPException as e:
+            raise e
+        except Exception as e:
+            verbose_proxy_logger.error(
+                "litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format(
+                    str(e)
+                )
+            )
+            verbose_proxy_logger.debug(traceback.format_exc())
+
+    async def async_post_call_success_hook(
+        self,
+        data: dict,
+        user_api_key_dict: UserAPIKeyAuth,
+        response,
+    ):
+        verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook")
+        if isinstance(response, litellm.ModelResponse) and isinstance(
+            response.choices[0], litellm.utils.Choices
+        ):
+            await self.test_violation(
+                content=response.choices[0].message.content or "", source="output"
+            )
+
+    # async def async_post_call_streaming_hook(
+    #    self,
+    #    user_api_key_dict: UserAPIKeyAuth,
+    #    response: str,
+    # ):
+    #    verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook")
+    #    await self.test_violation(content=response, source="output")