diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation')
2 files changed, 259 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/cost_calculator.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/cost_calculator.py new file mode 100644 index 00000000..2ba18c09 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/cost_calculator.py @@ -0,0 +1,23 @@ +""" +Vertex AI Image Generation Cost Calculator +""" + +import litellm +from litellm.types.utils import ImageResponse + + +def cost_calculator( + model: str, + image_response: ImageResponse, +) -> float: + """ + Vertex AI Image Generation Cost Calculator + """ + _model_info = litellm.get_model_info( + model=model, + custom_llm_provider="vertex_ai", + ) + + output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0 + num_images: int = len(image_response.data) + return output_cost_per_image * num_images diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/image_generation_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/image_generation_handler.py new file mode 100644 index 00000000..1d5322c0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/image_generation_handler.py @@ -0,0 +1,236 @@ +import json +from typing import Any, Dict, List, Optional + +import httpx +from openai.types.image import Image + +import litellm +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, +) +from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM +from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES +from litellm.types.utils import ImageResponse + + +class VertexImageGeneration(VertexLLM): + def process_image_generation_response( + self, + json_response: Dict[str, Any], + model_response: ImageResponse, + model: Optional[str] = None, + ) -> ImageResponse: + if "predictions" not in json_response: + raise litellm.InternalServerError( + message=f"image generation response does not contain 'predictions', got {json_response}", + llm_provider="vertex_ai", + model=model, + ) + + predictions = json_response["predictions"] + response_data: List[Image] = [] + + for prediction in predictions: + bytes_base64_encoded = prediction["bytesBase64Encoded"] + image_object = Image(b64_json=bytes_base64_encoded) + response_data.append(image_object) + + model_response.data = response_data + return model_response + + def image_generation( + self, + prompt: str, + vertex_project: Optional[str], + vertex_location: Optional[str], + vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES], + model_response: ImageResponse, + logging_obj: Any, + model: Optional[ + str + ] = "imagegeneration", # vertex ai uses imagegeneration as the default model + client: Optional[Any] = None, + optional_params: Optional[dict] = None, + timeout: Optional[int] = None, + aimg_generation=False, + ) -> ImageResponse: + if aimg_generation is True: + return self.aimage_generation( # type: ignore + prompt=prompt, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + model=model, + client=client, + optional_params=optional_params, + timeout=timeout, + logging_obj=logging_obj, + model_response=model_response, + ) + + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore + else: + sync_handler = client # type: ignore + + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" + + auth_header, _ = self._ensure_access_token( + credentials=vertex_credentials, + project_id=vertex_project, + custom_llm_provider="vertex_ai", + ) + optional_params = optional_params or { + "sampleCount": 1 + } # default optional params + + request_data = { + "instances": [{"prompt": prompt}], + "parameters": optional_params, + } + + request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + + response = sync_handler.post( + url=url, + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {auth_header}", + }, + data=json.dumps(request_data), + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + json_response = response.json() + return self.process_image_generation_response( + json_response, model_response, model + ) + + async def aimage_generation( + self, + prompt: str, + vertex_project: Optional[str], + vertex_location: Optional[str], + vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES], + model_response: litellm.ImageResponse, + logging_obj: Any, + model: Optional[ + str + ] = "imagegeneration", # vertex ai uses imagegeneration as the default model + client: Optional[AsyncHTTPHandler] = None, + optional_params: Optional[dict] = None, + timeout: Optional[int] = None, + ): + response = None + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + _httpx_timeout = httpx.Timeout(timeout) + _params["timeout"] = _httpx_timeout + else: + _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) + + self.async_handler = get_async_httpx_client( + llm_provider=litellm.LlmProviders.VERTEX_AI, + params={"timeout": timeout}, + ) + else: + self.async_handler = client # type: ignore + + # make POST request to + # https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict + url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" + + """ + Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218 + curl -X POST \ + -H "Authorization: Bearer $(gcloud auth print-access-token)" \ + -H "Content-Type: application/json; charset=utf-8" \ + -d { + "instances": [ + { + "prompt": "a cat" + } + ], + "parameters": { + "sampleCount": 1 + } + } \ + "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict" + """ + auth_header, _ = self._ensure_access_token( + credentials=vertex_credentials, + project_id=vertex_project, + custom_llm_provider="vertex_ai", + ) + optional_params = optional_params or { + "sampleCount": 1 + } # default optional params + + request_data = { + "instances": [{"prompt": prompt}], + "parameters": optional_params, + } + + request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" + logging_obj.pre_call( + input=prompt, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + "request_str": request_str, + }, + ) + + response = await self.async_handler.post( + url=url, + headers={ + "Content-Type": "application/json; charset=utf-8", + "Authorization": f"Bearer {auth_header}", + }, + data=json.dumps(request_data), + ) + + if response.status_code != 200: + raise Exception(f"Error: {response.status_code} {response.text}") + + json_response = response.json() + return self.process_image_generation_response( + json_response, model_response, model + ) + + def is_image_generation_response(self, json_response: Dict[str, Any]) -> bool: + if "predictions" in json_response: + if "bytesBase64Encoded" in json_response["predictions"][0]: + return True + return False |