aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/cost_calculator.py23
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/image_generation_handler.py236
2 files changed, 259 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/cost_calculator.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/cost_calculator.py
new file mode 100644
index 00000000..2ba18c09
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/cost_calculator.py
@@ -0,0 +1,23 @@
+"""
+Vertex AI Image Generation Cost Calculator
+"""
+
+import litellm
+from litellm.types.utils import ImageResponse
+
+
+def cost_calculator(
+ model: str,
+ image_response: ImageResponse,
+) -> float:
+ """
+ Vertex AI Image Generation Cost Calculator
+ """
+ _model_info = litellm.get_model_info(
+ model=model,
+ custom_llm_provider="vertex_ai",
+ )
+
+ output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0
+ num_images: int = len(image_response.data)
+ return output_cost_per_image * num_images
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/image_generation_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/image_generation_handler.py
new file mode 100644
index 00000000..1d5322c0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/image_generation_handler.py
@@ -0,0 +1,236 @@
+import json
+from typing import Any, Dict, List, Optional
+
+import httpx
+from openai.types.image import Image
+
+import litellm
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ get_async_httpx_client,
+)
+from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
+from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
+from litellm.types.utils import ImageResponse
+
+
+class VertexImageGeneration(VertexLLM):
+ def process_image_generation_response(
+ self,
+ json_response: Dict[str, Any],
+ model_response: ImageResponse,
+ model: Optional[str] = None,
+ ) -> ImageResponse:
+ if "predictions" not in json_response:
+ raise litellm.InternalServerError(
+ message=f"image generation response does not contain 'predictions', got {json_response}",
+ llm_provider="vertex_ai",
+ model=model,
+ )
+
+ predictions = json_response["predictions"]
+ response_data: List[Image] = []
+
+ for prediction in predictions:
+ bytes_base64_encoded = prediction["bytesBase64Encoded"]
+ image_object = Image(b64_json=bytes_base64_encoded)
+ response_data.append(image_object)
+
+ model_response.data = response_data
+ return model_response
+
+ def image_generation(
+ self,
+ prompt: str,
+ vertex_project: Optional[str],
+ vertex_location: Optional[str],
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ model_response: ImageResponse,
+ logging_obj: Any,
+ model: Optional[
+ str
+ ] = "imagegeneration", # vertex ai uses imagegeneration as the default model
+ client: Optional[Any] = None,
+ optional_params: Optional[dict] = None,
+ timeout: Optional[int] = None,
+ aimg_generation=False,
+ ) -> ImageResponse:
+ if aimg_generation is True:
+ return self.aimage_generation( # type: ignore
+ prompt=prompt,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ model=model,
+ client=client,
+ optional_params=optional_params,
+ timeout=timeout,
+ logging_obj=logging_obj,
+ model_response=model_response,
+ )
+
+ if client is None:
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ _httpx_timeout = httpx.Timeout(timeout)
+ _params["timeout"] = _httpx_timeout
+ else:
+ _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
+
+ sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
+ else:
+ sync_handler = client # type: ignore
+
+ url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
+
+ auth_header, _ = self._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider="vertex_ai",
+ )
+ optional_params = optional_params or {
+ "sampleCount": 1
+ } # default optional params
+
+ request_data = {
+ "instances": [{"prompt": prompt}],
+ "parameters": optional_params,
+ }
+
+ request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+
+ response = sync_handler.post(
+ url=url,
+ headers={
+ "Content-Type": "application/json; charset=utf-8",
+ "Authorization": f"Bearer {auth_header}",
+ },
+ data=json.dumps(request_data),
+ )
+
+ if response.status_code != 200:
+ raise Exception(f"Error: {response.status_code} {response.text}")
+
+ json_response = response.json()
+ return self.process_image_generation_response(
+ json_response, model_response, model
+ )
+
+ async def aimage_generation(
+ self,
+ prompt: str,
+ vertex_project: Optional[str],
+ vertex_location: Optional[str],
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ model_response: litellm.ImageResponse,
+ logging_obj: Any,
+ model: Optional[
+ str
+ ] = "imagegeneration", # vertex ai uses imagegeneration as the default model
+ client: Optional[AsyncHTTPHandler] = None,
+ optional_params: Optional[dict] = None,
+ timeout: Optional[int] = None,
+ ):
+ response = None
+ if client is None:
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ _httpx_timeout = httpx.Timeout(timeout)
+ _params["timeout"] = _httpx_timeout
+ else:
+ _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
+
+ self.async_handler = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.VERTEX_AI,
+ params={"timeout": timeout},
+ )
+ else:
+ self.async_handler = client # type: ignore
+
+ # make POST request to
+ # https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
+ url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
+
+ """
+ Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
+ curl -X POST \
+ -H "Authorization: Bearer $(gcloud auth print-access-token)" \
+ -H "Content-Type: application/json; charset=utf-8" \
+ -d {
+ "instances": [
+ {
+ "prompt": "a cat"
+ }
+ ],
+ "parameters": {
+ "sampleCount": 1
+ }
+ } \
+ "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
+ """
+ auth_header, _ = self._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider="vertex_ai",
+ )
+ optional_params = optional_params or {
+ "sampleCount": 1
+ } # default optional params
+
+ request_data = {
+ "instances": [{"prompt": prompt}],
+ "parameters": optional_params,
+ }
+
+ request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+
+ response = await self.async_handler.post(
+ url=url,
+ headers={
+ "Content-Type": "application/json; charset=utf-8",
+ "Authorization": f"Bearer {auth_header}",
+ },
+ data=json.dumps(request_data),
+ )
+
+ if response.status_code != 200:
+ raise Exception(f"Error: {response.status_code} {response.text}")
+
+ json_response = response.json()
+ return self.process_image_generation_response(
+ json_response, model_response, model
+ )
+
+ def is_image_generation_response(self, json_response: Dict[str, Any]) -> bool:
+ if "predictions" in json_response:
+ if "bytesBase64Encoded" in json_response["predictions"][0]:
+ return True
+ return False