diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/init_guardrails.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/proxy/guardrails/init_guardrails.py | 171 |
1 files changed, 171 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/init_guardrails.py b/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/init_guardrails.py new file mode 100644 index 00000000..d08dcaba --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/init_guardrails.py @@ -0,0 +1,171 @@ +import importlib +import os +from typing import Dict, List, Optional + +import litellm +from litellm import get_secret +from litellm._logging import verbose_proxy_logger +from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy + +# v2 implementation +from litellm.types.guardrails import ( + Guardrail, + GuardrailItem, + GuardrailItemSpec, + LakeraCategoryThresholds, + LitellmParams, +) + +from .guardrail_registry import guardrail_registry + +all_guardrails: List[GuardrailItem] = [] + + +def initialize_guardrails( + guardrails_config: List[Dict[str, GuardrailItemSpec]], + premium_user: bool, + config_file_path: str, + litellm_settings: dict, +) -> Dict[str, GuardrailItem]: + try: + verbose_proxy_logger.debug(f"validating guardrails passed {guardrails_config}") + global all_guardrails + for item in guardrails_config: + """ + one item looks like this: + + {'prompt_injection': {'callbacks': ['lakera_prompt_injection', 'prompt_injection_api_2'], 'default_on': True, 'enabled_roles': ['user']}} + """ + for k, v in item.items(): + guardrail_item = GuardrailItem(**v, guardrail_name=k) + all_guardrails.append(guardrail_item) + litellm.guardrail_name_config_map[k] = guardrail_item + + # set appropriate callbacks if they are default on + default_on_callbacks = set() + callback_specific_params = {} + for guardrail in all_guardrails: + verbose_proxy_logger.debug(guardrail.guardrail_name) + verbose_proxy_logger.debug(guardrail.default_on) + + callback_specific_params.update(guardrail.callback_args) + + if guardrail.default_on is True: + # add these to litellm callbacks if they don't exist + for callback in guardrail.callbacks: + if callback not in litellm.callbacks: + default_on_callbacks.add(callback) + + if guardrail.logging_only is True: + if callback == "presidio": + callback_specific_params["presidio"] = {"logging_only": True} # type: ignore + + default_on_callbacks_list = list(default_on_callbacks) + if len(default_on_callbacks_list) > 0: + initialize_callbacks_on_proxy( + value=default_on_callbacks_list, + premium_user=premium_user, + config_file_path=config_file_path, + litellm_settings=litellm_settings, + callback_specific_params=callback_specific_params, + ) + + return litellm.guardrail_name_config_map + except Exception as e: + verbose_proxy_logger.exception( + "error initializing guardrails {}".format(str(e)) + ) + raise e + + +""" +Map guardrail_name: <pre_call>, <post_call>, during_call + +""" + + +def init_guardrails_v2( + all_guardrails: List[Dict], + config_file_path: Optional[str] = None, +): + guardrail_list = [] + + for guardrail in all_guardrails: + litellm_params_data = guardrail["litellm_params"] + verbose_proxy_logger.debug("litellm_params= %s", litellm_params_data) + + _litellm_params_kwargs = { + k: litellm_params_data.get(k) for k in LitellmParams.__annotations__.keys() + } + + litellm_params = LitellmParams(**_litellm_params_kwargs) # type: ignore + + if ( + "category_thresholds" in litellm_params_data + and litellm_params_data["category_thresholds"] + ): + lakera_category_thresholds = LakeraCategoryThresholds( + **litellm_params_data["category_thresholds"] + ) + litellm_params["category_thresholds"] = lakera_category_thresholds + + if litellm_params["api_key"] and litellm_params["api_key"].startswith( + "os.environ/" + ): + litellm_params["api_key"] = str(get_secret(litellm_params["api_key"])) # type: ignore + + if litellm_params["api_base"] and litellm_params["api_base"].startswith( + "os.environ/" + ): + litellm_params["api_base"] = str(get_secret(litellm_params["api_base"])) # type: ignore + + guardrail_type = litellm_params["guardrail"] + + initializer = guardrail_registry.get(guardrail_type) + + if initializer: + initializer(litellm_params, guardrail) + elif isinstance(guardrail_type, str) and "." in guardrail_type: + if not config_file_path: + raise Exception( + "GuardrailsAIException - Please pass the config_file_path to initialize_guardrails_v2" + ) + + _file_name, _class_name = guardrail_type.split(".") + verbose_proxy_logger.debug( + "Initializing custom guardrail: %s, file_name: %s, class_name: %s", + guardrail_type, + _file_name, + _class_name, + ) + + directory = os.path.dirname(config_file_path) + module_file_path = os.path.join(directory, _file_name) + ".py" + + spec = importlib.util.spec_from_file_location(_class_name, module_file_path) # type: ignore + if not spec: + raise ImportError( + f"Could not find a module specification for {module_file_path}" + ) + + module = importlib.util.module_from_spec(spec) # type: ignore + spec.loader.exec_module(module) # type: ignore + _guardrail_class = getattr(module, _class_name) + + _guardrail_callback = _guardrail_class( + guardrail_name=guardrail["guardrail_name"], + event_hook=litellm_params["mode"], + default_on=litellm_params["default_on"], + ) + litellm.logging_callback_manager.add_litellm_callback(_guardrail_callback) # type: ignore + else: + raise ValueError(f"Unsupported guardrail: {guardrail_type}") + + parsed_guardrail = Guardrail( + guardrail_name=guardrail["guardrail_name"], + litellm_params=litellm_params, + ) + + guardrail_list.append(parsed_guardrail) + + verbose_proxy_logger.info(f"\nGuardrail List:{guardrail_list}\n") |