diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py | 156 |
1 files changed, 156 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py new file mode 100644 index 00000000..b35d6711 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/azure_content_safety.py @@ -0,0 +1,156 @@ +import traceback +from typing import Optional + +from fastapi import HTTPException + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_logger import CustomLogger +from litellm.proxy._types import UserAPIKeyAuth + + +class _PROXY_AzureContentSafety( + CustomLogger +): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class + # Class variables or attributes + + def __init__(self, endpoint, api_key, thresholds=None): + try: + from azure.ai.contentsafety.aio import ContentSafetyClient + from azure.ai.contentsafety.models import ( + AnalyzeTextOptions, + AnalyzeTextOutputType, + TextCategory, + ) + from azure.core.credentials import AzureKeyCredential + from azure.core.exceptions import HttpResponseError + except Exception as e: + raise Exception( + f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m" + ) + self.endpoint = endpoint + self.api_key = api_key + self.text_category = TextCategory + self.analyze_text_options = AnalyzeTextOptions + self.analyze_text_output_type = AnalyzeTextOutputType + self.azure_http_error = HttpResponseError + + self.thresholds = self._configure_thresholds(thresholds) + + self.client = ContentSafetyClient( + self.endpoint, AzureKeyCredential(self.api_key) + ) + + def _configure_thresholds(self, thresholds=None): + default_thresholds = { + self.text_category.HATE: 4, + self.text_category.SELF_HARM: 4, + self.text_category.SEXUAL: 4, + self.text_category.VIOLENCE: 4, + } + + if thresholds is None: + return default_thresholds + + for key, default in default_thresholds.items(): + if key not in thresholds: + thresholds[key] = default + + return thresholds + + def _compute_result(self, response): + result = {} + + category_severity = { + item.category: item.severity for item in response.categories_analysis + } + for category in self.text_category: + severity = category_severity.get(category) + if severity is not None: + result[category] = { + "filtered": severity >= self.thresholds[category], + "severity": severity, + } + + return result + + async def test_violation(self, content: str, source: Optional[str] = None): + verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content) + + # Construct a request + request = self.analyze_text_options( + text=content, + output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS, + ) + + # Analyze text + try: + response = await self.client.analyze_text(request) + except self.azure_http_error: + verbose_proxy_logger.debug( + "Error in Azure Content-Safety: %s", traceback.format_exc() + ) + verbose_proxy_logger.debug(traceback.format_exc()) + raise + + result = self._compute_result(response) + verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result) + + for key, value in result.items(): + if value["filtered"]: + raise HTTPException( + status_code=400, + detail={ + "error": "Violated content safety policy", + "source": source, + "category": key, + "severity": value["severity"], + }, + ) + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, # "completion", "embeddings", "image_generation", "moderation" + ): + verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook") + try: + if call_type == "completion" and "messages" in data: + for m in data["messages"]: + if "content" in m and isinstance(m["content"], str): + await self.test_violation(content=m["content"], source="input") + + except HTTPException as e: + raise e + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + + async def async_post_call_success_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook") + if isinstance(response, litellm.ModelResponse) and isinstance( + response.choices[0], litellm.utils.Choices + ): + await self.test_violation( + content=response.choices[0].message.content or "", source="output" + ) + + # async def async_post_call_streaming_hook( + # self, + # user_api_key_dict: UserAPIKeyAuth, + # response: str, + # ): + # verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook") + # await self.test_violation(content=response, source="output") |