aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/image/image_handler.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/bedrock/image/image_handler.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/bedrock/image/image_handler.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/bedrock/image/image_handler.py314
1 files changed, 314 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/image/image_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/image/image_handler.py
new file mode 100644
index 00000000..8f7762e5
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/bedrock/image/image_handler.py
@@ -0,0 +1,314 @@
+import copy
+import json
+import os
+from typing import TYPE_CHECKING, Any, Optional, Union
+
+import httpx
+from pydantic import BaseModel
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ _get_httpx_client,
+ get_async_httpx_client,
+)
+from litellm.types.utils import ImageResponse
+
+from ..base_aws_llm import BaseAWSLLM
+from ..common_utils import BedrockError
+
+if TYPE_CHECKING:
+ from botocore.awsrequest import AWSPreparedRequest
+else:
+ AWSPreparedRequest = Any
+
+
+class BedrockImagePreparedRequest(BaseModel):
+ """
+ Internal/Helper class for preparing the request for bedrock image generation
+ """
+
+ endpoint_url: str
+ prepped: AWSPreparedRequest
+ body: bytes
+ data: dict
+
+
+class BedrockImageGeneration(BaseAWSLLM):
+ """
+ Bedrock Image Generation handler
+ """
+
+ def image_generation(
+ self,
+ model: str,
+ prompt: str,
+ model_response: ImageResponse,
+ optional_params: dict,
+ logging_obj: LitellmLogging,
+ timeout: Optional[Union[float, httpx.Timeout]],
+ aimg_generation: bool = False,
+ api_base: Optional[str] = None,
+ extra_headers: Optional[dict] = None,
+ client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
+ ):
+ prepared_request = self._prepare_request(
+ model=model,
+ optional_params=optional_params,
+ api_base=api_base,
+ extra_headers=extra_headers,
+ logging_obj=logging_obj,
+ prompt=prompt,
+ )
+
+ if aimg_generation is True:
+ return self.async_image_generation(
+ prepared_request=prepared_request,
+ timeout=timeout,
+ model=model,
+ logging_obj=logging_obj,
+ prompt=prompt,
+ model_response=model_response,
+ client=(
+ client
+ if client is not None and isinstance(client, AsyncHTTPHandler)
+ else None
+ ),
+ )
+
+ if client is None or not isinstance(client, HTTPHandler):
+ client = _get_httpx_client()
+ try:
+ response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise BedrockError(status_code=error_code, message=err.response.text)
+ except httpx.TimeoutException:
+ raise BedrockError(status_code=408, message="Timeout error occurred.")
+ ### FORMAT RESPONSE TO OPENAI FORMAT ###
+ model_response = self._transform_response_dict_to_openai_response(
+ model_response=model_response,
+ model=model,
+ logging_obj=logging_obj,
+ prompt=prompt,
+ response=response,
+ data=prepared_request.data,
+ )
+ return model_response
+
+ async def async_image_generation(
+ self,
+ prepared_request: BedrockImagePreparedRequest,
+ timeout: Optional[Union[float, httpx.Timeout]],
+ model: str,
+ logging_obj: LitellmLogging,
+ prompt: str,
+ model_response: ImageResponse,
+ client: Optional[AsyncHTTPHandler] = None,
+ ) -> ImageResponse:
+ """
+ Asynchronous handler for bedrock image generation
+
+ Awaits the response from the bedrock image generation endpoint
+ """
+ async_client = client or get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.BEDROCK,
+ params={"timeout": timeout},
+ )
+
+ try:
+ response = await async_client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise BedrockError(status_code=error_code, message=err.response.text)
+ except httpx.TimeoutException:
+ raise BedrockError(status_code=408, message="Timeout error occurred.")
+
+ ### FORMAT RESPONSE TO OPENAI FORMAT ###
+ model_response = self._transform_response_dict_to_openai_response(
+ model=model,
+ logging_obj=logging_obj,
+ prompt=prompt,
+ response=response,
+ data=prepared_request.data,
+ model_response=model_response,
+ )
+ return model_response
+
+ def _prepare_request(
+ self,
+ model: str,
+ optional_params: dict,
+ api_base: Optional[str],
+ extra_headers: Optional[dict],
+ logging_obj: LitellmLogging,
+ prompt: str,
+ ) -> BedrockImagePreparedRequest:
+ """
+ Prepare the request body, headers, and endpoint URL for the Bedrock Image Generation API
+
+ Args:
+ model (str): The model to use for the image generation
+ optional_params (dict): The optional parameters for the image generation
+ api_base (Optional[str]): The base URL for the Bedrock API
+ extra_headers (Optional[dict]): The extra headers to include in the request
+ logging_obj (LitellmLogging): The logging object to use for logging
+ prompt (str): The prompt to use for the image generation
+ Returns:
+ BedrockImagePreparedRequest: The prepared request object
+
+ The BedrockImagePreparedRequest contains:
+ endpoint_url (str): The endpoint URL for the Bedrock Image Generation API
+ prepped (httpx.Request): The prepared request object
+ body (bytes): The request body
+ """
+ try:
+ from botocore.auth import SigV4Auth
+ from botocore.awsrequest import AWSRequest
+ except ImportError:
+ raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
+ boto3_credentials_info = self._get_boto_credentials_from_optional_params(
+ optional_params, model
+ )
+
+ ### SET RUNTIME ENDPOINT ###
+ modelId = model
+ _, proxy_endpoint_url = self.get_runtime_endpoint(
+ api_base=api_base,
+ aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
+ aws_region_name=boto3_credentials_info.aws_region_name,
+ )
+ proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke"
+ sigv4 = SigV4Auth(
+ boto3_credentials_info.credentials,
+ "bedrock",
+ boto3_credentials_info.aws_region_name,
+ )
+
+ data = self._get_request_body(
+ model=model, prompt=prompt, optional_params=optional_params
+ )
+
+ # Make POST Request
+ body = json.dumps(data).encode("utf-8")
+
+ headers = {"Content-Type": "application/json"}
+ if extra_headers is not None:
+ headers = {"Content-Type": "application/json", **extra_headers}
+ request = AWSRequest(
+ method="POST", url=proxy_endpoint_url, data=body, headers=headers
+ )
+ sigv4.add_auth(request)
+ if (
+ extra_headers is not None and "Authorization" in extra_headers
+ ): # prevent sigv4 from overwriting the auth header
+ request.headers["Authorization"] = extra_headers["Authorization"]
+ prepped = request.prepare()
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key="",
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": proxy_endpoint_url,
+ "headers": prepped.headers,
+ },
+ )
+ return BedrockImagePreparedRequest(
+ endpoint_url=proxy_endpoint_url,
+ prepped=prepped,
+ body=body,
+ data=data,
+ )
+
+ def _get_request_body(
+ self,
+ model: str,
+ prompt: str,
+ optional_params: dict,
+ ) -> dict:
+ """
+ Get the request body for the Bedrock Image Generation API
+
+ Checks the model/provider and transforms the request body accordingly
+
+ Returns:
+ dict: The request body to use for the Bedrock Image Generation API
+ """
+ provider = model.split(".")[0]
+ inference_params = copy.deepcopy(optional_params)
+ inference_params.pop(
+ "user", None
+ ) # make sure user is not passed in for bedrock call
+ data = {}
+ if provider == "stability":
+ if litellm.AmazonStability3Config._is_stability_3_model(model):
+ request_body = litellm.AmazonStability3Config.transform_request_body(
+ prompt=prompt, optional_params=optional_params
+ )
+ return dict(request_body)
+ else:
+ prompt = prompt.replace(os.linesep, " ")
+ ## LOAD CONFIG
+ config = litellm.AmazonStabilityConfig.get_config()
+ for k, v in config.items():
+ if (
+ k not in inference_params
+ ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
+ inference_params[k] = v
+ data = {
+ "text_prompts": [{"text": prompt, "weight": 1}],
+ **inference_params,
+ }
+ elif provider == "amazon":
+ return dict(litellm.AmazonNovaCanvasConfig.transform_request_body(text=prompt, optional_params=optional_params))
+ else:
+ raise BedrockError(
+ status_code=422, message=f"Unsupported model={model}, passed in"
+ )
+ return data
+
+ def _transform_response_dict_to_openai_response(
+ self,
+ model_response: ImageResponse,
+ model: str,
+ logging_obj: LitellmLogging,
+ prompt: str,
+ response: httpx.Response,
+ data: dict,
+ ) -> ImageResponse:
+ """
+ Transforms the Image Generation response from Bedrock to OpenAI format
+ """
+
+ ## LOGGING
+ if logging_obj is not None:
+ logging_obj.post_call(
+ input=prompt,
+ api_key="",
+ original_response=response.text,
+ additional_args={"complete_input_dict": data},
+ )
+ verbose_logger.debug("raw model_response: %s", response.text)
+ response_dict = response.json()
+ if response_dict is None:
+ raise ValueError("Error in response object format, got None")
+
+ config_class = (
+ litellm.AmazonStability3Config
+ if litellm.AmazonStability3Config._is_stability_3_model(model=model)
+ else litellm.AmazonNovaCanvasConfig if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model)
+ else litellm.AmazonStabilityConfig
+ )
+ config_class.transform_response_dict_to_openai_response(
+ model_response=model_response,
+ response_dict=response_dict,
+ )
+
+ return model_response