diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/presidio.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/presidio.py | 390 |
1 files changed, 390 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/presidio.py b/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/presidio.py new file mode 100644 index 00000000..86d2c8b2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/presidio.py @@ -0,0 +1,390 @@ +# +-----------------------------------------------+ +# | | +# | PII Masking | +# | with Microsoft Presidio | +# | https://github.com/BerriAI/litellm/issues/ | +# +-----------------------------------------------+ +# +# Tell us how we can improve! - Krrish & Ishaan + + +import asyncio +import json +import uuid +from typing import Any, List, Optional, Tuple, Union + +import aiohttp +from pydantic import BaseModel + +import litellm # noqa: E401 +from litellm import get_secret +from litellm._logging import verbose_proxy_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_guardrail import ( + CustomGuardrail, + log_guardrail_information, +) +from litellm.proxy._types import UserAPIKeyAuth +from litellm.types.guardrails import GuardrailEventHooks +from litellm.utils import ( + EmbeddingResponse, + ImageResponse, + ModelResponse, + StreamingChoices, +) + + +class PresidioPerRequestConfig(BaseModel): + """ + presdio params that can be controlled per request, api key + """ + + language: Optional[str] = None + + +class _OPTIONAL_PresidioPIIMasking(CustomGuardrail): + user_api_key_cache = None + ad_hoc_recognizers = None + + # Class variables or attributes + def __init__( + self, + mock_testing: bool = False, + mock_redacted_text: Optional[dict] = None, + presidio_analyzer_api_base: Optional[str] = None, + presidio_anonymizer_api_base: Optional[str] = None, + output_parse_pii: Optional[bool] = False, + presidio_ad_hoc_recognizers: Optional[str] = None, + logging_only: Optional[bool] = None, + **kwargs, + ): + if logging_only is True: + self.logging_only = True + kwargs["event_hook"] = GuardrailEventHooks.logging_only + super().__init__(**kwargs) + self.pii_tokens: dict = ( + {} + ) # mapping of PII token to original text - only used with Presidio `replace` operation + self.mock_redacted_text = mock_redacted_text + self.output_parse_pii = output_parse_pii or False + if mock_testing is True: # for testing purposes only + return + + ad_hoc_recognizers = presidio_ad_hoc_recognizers + if ad_hoc_recognizers is not None: + try: + with open(ad_hoc_recognizers, "r") as file: + self.ad_hoc_recognizers = json.load(file) + except FileNotFoundError: + raise Exception(f"File not found. file_path={ad_hoc_recognizers}") + except json.JSONDecodeError as e: + raise Exception( + f"Error decoding JSON file: {str(e)}, file_path={ad_hoc_recognizers}" + ) + except Exception as e: + raise Exception( + f"An error occurred: {str(e)}, file_path={ad_hoc_recognizers}" + ) + self.validate_environment( + presidio_analyzer_api_base=presidio_analyzer_api_base, + presidio_anonymizer_api_base=presidio_anonymizer_api_base, + ) + + def validate_environment( + self, + presidio_analyzer_api_base: Optional[str] = None, + presidio_anonymizer_api_base: Optional[str] = None, + ): + self.presidio_analyzer_api_base: Optional[str] = ( + presidio_analyzer_api_base or get_secret("PRESIDIO_ANALYZER_API_BASE", None) # type: ignore + ) + self.presidio_anonymizer_api_base: Optional[ + str + ] = presidio_anonymizer_api_base or litellm.get_secret( + "PRESIDIO_ANONYMIZER_API_BASE", None + ) # type: ignore + + if self.presidio_analyzer_api_base is None: + raise Exception("Missing `PRESIDIO_ANALYZER_API_BASE` from environment") + if not self.presidio_analyzer_api_base.endswith("/"): + self.presidio_analyzer_api_base += "/" + if not ( + self.presidio_analyzer_api_base.startswith("http://") + or self.presidio_analyzer_api_base.startswith("https://") + ): + # add http:// if unset, assume communicating over private network - e.g. render + self.presidio_analyzer_api_base = ( + "http://" + self.presidio_analyzer_api_base + ) + + if self.presidio_anonymizer_api_base is None: + raise Exception("Missing `PRESIDIO_ANONYMIZER_API_BASE` from environment") + if not self.presidio_anonymizer_api_base.endswith("/"): + self.presidio_anonymizer_api_base += "/" + if not ( + self.presidio_anonymizer_api_base.startswith("http://") + or self.presidio_anonymizer_api_base.startswith("https://") + ): + # add http:// if unset, assume communicating over private network - e.g. render + self.presidio_anonymizer_api_base = ( + "http://" + self.presidio_anonymizer_api_base + ) + + async def check_pii( + self, + text: str, + output_parse_pii: bool, + presidio_config: Optional[PresidioPerRequestConfig], + request_data: dict, + ) -> str: + """ + [TODO] make this more performant for high-throughput scenario + """ + try: + async with aiohttp.ClientSession() as session: + if self.mock_redacted_text is not None: + redacted_text = self.mock_redacted_text + else: + # Make the first request to /analyze + # Construct Request 1 + analyze_url = f"{self.presidio_analyzer_api_base}analyze" + analyze_payload = {"text": text, "language": "en"} + if presidio_config and presidio_config.language: + analyze_payload["language"] = presidio_config.language + if self.ad_hoc_recognizers is not None: + analyze_payload["ad_hoc_recognizers"] = self.ad_hoc_recognizers + # End of constructing Request 1 + analyze_payload.update( + self.get_guardrail_dynamic_request_body_params( + request_data=request_data + ) + ) + redacted_text = None + verbose_proxy_logger.debug( + "Making request to: %s with payload: %s", + analyze_url, + analyze_payload, + ) + async with session.post( + analyze_url, json=analyze_payload + ) as response: + + analyze_results = await response.json() + + # Make the second request to /anonymize + anonymize_url = f"{self.presidio_anonymizer_api_base}anonymize" + verbose_proxy_logger.debug("Making request to: %s", anonymize_url) + anonymize_payload = { + "text": text, + "analyzer_results": analyze_results, + } + + async with session.post( + anonymize_url, json=anonymize_payload + ) as response: + redacted_text = await response.json() + + new_text = text + if redacted_text is not None: + verbose_proxy_logger.debug("redacted_text: %s", redacted_text) + for item in redacted_text["items"]: + start = item["start"] + end = item["end"] + replacement = item["text"] # replacement token + if item["operator"] == "replace" and output_parse_pii is True: + # check if token in dict + # if exists, add a uuid to the replacement token for swapping back to the original text in llm response output parsing + if replacement in self.pii_tokens: + replacement = replacement + str(uuid.uuid4()) + + self.pii_tokens[replacement] = new_text[ + start:end + ] # get text it'll replace + + new_text = new_text[:start] + replacement + new_text[end:] + return redacted_text["text"] + else: + raise Exception(f"Invalid anonymizer response: {redacted_text}") + except Exception as e: + raise e + + @log_guardrail_information + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: str, + ): + """ + - Check if request turned off pii + - Check if user allowed to turn off pii (key permissions -> 'allow_pii_controls') + + - Take the request data + - Call /analyze -> get the results + - Call /anonymize w/ the analyze results -> get the redacted text + + For multiple messages in /chat/completions, we'll need to call them in parallel. + """ + + try: + + content_safety = data.get("content_safety", None) + verbose_proxy_logger.debug("content_safety: %s", content_safety) + presidio_config = self.get_presidio_settings_from_request_data(data) + + if call_type == "completion": # /chat/completions requests + messages = data["messages"] + tasks = [] + + for m in messages: + if isinstance(m["content"], str): + tasks.append( + self.check_pii( + text=m["content"], + output_parse_pii=self.output_parse_pii, + presidio_config=presidio_config, + request_data=data, + ) + ) + responses = await asyncio.gather(*tasks) + for index, r in enumerate(responses): + if isinstance(messages[index]["content"], str): + messages[index][ + "content" + ] = r # replace content with redacted string + verbose_proxy_logger.info( + f"Presidio PII Masking: Redacted pii message: {data['messages']}" + ) + data["messages"] = messages + return data + except Exception as e: + raise e + + @log_guardrail_information + def logging_hook( + self, kwargs: dict, result: Any, call_type: str + ) -> Tuple[dict, Any]: + from concurrent.futures import ThreadPoolExecutor + + def run_in_new_loop(): + """Run the coroutine in a new event loop within this thread.""" + new_loop = asyncio.new_event_loop() + try: + asyncio.set_event_loop(new_loop) + return new_loop.run_until_complete( + self.async_logging_hook( + kwargs=kwargs, result=result, call_type=call_type + ) + ) + finally: + new_loop.close() + asyncio.set_event_loop(None) + + try: + # First, try to get the current event loop + _ = asyncio.get_running_loop() + # If we're already in an event loop, run in a separate thread + # to avoid nested event loop issues + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(run_in_new_loop) + return future.result() + + except RuntimeError: + # No running event loop, we can safely run in this thread + return run_in_new_loop() + + @log_guardrail_information + async def async_logging_hook( + self, kwargs: dict, result: Any, call_type: str + ) -> Tuple[dict, Any]: + """ + Masks the input before logging to langfuse, datadog, etc. + """ + if ( + call_type == "completion" or call_type == "acompletion" + ): # /chat/completions requests + messages: Optional[List] = kwargs.get("messages", None) + tasks = [] + + if messages is None: + return kwargs, result + + presidio_config = self.get_presidio_settings_from_request_data(kwargs) + + for m in messages: + text_str = "" + if m["content"] is None: + continue + if isinstance(m["content"], str): + text_str = m["content"] + tasks.append( + self.check_pii( + text=text_str, + output_parse_pii=False, + presidio_config=presidio_config, + request_data=kwargs, + ) + ) # need to pass separately b/c presidio has context window limits + responses = await asyncio.gather(*tasks) + for index, r in enumerate(responses): + if isinstance(messages[index]["content"], str): + messages[index][ + "content" + ] = r # replace content with redacted string + verbose_proxy_logger.info( + f"Presidio PII Masking: Redacted pii message: {messages}" + ) + kwargs["messages"] = messages + + return kwargs, result + + @log_guardrail_information + async def async_post_call_success_hook( # type: ignore + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response: Union[ModelResponse, EmbeddingResponse, ImageResponse], + ): + """ + Output parse the response object to replace the masked tokens with user sent values + """ + verbose_proxy_logger.debug( + f"PII Masking Args: self.output_parse_pii={self.output_parse_pii}; type of response={type(response)}" + ) + + if self.output_parse_pii is False and litellm.output_parse_pii is False: + return response + + if isinstance(response, ModelResponse) and not isinstance( + response.choices[0], StreamingChoices + ): # /chat/completions requests + if isinstance(response.choices[0].message.content, str): + verbose_proxy_logger.debug( + f"self.pii_tokens: {self.pii_tokens}; initial response: {response.choices[0].message.content}" + ) + for key, value in self.pii_tokens.items(): + response.choices[0].message.content = response.choices[ + 0 + ].message.content.replace(key, value) + return response + + def get_presidio_settings_from_request_data( + self, data: dict + ) -> Optional[PresidioPerRequestConfig]: + if "metadata" in data: + _metadata = data["metadata"] + _guardrail_config = _metadata.get("guardrail_config") + if _guardrail_config: + _presidio_config = PresidioPerRequestConfig(**_guardrail_config) + return _presidio_config + + return None + + def print_verbose(self, print_statement): + try: + verbose_proxy_logger.debug(print_statement) + if litellm.set_verbose: + print(print_statement) # noqa + except Exception: + pass |