aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py156
1 files changed, 156 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py
new file mode 100644
index 00000000..b35d6711
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py
@@ -0,0 +1,156 @@
+import traceback
+from typing import Optional
+
+from fastapi import HTTPException
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.caching.caching import DualCache
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.proxy._types import UserAPIKeyAuth
+
+
+class _PROXY_AzureContentSafety(
+ CustomLogger
+): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
+ # Class variables or attributes
+
+ def __init__(self, endpoint, api_key, thresholds=None):
+ try:
+ from azure.ai.contentsafety.aio import ContentSafetyClient
+ from azure.ai.contentsafety.models import (
+ AnalyzeTextOptions,
+ AnalyzeTextOutputType,
+ TextCategory,
+ )
+ from azure.core.credentials import AzureKeyCredential
+ from azure.core.exceptions import HttpResponseError
+ except Exception as e:
+ raise Exception(
+ f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
+ )
+ self.endpoint = endpoint
+ self.api_key = api_key
+ self.text_category = TextCategory
+ self.analyze_text_options = AnalyzeTextOptions
+ self.analyze_text_output_type = AnalyzeTextOutputType
+ self.azure_http_error = HttpResponseError
+
+ self.thresholds = self._configure_thresholds(thresholds)
+
+ self.client = ContentSafetyClient(
+ self.endpoint, AzureKeyCredential(self.api_key)
+ )
+
+ def _configure_thresholds(self, thresholds=None):
+ default_thresholds = {
+ self.text_category.HATE: 4,
+ self.text_category.SELF_HARM: 4,
+ self.text_category.SEXUAL: 4,
+ self.text_category.VIOLENCE: 4,
+ }
+
+ if thresholds is None:
+ return default_thresholds
+
+ for key, default in default_thresholds.items():
+ if key not in thresholds:
+ thresholds[key] = default
+
+ return thresholds
+
+ def _compute_result(self, response):
+ result = {}
+
+ category_severity = {
+ item.category: item.severity for item in response.categories_analysis
+ }
+ for category in self.text_category:
+ severity = category_severity.get(category)
+ if severity is not None:
+ result[category] = {
+ "filtered": severity >= self.thresholds[category],
+ "severity": severity,
+ }
+
+ return result
+
+ async def test_violation(self, content: str, source: Optional[str] = None):
+ verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content)
+
+ # Construct a request
+ request = self.analyze_text_options(
+ text=content,
+ output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS,
+ )
+
+ # Analyze text
+ try:
+ response = await self.client.analyze_text(request)
+ except self.azure_http_error:
+ verbose_proxy_logger.debug(
+ "Error in Azure Content-Safety: %s", traceback.format_exc()
+ )
+ verbose_proxy_logger.debug(traceback.format_exc())
+ raise
+
+ result = self._compute_result(response)
+ verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result)
+
+ for key, value in result.items():
+ if value["filtered"]:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "Violated content safety policy",
+ "source": source,
+ "category": key,
+ "severity": value["severity"],
+ },
+ )
+
+ async def async_pre_call_hook(
+ self,
+ user_api_key_dict: UserAPIKeyAuth,
+ cache: DualCache,
+ data: dict,
+ call_type: str, # "completion", "embeddings", "image_generation", "moderation"
+ ):
+ verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook")
+ try:
+ if call_type == "completion" and "messages" in data:
+ for m in data["messages"]:
+ if "content" in m and isinstance(m["content"], str):
+ await self.test_violation(content=m["content"], source="input")
+
+ except HTTPException as e:
+ raise e
+ except Exception as e:
+ verbose_proxy_logger.error(
+ "litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ verbose_proxy_logger.debug(traceback.format_exc())
+
+ async def async_post_call_success_hook(
+ self,
+ data: dict,
+ user_api_key_dict: UserAPIKeyAuth,
+ response,
+ ):
+ verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook")
+ if isinstance(response, litellm.ModelResponse) and isinstance(
+ response.choices[0], litellm.utils.Choices
+ ):
+ await self.test_violation(
+ content=response.choices[0].message.content or "", source="output"
+ )
+
+ # async def async_post_call_streaming_hook(
+ # self,
+ # user_api_key_dict: UserAPIKeyAuth,
+ # response: str,
+ # ):
+ # verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook")
+ # await self.test_violation(content=response, source="output")