diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py | 305 |
1 files changed, 305 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py new file mode 100644 index 00000000..7686fba7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py @@ -0,0 +1,305 @@ +# +-------------------------------------------------------------+ +# +# Use Bedrock Guardrails for your LLM calls +# +# +-------------------------------------------------------------+ +# Thank you users! We ❤️ you! - Krrish & Ishaan + +import os +import sys + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import json +import sys +from typing import Any, List, Literal, Optional, Union + +from fastapi import HTTPException + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.integrations.custom_guardrail import ( + CustomGuardrail, + log_guardrail_information, +) +from litellm.litellm_core_utils.prompt_templates.common_utils import ( + convert_content_list_to_str, +) +from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.proxy._types import UserAPIKeyAuth +from litellm.secret_managers.main import get_secret +from litellm.types.guardrails import ( + BedrockContentItem, + BedrockRequest, + BedrockTextContent, + GuardrailEventHooks, +) +from litellm.types.llms.openai import AllMessageValues +from litellm.types.utils import ModelResponse + +GUARDRAIL_NAME = "bedrock" + + +class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): + def __init__( + self, + guardrailIdentifier: Optional[str] = None, + guardrailVersion: Optional[str] = None, + **kwargs, + ): + self.async_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.GuardrailCallback + ) + self.guardrailIdentifier = guardrailIdentifier + self.guardrailVersion = guardrailVersion + + # store kwargs as optional_params + self.optional_params = kwargs + + super().__init__(**kwargs) + BaseAWSLLM.__init__(self) + + def convert_to_bedrock_format( + self, + messages: Optional[List[AllMessageValues]] = None, + response: Optional[Union[Any, ModelResponse]] = None, + ) -> BedrockRequest: + bedrock_request: BedrockRequest = BedrockRequest(source="INPUT") + bedrock_request_content: List[BedrockContentItem] = [] + + if messages: + for message in messages: + bedrock_content_item = BedrockContentItem( + text=BedrockTextContent( + text=convert_content_list_to_str(message=message) + ) + ) + bedrock_request_content.append(bedrock_content_item) + + bedrock_request["content"] = bedrock_request_content + if response: + bedrock_request["source"] = "OUTPUT" + if isinstance(response, litellm.ModelResponse): + for choice in response.choices: + if isinstance(choice, litellm.Choices): + if choice.message.content and isinstance( + choice.message.content, str + ): + bedrock_content_item = BedrockContentItem( + text=BedrockTextContent(text=choice.message.content) + ) + bedrock_request_content.append(bedrock_content_item) + bedrock_request["content"] = bedrock_request_content + return bedrock_request + + #### CALL HOOKS - proxy only #### + def _load_credentials( + self, + ): + try: + from botocore.credentials import Credentials + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + ## CREDENTIALS ## + # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = self.optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = self.optional_params.pop("aws_access_key_id", None) + aws_session_token = self.optional_params.pop("aws_session_token", None) + aws_region_name = self.optional_params.pop("aws_region_name", None) + aws_role_name = self.optional_params.pop("aws_role_name", None) + aws_session_name = self.optional_params.pop("aws_session_name", None) + aws_profile_name = self.optional_params.pop("aws_profile_name", None) + self.optional_params.pop( + "aws_bedrock_runtime_endpoint", None + ) # https://bedrock-runtime.{region_name}.amazonaws.com + aws_web_identity_token = self.optional_params.pop( + "aws_web_identity_token", None + ) + aws_sts_endpoint = self.optional_params.pop("aws_sts_endpoint", None) + + ### SET REGION NAME ### + if aws_region_name is None: + # check env # + litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) + + if litellm_aws_region_name is not None and isinstance( + litellm_aws_region_name, str + ): + aws_region_name = litellm_aws_region_name + + standard_aws_region_name = get_secret("AWS_REGION", None) + if standard_aws_region_name is not None and isinstance( + standard_aws_region_name, str + ): + aws_region_name = standard_aws_region_name + + if aws_region_name is None: + aws_region_name = "us-west-2" + + credentials: Credentials = self.get_credentials( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + aws_region_name=aws_region_name, + aws_session_name=aws_session_name, + aws_profile_name=aws_profile_name, + aws_role_name=aws_role_name, + aws_web_identity_token=aws_web_identity_token, + aws_sts_endpoint=aws_sts_endpoint, + ) + return credentials, aws_region_name + + def _prepare_request( + self, + credentials, + data: dict, + optional_params: dict, + aws_region_name: str, + extra_headers: Optional[dict] = None, + ): + try: + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + except ImportError: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + + sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) + api_base = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com/guardrail/{self.guardrailIdentifier}/version/{self.guardrailVersion}/apply" + + encoded_data = json.dumps(data).encode("utf-8") + headers = {"Content-Type": "application/json"} + if extra_headers is not None: + headers = {"Content-Type": "application/json", **extra_headers} + + request = AWSRequest( + method="POST", url=api_base, data=encoded_data, headers=headers + ) + sigv4.add_auth(request) + if ( + extra_headers is not None and "Authorization" in extra_headers + ): # prevent sigv4 from overwriting the auth header + request.headers["Authorization"] = extra_headers["Authorization"] + + prepped_request = request.prepare() + + return prepped_request + + async def make_bedrock_api_request( + self, kwargs: dict, response: Optional[Union[Any, litellm.ModelResponse]] = None + ): + + credentials, aws_region_name = self._load_credentials() + bedrock_request_data: dict = dict( + self.convert_to_bedrock_format( + messages=kwargs.get("messages"), response=response + ) + ) + bedrock_request_data.update( + self.get_guardrail_dynamic_request_body_params(request_data=kwargs) + ) + prepared_request = self._prepare_request( + credentials=credentials, + data=bedrock_request_data, + optional_params=self.optional_params, + aws_region_name=aws_region_name, + ) + verbose_proxy_logger.debug( + "Bedrock AI request body: %s, url %s, headers: %s", + bedrock_request_data, + prepared_request.url, + prepared_request.headers, + ) + + response = await self.async_handler.post( + url=prepared_request.url, + data=prepared_request.body, # type: ignore + headers=prepared_request.headers, # type: ignore + ) + verbose_proxy_logger.debug("Bedrock AI response: %s", response.text) + if response.status_code == 200: + # check if the response was flagged + _json_response = response.json() + if _json_response.get("action") == "GUARDRAIL_INTERVENED": + raise HTTPException( + status_code=400, + detail={ + "error": "Violated guardrail policy", + "bedrock_guardrail_response": _json_response, + }, + ) + else: + verbose_proxy_logger.error( + "Bedrock AI: error in response. Status code: %s, response: %s", + response.status_code, + response.text, + ) + + @log_guardrail_information + async def async_moderation_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal[ + "completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "responses", + ], + ): + from litellm.proxy.common_utils.callback_utils import ( + add_guardrail_to_applied_guardrails_header, + ) + + event_type: GuardrailEventHooks = GuardrailEventHooks.during_call + if self.should_run_guardrail(data=data, event_type=event_type) is not True: + return + + new_messages: Optional[List[dict]] = data.get("messages") + if new_messages is not None: + await self.make_bedrock_api_request(kwargs=data) + add_guardrail_to_applied_guardrails_header( + request_data=data, guardrail_name=self.guardrail_name + ) + else: + verbose_proxy_logger.warning( + "Bedrock AI: not running guardrail. No messages in data" + ) + pass + + @log_guardrail_information + async def async_post_call_success_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + response, + ): + from litellm.proxy.common_utils.callback_utils import ( + add_guardrail_to_applied_guardrails_header, + ) + from litellm.types.guardrails import GuardrailEventHooks + + if ( + self.should_run_guardrail( + data=data, event_type=GuardrailEventHooks.post_call + ) + is not True + ): + return + + new_messages: Optional[List[dict]] = data.get("messages") + if new_messages is not None: + await self.make_bedrock_api_request(kwargs=data, response=response) + add_guardrail_to_applied_guardrails_header( + request_data=data, guardrail_name=self.guardrail_name + ) + else: + verbose_proxy_logger.warning( + "Bedrock AI: not running guardrail. No messages in data" + ) |