aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/Readme.md6
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/handler.py218
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/transformation.py193
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/common_utils.py282
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py110
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py416
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/cost_calculator.py242
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/files/handler.py97
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/files/transformation.py163
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/fine_tuning/handler.py371
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py479
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py1495
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_handler.py182
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py74
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/cost_calculator.py23
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/image_generation_handler.py236
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py294
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/text_to_speech/text_to_speech_handler.py243
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_non_gemini.py784
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/ai21/transformation.py62
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/transformation.py114
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/llama3/transformation.py73
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py242
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py228
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/transformation.py265
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/types.py62
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_llm_base.py319
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_model_garden/main.py149
28 files changed, 7422 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/Readme.md b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/Readme.md
new file mode 100644
index 00000000..2aa7d7b0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/Readme.md
@@ -0,0 +1,6 @@
+# Vertex AI Batch Prediction Jobs
+
+Implementation to call VertexAI Batch endpoints in OpenAI Batch API spec
+
+Vertex Docs: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini
+
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/handler.py
new file mode 100644
index 00000000..b82268be
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/handler.py
@@ -0,0 +1,218 @@
+import json
+from typing import Any, Coroutine, Dict, Optional, Union
+
+import httpx
+
+import litellm
+from litellm.llms.custom_httpx.http_handler import (
+ _get_httpx_client,
+ get_async_httpx_client,
+)
+from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
+from litellm.types.llms.openai import CreateBatchRequest
+from litellm.types.llms.vertex_ai import (
+ VERTEX_CREDENTIALS_TYPES,
+ VertexAIBatchPredictionJob,
+)
+from litellm.types.utils import LiteLLMBatch
+
+from .transformation import VertexAIBatchTransformation
+
+
+class VertexAIBatchPrediction(VertexLLM):
+ def __init__(self, gcs_bucket_name: str, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.gcs_bucket_name = gcs_bucket_name
+
+ def create_batch(
+ self,
+ _is_async: bool,
+ create_batch_data: CreateBatchRequest,
+ api_base: Optional[str],
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ vertex_project: Optional[str],
+ vertex_location: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
+
+ sync_handler = _get_httpx_client()
+
+ access_token, project_id = self._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider="vertex_ai",
+ )
+
+ default_api_base = self.create_vertex_url(
+ vertex_location=vertex_location or "us-central1",
+ vertex_project=vertex_project or project_id,
+ )
+
+ if len(default_api_base.split(":")) > 1:
+ endpoint = default_api_base.split(":")[-1]
+ else:
+ endpoint = ""
+
+ _, api_base = self._check_custom_proxy(
+ api_base=api_base,
+ custom_llm_provider="vertex_ai",
+ gemini_api_key=None,
+ endpoint=endpoint,
+ stream=None,
+ auth_header=None,
+ url=default_api_base,
+ )
+
+ headers = {
+ "Content-Type": "application/json; charset=utf-8",
+ "Authorization": f"Bearer {access_token}",
+ }
+
+ vertex_batch_request: VertexAIBatchPredictionJob = (
+ VertexAIBatchTransformation.transform_openai_batch_request_to_vertex_ai_batch_request(
+ request=create_batch_data
+ )
+ )
+
+ if _is_async is True:
+ return self._async_create_batch(
+ vertex_batch_request=vertex_batch_request,
+ api_base=api_base,
+ headers=headers,
+ )
+
+ response = sync_handler.post(
+ url=api_base,
+ headers=headers,
+ data=json.dumps(vertex_batch_request),
+ )
+
+ if response.status_code != 200:
+ raise Exception(f"Error: {response.status_code} {response.text}")
+
+ _json_response = response.json()
+ vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
+ response=_json_response
+ )
+ return vertex_batch_response
+
+ async def _async_create_batch(
+ self,
+ vertex_batch_request: VertexAIBatchPredictionJob,
+ api_base: str,
+ headers: Dict[str, str],
+ ) -> LiteLLMBatch:
+ client = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.VERTEX_AI,
+ )
+ response = await client.post(
+ url=api_base,
+ headers=headers,
+ data=json.dumps(vertex_batch_request),
+ )
+ if response.status_code != 200:
+ raise Exception(f"Error: {response.status_code} {response.text}")
+
+ _json_response = response.json()
+ vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
+ response=_json_response
+ )
+ return vertex_batch_response
+
+ def create_vertex_url(
+ self,
+ vertex_location: str,
+ vertex_project: str,
+ ) -> str:
+ """Return the base url for the vertex garden models"""
+ # POST https://LOCATION-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/LOCATION/batchPredictionJobs
+ return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/batchPredictionJobs"
+
+ def retrieve_batch(
+ self,
+ _is_async: bool,
+ batch_id: str,
+ api_base: Optional[str],
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ vertex_project: Optional[str],
+ vertex_location: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ ) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
+ sync_handler = _get_httpx_client()
+
+ access_token, project_id = self._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider="vertex_ai",
+ )
+
+ default_api_base = self.create_vertex_url(
+ vertex_location=vertex_location or "us-central1",
+ vertex_project=vertex_project or project_id,
+ )
+
+ # Append batch_id to the URL
+ default_api_base = f"{default_api_base}/{batch_id}"
+
+ if len(default_api_base.split(":")) > 1:
+ endpoint = default_api_base.split(":")[-1]
+ else:
+ endpoint = ""
+
+ _, api_base = self._check_custom_proxy(
+ api_base=api_base,
+ custom_llm_provider="vertex_ai",
+ gemini_api_key=None,
+ endpoint=endpoint,
+ stream=None,
+ auth_header=None,
+ url=default_api_base,
+ )
+
+ headers = {
+ "Content-Type": "application/json; charset=utf-8",
+ "Authorization": f"Bearer {access_token}",
+ }
+
+ if _is_async is True:
+ return self._async_retrieve_batch(
+ api_base=api_base,
+ headers=headers,
+ )
+
+ response = sync_handler.get(
+ url=api_base,
+ headers=headers,
+ )
+
+ if response.status_code != 200:
+ raise Exception(f"Error: {response.status_code} {response.text}")
+
+ _json_response = response.json()
+ vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
+ response=_json_response
+ )
+ return vertex_batch_response
+
+ async def _async_retrieve_batch(
+ self,
+ api_base: str,
+ headers: Dict[str, str],
+ ) -> LiteLLMBatch:
+ client = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.VERTEX_AI,
+ )
+ response = await client.get(
+ url=api_base,
+ headers=headers,
+ )
+ if response.status_code != 200:
+ raise Exception(f"Error: {response.status_code} {response.text}")
+
+ _json_response = response.json()
+ vertex_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
+ response=_json_response
+ )
+ return vertex_batch_response
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/transformation.py
new file mode 100644
index 00000000..a97f312d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/batches/transformation.py
@@ -0,0 +1,193 @@
+import uuid
+from typing import Dict
+
+from litellm.llms.vertex_ai.common_utils import (
+ _convert_vertex_datetime_to_openai_datetime,
+)
+from litellm.types.llms.openai import BatchJobStatus, CreateBatchRequest
+from litellm.types.llms.vertex_ai import *
+from litellm.types.utils import LiteLLMBatch
+
+
+class VertexAIBatchTransformation:
+ """
+ Transforms OpenAI Batch requests to Vertex AI Batch requests
+
+ API Ref: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/batch-prediction-gemini
+ """
+
+ @classmethod
+ def transform_openai_batch_request_to_vertex_ai_batch_request(
+ cls,
+ request: CreateBatchRequest,
+ ) -> VertexAIBatchPredictionJob:
+ """
+ Transforms OpenAI Batch requests to Vertex AI Batch requests
+ """
+ request_display_name = f"litellm-vertex-batch-{uuid.uuid4()}"
+ input_file_id = request.get("input_file_id")
+ if input_file_id is None:
+ raise ValueError("input_file_id is required, but not provided")
+ input_config: InputConfig = InputConfig(
+ gcsSource=GcsSource(uris=input_file_id), instancesFormat="jsonl"
+ )
+ model: str = cls._get_model_from_gcs_file(input_file_id)
+ output_config: OutputConfig = OutputConfig(
+ predictionsFormat="jsonl",
+ gcsDestination=GcsDestination(
+ outputUriPrefix=cls._get_gcs_uri_prefix_from_file(input_file_id)
+ ),
+ )
+ return VertexAIBatchPredictionJob(
+ inputConfig=input_config,
+ outputConfig=output_config,
+ model=model,
+ displayName=request_display_name,
+ )
+
+ @classmethod
+ def transform_vertex_ai_batch_response_to_openai_batch_response(
+ cls, response: VertexBatchPredictionResponse
+ ) -> LiteLLMBatch:
+ return LiteLLMBatch(
+ id=cls._get_batch_id_from_vertex_ai_batch_response(response),
+ completion_window="24hrs",
+ created_at=_convert_vertex_datetime_to_openai_datetime(
+ vertex_datetime=response.get("createTime", "")
+ ),
+ endpoint="",
+ input_file_id=cls._get_input_file_id_from_vertex_ai_batch_response(
+ response
+ ),
+ object="batch",
+ status=cls._get_batch_job_status_from_vertex_ai_batch_response(response),
+ error_file_id=None, # Vertex AI doesn't seem to have a direct equivalent
+ output_file_id=cls._get_output_file_id_from_vertex_ai_batch_response(
+ response
+ ),
+ )
+
+ @classmethod
+ def _get_batch_id_from_vertex_ai_batch_response(
+ cls, response: VertexBatchPredictionResponse
+ ) -> str:
+ """
+ Gets the batch id from the Vertex AI Batch response safely
+
+ vertex response: `projects/510528649030/locations/us-central1/batchPredictionJobs/3814889423749775360`
+ returns: `3814889423749775360`
+ """
+ _name = response.get("name", "")
+ if not _name:
+ return ""
+
+ # Split by '/' and get the last part if it exists
+ parts = _name.split("/")
+ return parts[-1] if parts else _name
+
+ @classmethod
+ def _get_input_file_id_from_vertex_ai_batch_response(
+ cls, response: VertexBatchPredictionResponse
+ ) -> str:
+ """
+ Gets the input file id from the Vertex AI Batch response
+ """
+ input_file_id: str = ""
+ input_config = response.get("inputConfig")
+ if input_config is None:
+ return input_file_id
+
+ gcs_source = input_config.get("gcsSource")
+ if gcs_source is None:
+ return input_file_id
+
+ uris = gcs_source.get("uris", "")
+ if len(uris) == 0:
+ return input_file_id
+
+ return uris[0]
+
+ @classmethod
+ def _get_output_file_id_from_vertex_ai_batch_response(
+ cls, response: VertexBatchPredictionResponse
+ ) -> str:
+ """
+ Gets the output file id from the Vertex AI Batch response
+ """
+ output_file_id: str = ""
+ output_config = response.get("outputConfig")
+ if output_config is None:
+ return output_file_id
+
+ gcs_destination = output_config.get("gcsDestination")
+ if gcs_destination is None:
+ return output_file_id
+
+ output_uri_prefix = gcs_destination.get("outputUriPrefix", "")
+ return output_uri_prefix
+
+ @classmethod
+ def _get_batch_job_status_from_vertex_ai_batch_response(
+ cls, response: VertexBatchPredictionResponse
+ ) -> BatchJobStatus:
+ """
+ Gets the batch job status from the Vertex AI Batch response
+
+ ref: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/JobState
+ """
+ state_mapping: Dict[str, BatchJobStatus] = {
+ "JOB_STATE_UNSPECIFIED": "failed",
+ "JOB_STATE_QUEUED": "validating",
+ "JOB_STATE_PENDING": "validating",
+ "JOB_STATE_RUNNING": "in_progress",
+ "JOB_STATE_SUCCEEDED": "completed",
+ "JOB_STATE_FAILED": "failed",
+ "JOB_STATE_CANCELLING": "cancelling",
+ "JOB_STATE_CANCELLED": "cancelled",
+ "JOB_STATE_PAUSED": "in_progress",
+ "JOB_STATE_EXPIRED": "expired",
+ "JOB_STATE_UPDATING": "in_progress",
+ "JOB_STATE_PARTIALLY_SUCCEEDED": "completed",
+ }
+
+ vertex_state = response.get("state", "JOB_STATE_UNSPECIFIED")
+ return state_mapping[vertex_state]
+
+ @classmethod
+ def _get_gcs_uri_prefix_from_file(cls, input_file_id: str) -> str:
+ """
+ Gets the gcs uri prefix from the input file id
+
+ Example:
+ input_file_id: "gs://litellm-testing-bucket/vtx_batch.jsonl"
+ returns: "gs://litellm-testing-bucket"
+
+ input_file_id: "gs://litellm-testing-bucket/batches/vtx_batch.jsonl"
+ returns: "gs://litellm-testing-bucket/batches"
+ """
+ # Split the path and remove the filename
+ path_parts = input_file_id.rsplit("/", 1)
+ return path_parts[0]
+
+ @classmethod
+ def _get_model_from_gcs_file(cls, gcs_file_uri: str) -> str:
+ """
+ Extracts the model from the gcs file uri
+
+ When files are uploaded using LiteLLM (/v1/files), the model is stored in the gcs file uri
+
+ Why?
+ - Because Vertex Requires the `model` param in create batch jobs request, but OpenAI does not require this
+
+
+ gcs_file_uri format: gs://litellm-testing-bucket/litellm-vertex-files/publishers/google/models/gemini-1.5-flash-001/e9412502-2c91-42a6-8e61-f5c294cc0fc8
+ returns: "publishers/google/models/gemini-1.5-flash-001"
+ """
+ from urllib.parse import unquote
+
+ decoded_uri = unquote(gcs_file_uri)
+
+ model_path = decoded_uri.split("publishers/")[1]
+ parts = model_path.split("/")
+ model = f"publishers/{'/'.join(parts[:3])}"
+ return model
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/common_utils.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/common_utils.py
new file mode 100644
index 00000000..f7149c34
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/common_utils.py
@@ -0,0 +1,282 @@
+from typing import Dict, List, Literal, Optional, Tuple, Union
+
+import httpx
+
+from litellm import supports_response_schema, supports_system_messages, verbose_logger
+from litellm.llms.base_llm.chat.transformation import BaseLLMException
+from litellm.types.llms.vertex_ai import PartType
+
+
+class VertexAIError(BaseLLMException):
+ def __init__(
+ self,
+ status_code: int,
+ message: str,
+ headers: Optional[Union[Dict, httpx.Headers]] = None,
+ ):
+ super().__init__(message=message, status_code=status_code, headers=headers)
+
+
+def get_supports_system_message(
+ model: str, custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"]
+) -> bool:
+ try:
+ _custom_llm_provider = custom_llm_provider
+ if custom_llm_provider == "vertex_ai_beta":
+ _custom_llm_provider = "vertex_ai"
+ supports_system_message = supports_system_messages(
+ model=model, custom_llm_provider=_custom_llm_provider
+ )
+ except Exception as e:
+ verbose_logger.warning(
+ "Unable to identify if system message supported. Defaulting to 'False'. Received error message - {}\nAdd it here - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json".format(
+ str(e)
+ )
+ )
+ supports_system_message = False
+
+ return supports_system_message
+
+
+def get_supports_response_schema(
+ model: str, custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"]
+) -> bool:
+ _custom_llm_provider = custom_llm_provider
+ if custom_llm_provider == "vertex_ai_beta":
+ _custom_llm_provider = "vertex_ai"
+
+ _supports_response_schema = supports_response_schema(
+ model=model, custom_llm_provider=_custom_llm_provider
+ )
+
+ return _supports_response_schema
+
+
+from typing import Literal, Optional
+
+all_gemini_url_modes = Literal["chat", "embedding", "batch_embedding"]
+
+
+def _get_vertex_url(
+ mode: all_gemini_url_modes,
+ model: str,
+ stream: Optional[bool],
+ vertex_project: Optional[str],
+ vertex_location: Optional[str],
+ vertex_api_version: Literal["v1", "v1beta1"],
+) -> Tuple[str, str]:
+ url: Optional[str] = None
+ endpoint: Optional[str] = None
+ if mode == "chat":
+ ### SET RUNTIME ENDPOINT ###
+ endpoint = "generateContent"
+ if stream is True:
+ endpoint = "streamGenerateContent"
+ url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}?alt=sse"
+ else:
+ url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
+
+ # if model is only numeric chars then it's a fine tuned gemini model
+ # model = 4965075652664360960
+ # send to this url: url = f"https://{vertex_location}-aiplatform.googleapis.com/{version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
+ if model.isdigit():
+ # It's a fine-tuned Gemini model
+ url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
+ if stream is True:
+ url += "?alt=sse"
+ elif mode == "embedding":
+ endpoint = "predict"
+ url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
+ if model.isdigit():
+ # https://us-central1-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/us-central1/endpoints/$ENDPOINT_ID:predict
+ url = f"https://{vertex_location}-aiplatform.googleapis.com/{vertex_api_version}/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}:{endpoint}"
+
+ if not url or not endpoint:
+ raise ValueError(f"Unable to get vertex url/endpoint for mode: {mode}")
+ return url, endpoint
+
+
+def _get_gemini_url(
+ mode: all_gemini_url_modes,
+ model: str,
+ stream: Optional[bool],
+ gemini_api_key: Optional[str],
+) -> Tuple[str, str]:
+ _gemini_model_name = "models/{}".format(model)
+ if mode == "chat":
+ endpoint = "generateContent"
+ if stream is True:
+ endpoint = "streamGenerateContent"
+ url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}&alt=sse".format(
+ _gemini_model_name, endpoint, gemini_api_key
+ )
+ else:
+ url = (
+ "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
+ _gemini_model_name, endpoint, gemini_api_key
+ )
+ )
+ elif mode == "embedding":
+ endpoint = "embedContent"
+ url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
+ _gemini_model_name, endpoint, gemini_api_key
+ )
+ elif mode == "batch_embedding":
+ endpoint = "batchEmbedContents"
+ url = "https://generativelanguage.googleapis.com/v1beta/{}:{}?key={}".format(
+ _gemini_model_name, endpoint, gemini_api_key
+ )
+
+ return url, endpoint
+
+
+def _check_text_in_content(parts: List[PartType]) -> bool:
+ """
+ check that user_content has 'text' parameter.
+ - Known Vertex Error: Unable to submit request because it must have a text parameter.
+ - 'text' param needs to be len > 0
+ - Relevant Issue: https://github.com/BerriAI/litellm/issues/5515
+ """
+ has_text_param = False
+ for part in parts:
+ if "text" in part and part.get("text"):
+ has_text_param = True
+
+ return has_text_param
+
+
+def _build_vertex_schema(parameters: dict):
+ """
+ This is a modified version of https://github.com/google-gemini/generative-ai-python/blob/8f77cc6ac99937cd3a81299ecf79608b91b06bbb/google/generativeai/types/content_types.py#L419
+ """
+ defs = parameters.pop("$defs", {})
+ # flatten the defs
+ for name, value in defs.items():
+ unpack_defs(value, defs)
+ unpack_defs(parameters, defs)
+
+ # 5. Nullable fields:
+ # * https://github.com/pydantic/pydantic/issues/1270
+ # * https://stackoverflow.com/a/58841311
+ # * https://github.com/pydantic/pydantic/discussions/4872
+ convert_to_nullable(parameters)
+ add_object_type(parameters)
+ # Postprocessing
+ # 4. Suppress unnecessary title generation:
+ # * https://github.com/pydantic/pydantic/issues/1051
+ # * http://cl/586221780
+ strip_field(parameters, field_name="title")
+
+ strip_field(
+ parameters, field_name="$schema"
+ ) # 5. Remove $schema - json schema value, not supported by OpenAPI - causes vertex errors.
+ strip_field(
+ parameters, field_name="$id"
+ ) # 6. Remove id - json schema value, not supported by OpenAPI - causes vertex errors.
+
+ return parameters
+
+
+def unpack_defs(schema, defs):
+ properties = schema.get("properties", None)
+ if properties is None:
+ return
+
+ for name, value in properties.items():
+ ref_key = value.get("$ref", None)
+ if ref_key is not None:
+ ref = defs[ref_key.split("defs/")[-1]]
+ unpack_defs(ref, defs)
+ properties[name] = ref
+ continue
+
+ anyof = value.get("anyOf", None)
+ if anyof is not None:
+ for i, atype in enumerate(anyof):
+ ref_key = atype.get("$ref", None)
+ if ref_key is not None:
+ ref = defs[ref_key.split("defs/")[-1]]
+ unpack_defs(ref, defs)
+ anyof[i] = ref
+ continue
+
+ items = value.get("items", None)
+ if items is not None:
+ ref_key = items.get("$ref", None)
+ if ref_key is not None:
+ ref = defs[ref_key.split("defs/")[-1]]
+ unpack_defs(ref, defs)
+ value["items"] = ref
+ continue
+
+
+def convert_to_nullable(schema):
+ anyof = schema.pop("anyOf", None)
+ if anyof is not None:
+ if len(anyof) != 2:
+ raise ValueError(
+ "Invalid input: Type Unions are not supported, except for `Optional` types. "
+ "Please provide an `Optional` type or a non-Union type."
+ )
+ a, b = anyof
+ if a == {"type": "null"}:
+ schema.update(b)
+ elif b == {"type": "null"}:
+ schema.update(a)
+ else:
+ raise ValueError(
+ "Invalid input: Type Unions are not supported, except for `Optional` types. "
+ "Please provide an `Optional` type or a non-Union type."
+ )
+ schema["nullable"] = True
+
+ properties = schema.get("properties", None)
+ if properties is not None:
+ for name, value in properties.items():
+ convert_to_nullable(value)
+
+ items = schema.get("items", None)
+ if items is not None:
+ convert_to_nullable(items)
+
+
+def add_object_type(schema):
+ properties = schema.get("properties", None)
+ if properties is not None:
+ if "required" in schema and schema["required"] is None:
+ schema.pop("required", None)
+ schema["type"] = "object"
+ for name, value in properties.items():
+ add_object_type(value)
+
+ items = schema.get("items", None)
+ if items is not None:
+ add_object_type(items)
+
+
+def strip_field(schema, field_name: str):
+ schema.pop(field_name, None)
+
+ properties = schema.get("properties", None)
+ if properties is not None:
+ for name, value in properties.items():
+ strip_field(value, field_name)
+
+ items = schema.get("items", None)
+ if items is not None:
+ strip_field(items, field_name)
+
+
+def _convert_vertex_datetime_to_openai_datetime(vertex_datetime: str) -> int:
+ """
+ Converts a Vertex AI datetime string to an OpenAI datetime integer
+
+ vertex_datetime: str = "2024-12-04T21:53:12.120184Z"
+ returns: int = 1722729192
+ """
+ from datetime import datetime
+
+ # Parse the ISO format string to datetime object
+ dt = datetime.strptime(vertex_datetime, "%Y-%m-%dT%H:%M:%S.%fZ")
+ # Convert to Unix timestamp (seconds since epoch)
+ return int(dt.timestamp())
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py
new file mode 100644
index 00000000..83c15029
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py
@@ -0,0 +1,110 @@
+"""
+Transformation logic for context caching.
+
+Why separate file? Make it easy to see how transformation works
+"""
+
+from typing import List, Tuple
+
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.llms.vertex_ai import CachedContentRequestBody
+from litellm.utils import is_cached_message
+
+from ..common_utils import get_supports_system_message
+from ..gemini.transformation import (
+ _gemini_convert_messages_with_history,
+ _transform_system_message,
+)
+
+
+def get_first_continuous_block_idx(
+ filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message)
+) -> int:
+ """
+ Find the array index that ends the first continuous sequence of message blocks.
+
+ Args:
+ filtered_messages: List of tuples containing (index, message) pairs
+
+ Returns:
+ int: The array index where the first continuous sequence ends
+ """
+ if not filtered_messages:
+ return -1
+
+ if len(filtered_messages) == 1:
+ return 0
+
+ current_value = filtered_messages[0][0]
+
+ # Search forward through the array indices
+ for i in range(1, len(filtered_messages)):
+ if filtered_messages[i][0] != current_value + 1:
+ return i - 1
+ current_value = filtered_messages[i][0]
+
+ # If we made it through the whole list, return the last index
+ return len(filtered_messages) - 1
+
+
+def separate_cached_messages(
+ messages: List[AllMessageValues],
+) -> Tuple[List[AllMessageValues], List[AllMessageValues]]:
+ """
+ Returns separated cached and non-cached messages.
+
+ Args:
+ messages: List of messages to be separated.
+
+ Returns:
+ Tuple containing:
+ - cached_messages: List of cached messages.
+ - non_cached_messages: List of non-cached messages.
+ """
+ cached_messages: List[AllMessageValues] = []
+ non_cached_messages: List[AllMessageValues] = []
+
+ # Extract cached messages and their indices
+ filtered_messages: List[Tuple[int, AllMessageValues]] = []
+ for idx, message in enumerate(messages):
+ if is_cached_message(message=message):
+ filtered_messages.append((idx, message))
+
+ # Validate only one block of continuous cached messages
+ last_continuous_block_idx = get_first_continuous_block_idx(filtered_messages)
+ # Separate messages based on the block of cached messages
+ if filtered_messages and last_continuous_block_idx is not None:
+ first_cached_idx = filtered_messages[0][0]
+ last_cached_idx = filtered_messages[last_continuous_block_idx][0]
+
+ cached_messages = messages[first_cached_idx : last_cached_idx + 1]
+ non_cached_messages = (
+ messages[:first_cached_idx] + messages[last_cached_idx + 1 :]
+ )
+ else:
+ non_cached_messages = messages
+
+ return cached_messages, non_cached_messages
+
+
+def transform_openai_messages_to_gemini_context_caching(
+ model: str, messages: List[AllMessageValues], cache_key: str
+) -> CachedContentRequestBody:
+ supports_system_message = get_supports_system_message(
+ model=model, custom_llm_provider="gemini"
+ )
+
+ transformed_system_messages, new_messages = _transform_system_message(
+ supports_system_message=supports_system_message, messages=messages
+ )
+
+ transformed_messages = _gemini_convert_messages_with_history(messages=new_messages)
+ data = CachedContentRequestBody(
+ contents=transformed_messages,
+ model="models/{}".format(model),
+ displayName=cache_key,
+ )
+ if transformed_system_messages is not None:
+ data["system_instruction"] = transformed_system_messages
+
+ return data
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py
new file mode 100644
index 00000000..5cfb9141
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py
@@ -0,0 +1,416 @@
+from typing import List, Literal, Optional, Tuple, Union
+
+import httpx
+
+import litellm
+from litellm.caching.caching import Cache, LiteLLMCacheType
+from litellm.litellm_core_utils.litellm_logging import Logging
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ get_async_httpx_client,
+)
+from litellm.llms.openai.openai import AllMessageValues
+from litellm.types.llms.vertex_ai import (
+ CachedContentListAllResponseBody,
+ VertexAICachedContentResponseObject,
+)
+
+from ..common_utils import VertexAIError
+from ..vertex_llm_base import VertexBase
+from .transformation import (
+ separate_cached_messages,
+ transform_openai_messages_to_gemini_context_caching,
+)
+
+local_cache_obj = Cache(
+ type=LiteLLMCacheType.LOCAL
+) # only used for calling 'get_cache_key' function
+
+
+class ContextCachingEndpoints(VertexBase):
+ """
+ Covers context caching endpoints for Vertex AI + Google AI Studio
+
+ v0: covers Google AI Studio
+ """
+
+ def __init__(self) -> None:
+ pass
+
+ def _get_token_and_url_context_caching(
+ self,
+ gemini_api_key: Optional[str],
+ custom_llm_provider: Literal["gemini"],
+ api_base: Optional[str],
+ ) -> Tuple[Optional[str], str]:
+ """
+ Internal function. Returns the token and url for the call.
+
+ Handles logic if it's google ai studio vs. vertex ai.
+
+ Returns
+ token, url
+ """
+ if custom_llm_provider == "gemini":
+ auth_header = None
+ endpoint = "cachedContents"
+ url = "https://generativelanguage.googleapis.com/v1beta/{}?key={}".format(
+ endpoint, gemini_api_key
+ )
+
+ else:
+ raise NotImplementedError
+
+ return self._check_custom_proxy(
+ api_base=api_base,
+ custom_llm_provider=custom_llm_provider,
+ gemini_api_key=gemini_api_key,
+ endpoint=endpoint,
+ stream=None,
+ auth_header=auth_header,
+ url=url,
+ )
+
+ def check_cache(
+ self,
+ cache_key: str,
+ client: HTTPHandler,
+ headers: dict,
+ api_key: str,
+ api_base: Optional[str],
+ logging_obj: Logging,
+ ) -> Optional[str]:
+ """
+ Checks if content already cached.
+
+ Currently, checks cache list, for cache key == displayName, since Google doesn't let us set the name of the cache (their API docs are out of sync with actual implementation).
+
+ Returns
+ - cached_content_name - str - cached content name stored on google. (if found.)
+ OR
+ - None
+ """
+
+ _, url = self._get_token_and_url_context_caching(
+ gemini_api_key=api_key,
+ custom_llm_provider="gemini",
+ api_base=api_base,
+ )
+ try:
+ ## LOGGING
+ logging_obj.pre_call(
+ input="",
+ api_key="",
+ additional_args={
+ "complete_input_dict": {},
+ "api_base": url,
+ "headers": headers,
+ },
+ )
+
+ resp = client.get(url=url, headers=headers)
+ resp.raise_for_status()
+ except httpx.HTTPStatusError as e:
+ if e.response.status_code == 403:
+ return None
+ raise VertexAIError(
+ status_code=e.response.status_code, message=e.response.text
+ )
+ except Exception as e:
+ raise VertexAIError(status_code=500, message=str(e))
+ raw_response = resp.json()
+ logging_obj.post_call(original_response=raw_response)
+
+ if "cachedContents" not in raw_response:
+ return None
+
+ all_cached_items = CachedContentListAllResponseBody(**raw_response)
+
+ if "cachedContents" not in all_cached_items:
+ return None
+
+ for cached_item in all_cached_items["cachedContents"]:
+ display_name = cached_item.get("displayName")
+ if display_name is not None and display_name == cache_key:
+ return cached_item.get("name")
+
+ return None
+
+ async def async_check_cache(
+ self,
+ cache_key: str,
+ client: AsyncHTTPHandler,
+ headers: dict,
+ api_key: str,
+ api_base: Optional[str],
+ logging_obj: Logging,
+ ) -> Optional[str]:
+ """
+ Checks if content already cached.
+
+ Currently, checks cache list, for cache key == displayName, since Google doesn't let us set the name of the cache (their API docs are out of sync with actual implementation).
+
+ Returns
+ - cached_content_name - str - cached content name stored on google. (if found.)
+ OR
+ - None
+ """
+
+ _, url = self._get_token_and_url_context_caching(
+ gemini_api_key=api_key,
+ custom_llm_provider="gemini",
+ api_base=api_base,
+ )
+ try:
+ ## LOGGING
+ logging_obj.pre_call(
+ input="",
+ api_key="",
+ additional_args={
+ "complete_input_dict": {},
+ "api_base": url,
+ "headers": headers,
+ },
+ )
+
+ resp = await client.get(url=url, headers=headers)
+ resp.raise_for_status()
+ except httpx.HTTPStatusError as e:
+ if e.response.status_code == 403:
+ return None
+ raise VertexAIError(
+ status_code=e.response.status_code, message=e.response.text
+ )
+ except Exception as e:
+ raise VertexAIError(status_code=500, message=str(e))
+ raw_response = resp.json()
+ logging_obj.post_call(original_response=raw_response)
+
+ if "cachedContents" not in raw_response:
+ return None
+
+ all_cached_items = CachedContentListAllResponseBody(**raw_response)
+
+ if "cachedContents" not in all_cached_items:
+ return None
+
+ for cached_item in all_cached_items["cachedContents"]:
+ display_name = cached_item.get("displayName")
+ if display_name is not None and display_name == cache_key:
+ return cached_item.get("name")
+
+ return None
+
+ def check_and_create_cache(
+ self,
+ messages: List[AllMessageValues], # receives openai format messages
+ api_key: str,
+ api_base: Optional[str],
+ model: str,
+ client: Optional[HTTPHandler],
+ timeout: Optional[Union[float, httpx.Timeout]],
+ logging_obj: Logging,
+ extra_headers: Optional[dict] = None,
+ cached_content: Optional[str] = None,
+ ) -> Tuple[List[AllMessageValues], Optional[str]]:
+ """
+ Receives
+ - messages: List of dict - messages in the openai format
+
+ Returns
+ - messages - List[dict] - filtered list of messages in the openai format.
+ - cached_content - str - the cache content id, to be passed in the gemini request body
+
+ Follows - https://ai.google.dev/api/caching#request-body
+ """
+ if cached_content is not None:
+ return messages, cached_content
+
+ ## AUTHORIZATION ##
+ token, url = self._get_token_and_url_context_caching(
+ gemini_api_key=api_key,
+ custom_llm_provider="gemini",
+ api_base=api_base,
+ )
+
+ headers = {
+ "Content-Type": "application/json",
+ }
+ if token is not None:
+ headers["Authorization"] = f"Bearer {token}"
+ if extra_headers is not None:
+ headers.update(extra_headers)
+
+ if client is None or not isinstance(client, HTTPHandler):
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ timeout = httpx.Timeout(timeout)
+ _params["timeout"] = timeout
+ client = HTTPHandler(**_params) # type: ignore
+ else:
+ client = client
+
+ cached_messages, non_cached_messages = separate_cached_messages(
+ messages=messages
+ )
+
+ if len(cached_messages) == 0:
+ return messages, None
+
+ ## CHECK IF CACHED ALREADY
+ generated_cache_key = local_cache_obj.get_cache_key(messages=cached_messages)
+ google_cache_name = self.check_cache(
+ cache_key=generated_cache_key,
+ client=client,
+ headers=headers,
+ api_key=api_key,
+ api_base=api_base,
+ logging_obj=logging_obj,
+ )
+ if google_cache_name:
+ return non_cached_messages, google_cache_name
+
+ ## TRANSFORM REQUEST
+ cached_content_request_body = (
+ transform_openai_messages_to_gemini_context_caching(
+ model=model, messages=cached_messages, cache_key=generated_cache_key
+ )
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=messages,
+ api_key="",
+ additional_args={
+ "complete_input_dict": cached_content_request_body,
+ "api_base": url,
+ "headers": headers,
+ },
+ )
+
+ try:
+ response = client.post(
+ url=url, headers=headers, json=cached_content_request_body # type: ignore
+ )
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise VertexAIError(status_code=error_code, message=err.response.text)
+ except httpx.TimeoutException:
+ raise VertexAIError(status_code=408, message="Timeout error occurred.")
+
+ raw_response_cached = response.json()
+ cached_content_response_obj = VertexAICachedContentResponseObject(
+ name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
+ )
+ return (non_cached_messages, cached_content_response_obj["name"])
+
+ async def async_check_and_create_cache(
+ self,
+ messages: List[AllMessageValues], # receives openai format messages
+ api_key: str,
+ api_base: Optional[str],
+ model: str,
+ client: Optional[AsyncHTTPHandler],
+ timeout: Optional[Union[float, httpx.Timeout]],
+ logging_obj: Logging,
+ extra_headers: Optional[dict] = None,
+ cached_content: Optional[str] = None,
+ ) -> Tuple[List[AllMessageValues], Optional[str]]:
+ """
+ Receives
+ - messages: List of dict - messages in the openai format
+
+ Returns
+ - messages - List[dict] - filtered list of messages in the openai format.
+ - cached_content - str - the cache content id, to be passed in the gemini request body
+
+ Follows - https://ai.google.dev/api/caching#request-body
+ """
+ if cached_content is not None:
+ return messages, cached_content
+
+ cached_messages, non_cached_messages = separate_cached_messages(
+ messages=messages
+ )
+
+ if len(cached_messages) == 0:
+ return messages, None
+
+ ## AUTHORIZATION ##
+ token, url = self._get_token_and_url_context_caching(
+ gemini_api_key=api_key,
+ custom_llm_provider="gemini",
+ api_base=api_base,
+ )
+
+ headers = {
+ "Content-Type": "application/json",
+ }
+ if token is not None:
+ headers["Authorization"] = f"Bearer {token}"
+ if extra_headers is not None:
+ headers.update(extra_headers)
+
+ if client is None or not isinstance(client, AsyncHTTPHandler):
+ client = get_async_httpx_client(
+ params={"timeout": timeout}, llm_provider=litellm.LlmProviders.VERTEX_AI
+ )
+ else:
+ client = client
+
+ ## CHECK IF CACHED ALREADY
+ generated_cache_key = local_cache_obj.get_cache_key(messages=cached_messages)
+ google_cache_name = await self.async_check_cache(
+ cache_key=generated_cache_key,
+ client=client,
+ headers=headers,
+ api_key=api_key,
+ api_base=api_base,
+ logging_obj=logging_obj,
+ )
+ if google_cache_name:
+ return non_cached_messages, google_cache_name
+
+ ## TRANSFORM REQUEST
+ cached_content_request_body = (
+ transform_openai_messages_to_gemini_context_caching(
+ model=model, messages=cached_messages, cache_key=generated_cache_key
+ )
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=messages,
+ api_key="",
+ additional_args={
+ "complete_input_dict": cached_content_request_body,
+ "api_base": url,
+ "headers": headers,
+ },
+ )
+
+ try:
+ response = await client.post(
+ url=url, headers=headers, json=cached_content_request_body # type: ignore
+ )
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise VertexAIError(status_code=error_code, message=err.response.text)
+ except httpx.TimeoutException:
+ raise VertexAIError(status_code=408, message="Timeout error occurred.")
+
+ raw_response_cached = response.json()
+ cached_content_response_obj = VertexAICachedContentResponseObject(
+ name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
+ )
+ return (non_cached_messages, cached_content_response_obj["name"])
+
+ def get_cache(self):
+ pass
+
+ async def async_get_cache(self):
+ pass
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/cost_calculator.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/cost_calculator.py
new file mode 100644
index 00000000..fd238860
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/cost_calculator.py
@@ -0,0 +1,242 @@
+# What is this?
+## Cost calculation for Google AI Studio / Vertex AI models
+from typing import Literal, Optional, Tuple, Union
+
+import litellm
+from litellm import verbose_logger
+from litellm.litellm_core_utils.llm_cost_calc.utils import _is_above_128k
+
+"""
+Gemini pricing covers:
+- token
+- image
+- audio
+- video
+"""
+
+"""
+Vertex AI -> character based pricing
+
+Google AI Studio -> token based pricing
+"""
+
+models_without_dynamic_pricing = ["gemini-1.0-pro", "gemini-pro"]
+
+
+def cost_router(
+ model: str,
+ custom_llm_provider: str,
+ call_type: Union[Literal["embedding", "aembedding"], str],
+) -> Literal["cost_per_character", "cost_per_token"]:
+ """
+ Route the cost calc to the right place, based on model/call_type/etc.
+
+ Returns
+ - str, the specific google cost calc function it should route to.
+ """
+ if custom_llm_provider == "vertex_ai" and (
+ "claude" in model
+ or "llama" in model
+ or "mistral" in model
+ or "jamba" in model
+ or "codestral" in model
+ ):
+ return "cost_per_token"
+ elif custom_llm_provider == "vertex_ai" and (
+ call_type == "embedding" or call_type == "aembedding"
+ ):
+ return "cost_per_token"
+ return "cost_per_character"
+
+
+def cost_per_character(
+ model: str,
+ custom_llm_provider: str,
+ prompt_tokens: float,
+ completion_tokens: float,
+ prompt_characters: Optional[float] = None,
+ completion_characters: Optional[float] = None,
+) -> Tuple[float, float]:
+ """
+ Calculates the cost per character for a given VertexAI model, input messages, and response object.
+
+ Input:
+ - model: str, the model name without provider prefix
+ - custom_llm_provider: str, "vertex_ai-*"
+ - prompt_characters: float, the number of input characters
+ - completion_characters: float, the number of output characters
+
+ Returns:
+ Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
+
+ Raises:
+ Exception if model requires >128k pricing, but model cost not mapped
+ """
+ model_info = litellm.get_model_info(
+ model=model, custom_llm_provider=custom_llm_provider
+ )
+
+ ## GET MODEL INFO
+ model_info = litellm.get_model_info(
+ model=model, custom_llm_provider=custom_llm_provider
+ )
+
+ ## CALCULATE INPUT COST
+ if prompt_characters is None:
+ prompt_cost, _ = cost_per_token(
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ )
+ else:
+ try:
+ if (
+ _is_above_128k(tokens=prompt_characters * 4) # 1 token = 4 char
+ and model not in models_without_dynamic_pricing
+ ):
+ ## check if character pricing, else default to token pricing
+ assert (
+ "input_cost_per_character_above_128k_tokens" in model_info
+ and model_info["input_cost_per_character_above_128k_tokens"]
+ is not None
+ ), "model info for model={} does not have 'input_cost_per_character_above_128k_tokens'-pricing for > 128k tokens\nmodel_info={}".format(
+ model, model_info
+ )
+ prompt_cost = (
+ prompt_characters
+ * model_info["input_cost_per_character_above_128k_tokens"]
+ )
+ else:
+ assert (
+ "input_cost_per_character" in model_info
+ and model_info["input_cost_per_character"] is not None
+ ), "model info for model={} does not have 'input_cost_per_character'-pricing\nmodel_info={}".format(
+ model, model_info
+ )
+ prompt_cost = prompt_characters * model_info["input_cost_per_character"]
+ except Exception as e:
+ verbose_logger.debug(
+ "litellm.litellm_core_utils.llm_cost_calc.google.py::cost_per_character(): Exception occured - {}\nDefaulting to None".format(
+ str(e)
+ )
+ )
+ prompt_cost, _ = cost_per_token(
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ )
+
+ ## CALCULATE OUTPUT COST
+ if completion_characters is None:
+ _, completion_cost = cost_per_token(
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ )
+ else:
+ try:
+ if (
+ _is_above_128k(tokens=completion_characters * 4) # 1 token = 4 char
+ and model not in models_without_dynamic_pricing
+ ):
+ assert (
+ "output_cost_per_character_above_128k_tokens" in model_info
+ and model_info["output_cost_per_character_above_128k_tokens"]
+ is not None
+ ), "model info for model={} does not have 'output_cost_per_character_above_128k_tokens' pricing\nmodel_info={}".format(
+ model, model_info
+ )
+ completion_cost = (
+ completion_tokens
+ * model_info["output_cost_per_character_above_128k_tokens"]
+ )
+ else:
+ assert (
+ "output_cost_per_character" in model_info
+ and model_info["output_cost_per_character"] is not None
+ ), "model info for model={} does not have 'output_cost_per_character'-pricing\nmodel_info={}".format(
+ model, model_info
+ )
+ completion_cost = (
+ completion_characters * model_info["output_cost_per_character"]
+ )
+ except Exception as e:
+ verbose_logger.debug(
+ "litellm.litellm_core_utils.llm_cost_calc.google.py::cost_per_character(): Exception occured - {}\nDefaulting to None".format(
+ str(e)
+ )
+ )
+ _, completion_cost = cost_per_token(
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ )
+
+ return prompt_cost, completion_cost
+
+
+def cost_per_token(
+ model: str,
+ custom_llm_provider: str,
+ prompt_tokens: float,
+ completion_tokens: float,
+) -> Tuple[float, float]:
+ """
+ Calculates the cost per token for a given model, prompt tokens, and completion tokens.
+
+ Input:
+ - model: str, the model name without provider prefix
+ - custom_llm_provider: str, either "vertex_ai-*" or "gemini"
+ - prompt_tokens: float, the number of input tokens
+ - completion_tokens: float, the number of output tokens
+
+ Returns:
+ Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
+
+ Raises:
+ Exception if model requires >128k pricing, but model cost not mapped
+ """
+ ## GET MODEL INFO
+ model_info = litellm.get_model_info(
+ model=model, custom_llm_provider=custom_llm_provider
+ )
+
+ ## CALCULATE INPUT COST
+ if (
+ _is_above_128k(tokens=prompt_tokens)
+ and model not in models_without_dynamic_pricing
+ ):
+ assert (
+ "input_cost_per_token_above_128k_tokens" in model_info
+ and model_info["input_cost_per_token_above_128k_tokens"] is not None
+ ), "model info for model={} does not have pricing for > 128k tokens\nmodel_info={}".format(
+ model, model_info
+ )
+ prompt_cost = (
+ prompt_tokens * model_info["input_cost_per_token_above_128k_tokens"]
+ )
+ else:
+ prompt_cost = prompt_tokens * model_info["input_cost_per_token"]
+
+ ## CALCULATE OUTPUT COST
+ if (
+ _is_above_128k(tokens=completion_tokens)
+ and model not in models_without_dynamic_pricing
+ ):
+ assert (
+ "output_cost_per_token_above_128k_tokens" in model_info
+ and model_info["output_cost_per_token_above_128k_tokens"] is not None
+ ), "model info for model={} does not have pricing for > 128k tokens\nmodel_info={}".format(
+ model, model_info
+ )
+ completion_cost = (
+ completion_tokens * model_info["output_cost_per_token_above_128k_tokens"]
+ )
+ else:
+ completion_cost = completion_tokens * model_info["output_cost_per_token"]
+
+ return prompt_cost, completion_cost
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/files/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/files/handler.py
new file mode 100644
index 00000000..266169cd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/files/handler.py
@@ -0,0 +1,97 @@
+from typing import Any, Coroutine, Optional, Union
+
+import httpx
+
+from litellm import LlmProviders
+from litellm.integrations.gcs_bucket.gcs_bucket_base import (
+ GCSBucketBase,
+ GCSLoggingConfig,
+)
+from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
+from litellm.types.llms.openai import CreateFileRequest, FileObject
+from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
+
+from .transformation import VertexAIFilesTransformation
+
+vertex_ai_files_transformation = VertexAIFilesTransformation()
+
+
+class VertexAIFilesHandler(GCSBucketBase):
+ """
+ Handles Calling VertexAI in OpenAI Files API format v1/files/*
+
+ This implementation uploads files on GCS Buckets
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.async_httpx_client = get_async_httpx_client(
+ llm_provider=LlmProviders.VERTEX_AI,
+ )
+
+ pass
+
+ async def async_create_file(
+ self,
+ create_file_data: CreateFileRequest,
+ api_base: Optional[str],
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ vertex_project: Optional[str],
+ vertex_location: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ ):
+ gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
+ kwargs={}
+ )
+ headers = await self.construct_request_headers(
+ vertex_instance=gcs_logging_config["vertex_instance"],
+ service_account_json=gcs_logging_config["path_service_account"],
+ )
+ bucket_name = gcs_logging_config["bucket_name"]
+ logging_payload, object_name = (
+ vertex_ai_files_transformation.transform_openai_file_content_to_vertex_ai_file_content(
+ openai_file_content=create_file_data.get("file")
+ )
+ )
+ gcs_upload_response = await self._log_json_data_on_gcs(
+ headers=headers,
+ bucket_name=bucket_name,
+ object_name=object_name,
+ logging_payload=logging_payload,
+ )
+
+ return vertex_ai_files_transformation.transform_gcs_bucket_response_to_openai_file_object(
+ create_file_data=create_file_data,
+ gcs_upload_response=gcs_upload_response,
+ )
+
+ def create_file(
+ self,
+ _is_async: bool,
+ create_file_data: CreateFileRequest,
+ api_base: Optional[str],
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ vertex_project: Optional[str],
+ vertex_location: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ max_retries: Optional[int],
+ ) -> Union[FileObject, Coroutine[Any, Any, FileObject]]:
+ """
+ Creates a file on VertexAI GCS Bucket
+
+ Only supported for Async litellm.acreate_file
+ """
+
+ if _is_async:
+ return self.async_create_file(
+ create_file_data=create_file_data,
+ api_base=api_base,
+ vertex_credentials=vertex_credentials,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ timeout=timeout,
+ max_retries=max_retries,
+ )
+
+ return None # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/files/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/files/transformation.py
new file mode 100644
index 00000000..a124e205
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/files/transformation.py
@@ -0,0 +1,163 @@
+import json
+import uuid
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from litellm.llms.vertex_ai.common_utils import (
+ _convert_vertex_datetime_to_openai_datetime,
+)
+from litellm.llms.vertex_ai.gemini.transformation import _transform_request_body
+from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
+ VertexGeminiConfig,
+)
+from litellm.types.llms.openai import CreateFileRequest, FileObject, FileTypes, PathLike
+
+
+class VertexAIFilesTransformation(VertexGeminiConfig):
+ """
+ Transforms OpenAI /v1/files/* requests to VertexAI /v1/files/* requests
+ """
+
+ def transform_openai_file_content_to_vertex_ai_file_content(
+ self, openai_file_content: Optional[FileTypes] = None
+ ) -> Tuple[str, str]:
+ """
+ Transforms OpenAI FileContentRequest to VertexAI FileContentRequest
+ """
+
+ if openai_file_content is None:
+ raise ValueError("contents of file are None")
+ # Read the content of the file
+ file_content = self._get_content_from_openai_file(openai_file_content)
+
+ # Split into lines and parse each line as JSON
+ openai_jsonl_content = [
+ json.loads(line) for line in file_content.splitlines() if line.strip()
+ ]
+ vertex_jsonl_content = (
+ self._transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
+ openai_jsonl_content
+ )
+ )
+ vertex_jsonl_string = "\n".join(
+ json.dumps(item) for item in vertex_jsonl_content
+ )
+ object_name = self._get_gcs_object_name(
+ openai_jsonl_content=openai_jsonl_content
+ )
+ return vertex_jsonl_string, object_name
+
+ def _transform_openai_jsonl_content_to_vertex_ai_jsonl_content(
+ self, openai_jsonl_content: List[Dict[str, Any]]
+ ):
+ """
+ Transforms OpenAI JSONL content to VertexAI JSONL content
+
+ jsonl body for vertex is {"request": <request_body>}
+ Example Vertex jsonl
+ {"request":{"contents": [{"role": "user", "parts": [{"text": "What is the relation between the following video and image samples?"}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/animals.mp4", "mimeType": "video/mp4"}}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/image/cricket.jpeg", "mimeType": "image/jpeg"}}]}]}}
+ {"request":{"contents": [{"role": "user", "parts": [{"text": "Describe what is happening in this video."}, {"fileData": {"fileUri": "gs://cloud-samples-data/generative-ai/video/another_video.mov", "mimeType": "video/mov"}}]}]}}
+ """
+
+ vertex_jsonl_content = []
+ for _openai_jsonl_content in openai_jsonl_content:
+ openai_request_body = _openai_jsonl_content.get("body") or {}
+ vertex_request_body = _transform_request_body(
+ messages=openai_request_body.get("messages", []),
+ model=openai_request_body.get("model", ""),
+ optional_params=self._map_openai_to_vertex_params(openai_request_body),
+ custom_llm_provider="vertex_ai",
+ litellm_params={},
+ cached_content=None,
+ )
+ vertex_jsonl_content.append({"request": vertex_request_body})
+ return vertex_jsonl_content
+
+ def _get_gcs_object_name(
+ self,
+ openai_jsonl_content: List[Dict[str, Any]],
+ ) -> str:
+ """
+ Gets a unique GCS object name for the VertexAI batch prediction job
+
+ named as: litellm-vertex-{model}-{uuid}
+ """
+ _model = openai_jsonl_content[0].get("body", {}).get("model", "")
+ if "publishers/google/models" not in _model:
+ _model = f"publishers/google/models/{_model}"
+ object_name = f"litellm-vertex-files/{_model}/{uuid.uuid4()}"
+ return object_name
+
+ def _map_openai_to_vertex_params(
+ self,
+ openai_request_body: Dict[str, Any],
+ ) -> Dict[str, Any]:
+ """
+ wrapper to call VertexGeminiConfig.map_openai_params
+ """
+ _model = openai_request_body.get("model", "")
+ vertex_params = self.map_openai_params(
+ model=_model,
+ non_default_params=openai_request_body,
+ optional_params={},
+ drop_params=False,
+ )
+ return vertex_params
+
+ def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
+ """
+ Helper to extract content from various OpenAI file types and return as string.
+
+ Handles:
+ - Direct content (str, bytes, IO[bytes])
+ - Tuple formats: (filename, content, [content_type], [headers])
+ - PathLike objects
+ """
+ content: Union[str, bytes] = b""
+ # Extract file content from tuple if necessary
+ if isinstance(openai_file_content, tuple):
+ # Take the second element which is always the file content
+ file_content = openai_file_content[1]
+ else:
+ file_content = openai_file_content
+
+ # Handle different file content types
+ if isinstance(file_content, str):
+ # String content can be used directly
+ content = file_content
+ elif isinstance(file_content, bytes):
+ # Bytes content can be decoded
+ content = file_content
+ elif isinstance(file_content, PathLike): # PathLike
+ with open(str(file_content), "rb") as f:
+ content = f.read()
+ elif hasattr(file_content, "read"): # IO[bytes]
+ # File-like objects need to be read
+ content = file_content.read()
+
+ # Ensure content is string
+ if isinstance(content, bytes):
+ content = content.decode("utf-8")
+
+ return content
+
+ def transform_gcs_bucket_response_to_openai_file_object(
+ self, create_file_data: CreateFileRequest, gcs_upload_response: Dict[str, Any]
+ ) -> FileObject:
+ """
+ Transforms GCS Bucket upload file response to OpenAI FileObject
+ """
+ gcs_id = gcs_upload_response.get("id", "")
+ # Remove the last numeric ID from the path
+ gcs_id = "/".join(gcs_id.split("/")[:-1]) if gcs_id else ""
+
+ return FileObject(
+ purpose=create_file_data.get("purpose", "batch"),
+ id=f"gs://{gcs_id}",
+ filename=gcs_upload_response.get("name", ""),
+ created_at=_convert_vertex_datetime_to_openai_datetime(
+ vertex_datetime=gcs_upload_response.get("timeCreated", "")
+ ),
+ status="uploaded",
+ bytes=gcs_upload_response.get("size", 0),
+ object="file",
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/fine_tuning/handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/fine_tuning/handler.py
new file mode 100644
index 00000000..3cf409c7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/fine_tuning/handler.py
@@ -0,0 +1,371 @@
+import json
+import traceback
+from datetime import datetime
+from typing import Literal, Optional, Union
+
+import httpx
+from openai.types.fine_tuning.fine_tuning_job import FineTuningJob
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.llms.custom_httpx.http_handler import HTTPHandler, get_async_httpx_client
+from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
+from litellm.types.fine_tuning import OpenAIFineTuningHyperparameters
+from litellm.types.llms.openai import FineTuningJobCreate
+from litellm.types.llms.vertex_ai import (
+ VERTEX_CREDENTIALS_TYPES,
+ FineTuneHyperparameters,
+ FineTuneJobCreate,
+ FineTunesupervisedTuningSpec,
+ ResponseSupervisedTuningSpec,
+ ResponseTuningJob,
+)
+
+
+class VertexFineTuningAPI(VertexLLM):
+ """
+ Vertex methods to support for batches
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.async_handler = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.VERTEX_AI,
+ params={"timeout": 600.0},
+ )
+
+ def convert_response_created_at(self, response: ResponseTuningJob):
+ try:
+
+ create_time_str = response.get("createTime", "") or ""
+ create_time_datetime = datetime.fromisoformat(
+ create_time_str.replace("Z", "+00:00")
+ )
+ # Convert to Unix timestamp (seconds since epoch)
+ created_at = int(create_time_datetime.timestamp())
+
+ return created_at
+ except Exception:
+ return 0
+
+ def convert_openai_request_to_vertex(
+ self,
+ create_fine_tuning_job_data: FineTuningJobCreate,
+ original_hyperparameters: dict = {},
+ kwargs: Optional[dict] = None,
+ ) -> FineTuneJobCreate:
+ """
+ convert request from OpenAI format to Vertex format
+ https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/tuning
+ supervised_tuning_spec = FineTunesupervisedTuningSpec(
+ """
+
+ supervised_tuning_spec = FineTunesupervisedTuningSpec(
+ training_dataset_uri=create_fine_tuning_job_data.training_file,
+ )
+
+ if create_fine_tuning_job_data.validation_file:
+ supervised_tuning_spec["validation_dataset"] = (
+ create_fine_tuning_job_data.validation_file
+ )
+
+ _vertex_hyperparameters = (
+ self._transform_openai_hyperparameters_to_vertex_hyperparameters(
+ create_fine_tuning_job_data=create_fine_tuning_job_data,
+ kwargs=kwargs,
+ original_hyperparameters=original_hyperparameters,
+ )
+ )
+
+ if _vertex_hyperparameters and len(_vertex_hyperparameters) > 0:
+ supervised_tuning_spec["hyperParameters"] = _vertex_hyperparameters
+
+ fine_tune_job = FineTuneJobCreate(
+ baseModel=create_fine_tuning_job_data.model,
+ supervisedTuningSpec=supervised_tuning_spec,
+ tunedModelDisplayName=create_fine_tuning_job_data.suffix,
+ )
+
+ return fine_tune_job
+
+ def _transform_openai_hyperparameters_to_vertex_hyperparameters(
+ self,
+ create_fine_tuning_job_data: FineTuningJobCreate,
+ original_hyperparameters: dict = {},
+ kwargs: Optional[dict] = None,
+ ) -> FineTuneHyperparameters:
+ _oai_hyperparameters = create_fine_tuning_job_data.hyperparameters
+ _vertex_hyperparameters = FineTuneHyperparameters()
+ if _oai_hyperparameters:
+ if _oai_hyperparameters.n_epochs:
+ _vertex_hyperparameters["epoch_count"] = int(
+ _oai_hyperparameters.n_epochs
+ )
+ if _oai_hyperparameters.learning_rate_multiplier:
+ _vertex_hyperparameters["learning_rate_multiplier"] = float(
+ _oai_hyperparameters.learning_rate_multiplier
+ )
+
+ _adapter_size = original_hyperparameters.get("adapter_size", None)
+ if _adapter_size:
+ _vertex_hyperparameters["adapter_size"] = _adapter_size
+
+ return _vertex_hyperparameters
+
+ def convert_vertex_response_to_open_ai_response(
+ self, response: ResponseTuningJob
+ ) -> FineTuningJob:
+ status: Literal[
+ "validating_files", "queued", "running", "succeeded", "failed", "cancelled"
+ ] = "queued"
+ if response["state"] == "JOB_STATE_PENDING":
+ status = "queued"
+ if response["state"] == "JOB_STATE_SUCCEEDED":
+ status = "succeeded"
+ if response["state"] == "JOB_STATE_FAILED":
+ status = "failed"
+ if response["state"] == "JOB_STATE_CANCELLED":
+ status = "cancelled"
+ if response["state"] == "JOB_STATE_RUNNING":
+ status = "running"
+
+ created_at = self.convert_response_created_at(response)
+
+ _supervisedTuningSpec: ResponseSupervisedTuningSpec = (
+ response.get("supervisedTuningSpec", None) or {}
+ )
+ training_uri: str = _supervisedTuningSpec.get("trainingDatasetUri", "") or ""
+ return FineTuningJob(
+ id=response.get("name", "") or "",
+ created_at=created_at,
+ fine_tuned_model=response.get("tunedModelDisplayName", ""),
+ finished_at=None,
+ hyperparameters=self._translate_vertex_response_hyperparameters(
+ vertex_hyper_parameters=_supervisedTuningSpec.get("hyperParameters", {})
+ or {}
+ ),
+ model=response.get("baseModel", "") or "",
+ object="fine_tuning.job",
+ organization_id="",
+ result_files=[],
+ seed=0,
+ status=status,
+ trained_tokens=None,
+ training_file=training_uri,
+ validation_file=None,
+ estimated_finish=None,
+ integrations=[],
+ )
+
+ def _translate_vertex_response_hyperparameters(
+ self, vertex_hyper_parameters: FineTuneHyperparameters
+ ) -> OpenAIFineTuningHyperparameters:
+ """
+ translate vertex responsehyperparameters to openai hyperparameters
+ """
+ _dict_remaining_hyperparameters: dict = dict(vertex_hyper_parameters)
+ return OpenAIFineTuningHyperparameters(
+ n_epochs=_dict_remaining_hyperparameters.pop("epoch_count", 0),
+ **_dict_remaining_hyperparameters,
+ )
+
+ async def acreate_fine_tuning_job(
+ self,
+ fine_tuning_url: str,
+ headers: dict,
+ request_data: FineTuneJobCreate,
+ ):
+
+ try:
+ verbose_logger.debug(
+ "about to create fine tuning job: %s, request_data: %s",
+ fine_tuning_url,
+ json.dumps(request_data, indent=4),
+ )
+ if self.async_handler is None:
+ raise ValueError(
+ "VertexAI Fine Tuning - async_handler is not initialized"
+ )
+ response = await self.async_handler.post(
+ headers=headers,
+ url=fine_tuning_url,
+ json=request_data, # type: ignore
+ )
+
+ if response.status_code != 200:
+ raise Exception(
+ f"Error creating fine tuning job. Status code: {response.status_code}. Response: {response.text}"
+ )
+
+ verbose_logger.debug(
+ "got response from creating fine tuning job: %s", response.json()
+ )
+
+ vertex_response = ResponseTuningJob( # type: ignore
+ **response.json(),
+ )
+
+ verbose_logger.debug("vertex_response %s", vertex_response)
+ open_ai_response = self.convert_vertex_response_to_open_ai_response(
+ vertex_response
+ )
+ return open_ai_response
+
+ except Exception as e:
+ verbose_logger.error("asyncerror creating fine tuning job %s", e)
+ trace_back_str = traceback.format_exc()
+ verbose_logger.error(trace_back_str)
+ raise e
+
+ def create_fine_tuning_job(
+ self,
+ _is_async: bool,
+ create_fine_tuning_job_data: FineTuningJobCreate,
+ vertex_project: Optional[str],
+ vertex_location: Optional[str],
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ kwargs: Optional[dict] = None,
+ original_hyperparameters: Optional[dict] = {},
+ ):
+
+ verbose_logger.debug(
+ "creating fine tuning job, args= %s", create_fine_tuning_job_data
+ )
+ _auth_header, vertex_project = self._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider="vertex_ai_beta",
+ )
+
+ auth_header, _ = self._get_token_and_url(
+ model="",
+ auth_header=_auth_header,
+ gemini_api_key=None,
+ vertex_credentials=vertex_credentials,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ stream=False,
+ custom_llm_provider="vertex_ai_beta",
+ api_base=api_base,
+ )
+
+ headers = {
+ "Authorization": f"Bearer {auth_header}",
+ "Content-Type": "application/json",
+ }
+
+ fine_tune_job = self.convert_openai_request_to_vertex(
+ create_fine_tuning_job_data=create_fine_tuning_job_data,
+ kwargs=kwargs,
+ original_hyperparameters=original_hyperparameters or {},
+ )
+
+ fine_tuning_url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs"
+ if _is_async is True:
+ return self.acreate_fine_tuning_job( # type: ignore
+ fine_tuning_url=fine_tuning_url,
+ headers=headers,
+ request_data=fine_tune_job,
+ )
+ sync_handler = HTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
+
+ verbose_logger.debug(
+ "about to create fine tuning job: %s, request_data: %s",
+ fine_tuning_url,
+ fine_tune_job,
+ )
+ response = sync_handler.post(
+ headers=headers,
+ url=fine_tuning_url,
+ json=fine_tune_job, # type: ignore
+ )
+
+ if response.status_code != 200:
+ raise Exception(
+ f"Error creating fine tuning job. Status code: {response.status_code}. Response: {response.text}"
+ )
+
+ verbose_logger.debug(
+ "got response from creating fine tuning job: %s", response.json()
+ )
+ vertex_response = ResponseTuningJob( # type: ignore
+ **response.json(),
+ )
+
+ verbose_logger.debug("vertex_response %s", vertex_response)
+ open_ai_response = self.convert_vertex_response_to_open_ai_response(
+ vertex_response
+ )
+ return open_ai_response
+
+ async def pass_through_vertex_ai_POST_request(
+ self,
+ request_data: dict,
+ vertex_project: str,
+ vertex_location: str,
+ vertex_credentials: str,
+ request_route: str,
+ ):
+ _auth_header, vertex_project = await self._ensure_access_token_async(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider="vertex_ai_beta",
+ )
+ auth_header, _ = self._get_token_and_url(
+ model="",
+ auth_header=_auth_header,
+ gemini_api_key=None,
+ vertex_credentials=vertex_credentials,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ stream=False,
+ custom_llm_provider="vertex_ai_beta",
+ api_base="",
+ )
+
+ headers = {
+ "Authorization": f"Bearer {auth_header}",
+ "Content-Type": "application/json",
+ }
+
+ url = None
+ if request_route == "/tuningJobs":
+ url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs"
+ elif "/tuningJobs/" in request_route and "cancel" in request_route:
+ url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/tuningJobs{request_route}"
+ elif "generateContent" in request_route:
+ url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
+ elif "predict" in request_route:
+ url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
+ elif "/batchPredictionJobs" in request_route:
+ url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
+ elif "countTokens" in request_route:
+ url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
+ elif "cachedContents" in request_route:
+ _model = request_data.get("model")
+ if _model is not None and "/publishers/google/models/" not in _model:
+ request_data["model"] = (
+ f"projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{_model}"
+ )
+
+ url = f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}{request_route}"
+ else:
+ raise ValueError(f"Unsupported Vertex AI request route: {request_route}")
+ if self.async_handler is None:
+ raise ValueError("VertexAI Fine Tuning - async_handler is not initialized")
+
+ response = await self.async_handler.post(
+ headers=headers,
+ url=url,
+ json=request_data, # type: ignore
+ )
+
+ if response.status_code != 200:
+ raise Exception(
+ f"Error creating fine tuning job. Status code: {response.status_code}. Response: {response.text}"
+ )
+
+ response_json = response.json()
+ return response_json
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py
new file mode 100644
index 00000000..d6bafc7c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py
@@ -0,0 +1,479 @@
+"""
+Transformation logic from OpenAI format to Gemini format.
+
+Why separate file? Make it easy to see how transformation works
+"""
+
+import os
+from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union, cast
+
+import httpx
+from pydantic import BaseModel
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.litellm_core_utils.prompt_templates.factory import (
+ convert_to_anthropic_image_obj,
+ convert_to_gemini_tool_call_invoke,
+ convert_to_gemini_tool_call_result,
+ response_schema_prompt,
+)
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
+from litellm.types.files import (
+ get_file_mime_type_for_file_type,
+ get_file_type_from_extension,
+ is_gemini_1_5_accepted_file_type,
+)
+from litellm.types.llms.openai import (
+ AllMessageValues,
+ ChatCompletionAssistantMessage,
+ ChatCompletionImageObject,
+ ChatCompletionTextObject,
+)
+from litellm.types.llms.vertex_ai import *
+from litellm.types.llms.vertex_ai import (
+ GenerationConfig,
+ PartType,
+ RequestBody,
+ SafetSettingsConfig,
+ SystemInstructions,
+ ToolConfig,
+ Tools,
+)
+
+from ..common_utils import (
+ _check_text_in_content,
+ get_supports_response_schema,
+ get_supports_system_message,
+)
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
+
+ LiteLLMLoggingObj = _LiteLLMLoggingObj
+else:
+ LiteLLMLoggingObj = Any
+
+
+def _process_gemini_image(image_url: str, format: Optional[str] = None) -> PartType:
+ """
+ Given an image URL, return the appropriate PartType for Gemini
+ """
+
+ try:
+ # GCS URIs
+ if "gs://" in image_url:
+ # Figure out file type
+ extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
+ extension = extension_with_dot[1:] # Ex: "png"
+
+ if not format:
+ file_type = get_file_type_from_extension(extension)
+
+ # Validate the file type is supported by Gemini
+ if not is_gemini_1_5_accepted_file_type(file_type):
+ raise Exception(f"File type not supported by gemini - {file_type}")
+
+ mime_type = get_file_mime_type_for_file_type(file_type)
+ else:
+ mime_type = format
+ file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
+
+ return PartType(file_data=file_data)
+ elif (
+ "https://" in image_url
+ and (image_type := format or _get_image_mime_type_from_url(image_url))
+ is not None
+ ):
+
+ file_data = FileDataType(file_uri=image_url, mime_type=image_type)
+ return PartType(file_data=file_data)
+ elif "http://" in image_url or "https://" in image_url or "base64" in image_url:
+ # https links for unsupported mime types and base64 images
+ image = convert_to_anthropic_image_obj(image_url, format=format)
+ _blob = BlobType(data=image["data"], mime_type=image["media_type"])
+ return PartType(inline_data=_blob)
+ raise Exception("Invalid image received - {}".format(image_url))
+ except Exception as e:
+ raise e
+
+
+def _get_image_mime_type_from_url(url: str) -> Optional[str]:
+ """
+ Get mime type for common image URLs
+ See gemini mime types: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/image-understanding#image-requirements
+
+ Supported by Gemini:
+ - PNG (`image/png`)
+ - JPEG (`image/jpeg`)
+ - WebP (`image/webp`)
+ Example:
+ url = https://example.com/image.jpg
+ Returns: image/jpeg
+ """
+ url = url.lower()
+ if url.endswith((".jpg", ".jpeg")):
+ return "image/jpeg"
+ elif url.endswith(".png"):
+ return "image/png"
+ elif url.endswith(".webp"):
+ return "image/webp"
+ elif url.endswith(".mp4"):
+ return "video/mp4"
+ elif url.endswith(".pdf"):
+ return "application/pdf"
+ return None
+
+
+def _gemini_convert_messages_with_history( # noqa: PLR0915
+ messages: List[AllMessageValues],
+) -> List[ContentType]:
+ """
+ Converts given messages from OpenAI format to Gemini format
+
+ - Parts must be iterable
+ - Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
+ - Please ensure that function response turn comes immediately after a function call turn
+ """
+ user_message_types = {"user", "system"}
+ contents: List[ContentType] = []
+
+ last_message_with_tool_calls = None
+
+ msg_i = 0
+ tool_call_responses = []
+ try:
+ while msg_i < len(messages):
+ user_content: List[PartType] = []
+ init_msg_i = msg_i
+ ## MERGE CONSECUTIVE USER CONTENT ##
+ while (
+ msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
+ ):
+ _message_content = messages[msg_i].get("content")
+ if _message_content is not None and isinstance(_message_content, list):
+ _parts: List[PartType] = []
+ for element in _message_content:
+ if (
+ element["type"] == "text"
+ and "text" in element
+ and len(element["text"]) > 0
+ ):
+ element = cast(ChatCompletionTextObject, element)
+ _part = PartType(text=element["text"])
+ _parts.append(_part)
+ elif element["type"] == "image_url":
+ element = cast(ChatCompletionImageObject, element)
+ img_element = element
+ format: Optional[str] = None
+ if isinstance(img_element["image_url"], dict):
+ image_url = img_element["image_url"]["url"]
+ format = img_element["image_url"].get("format")
+ else:
+ image_url = img_element["image_url"]
+ _part = _process_gemini_image(
+ image_url=image_url, format=format
+ )
+ _parts.append(_part)
+ user_content.extend(_parts)
+ elif (
+ _message_content is not None
+ and isinstance(_message_content, str)
+ and len(_message_content) > 0
+ ):
+ _part = PartType(text=_message_content)
+ user_content.append(_part)
+
+ msg_i += 1
+
+ if user_content:
+ """
+ check that user_content has 'text' parameter.
+ - Known Vertex Error: Unable to submit request because it must have a text parameter.
+ - Relevant Issue: https://github.com/BerriAI/litellm/issues/5515
+ """
+ has_text_in_content = _check_text_in_content(user_content)
+ if has_text_in_content is False:
+ verbose_logger.warning(
+ "No text in user content. Adding a blank text to user content, to ensure Gemini doesn't fail the request. Relevant Issue - https://github.com/BerriAI/litellm/issues/5515"
+ )
+ user_content.append(
+ PartType(text=" ")
+ ) # add a blank text, to ensure Gemini doesn't fail the request.
+ contents.append(ContentType(role="user", parts=user_content))
+ assistant_content = []
+ ## MERGE CONSECUTIVE ASSISTANT CONTENT ##
+ while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
+ if isinstance(messages[msg_i], BaseModel):
+ msg_dict: Union[ChatCompletionAssistantMessage, dict] = messages[msg_i].model_dump() # type: ignore
+ else:
+ msg_dict = messages[msg_i] # type: ignore
+ assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore
+ _message_content = assistant_msg.get("content", None)
+ if _message_content is not None and isinstance(_message_content, list):
+ _parts = []
+ for element in _message_content:
+ if isinstance(element, dict):
+ if element["type"] == "text":
+ _part = PartType(text=element["text"])
+ _parts.append(_part)
+ assistant_content.extend(_parts)
+ elif (
+ _message_content is not None
+ and isinstance(_message_content, str)
+ and _message_content
+ ):
+ assistant_text = _message_content # either string or none
+ assistant_content.append(PartType(text=assistant_text)) # type: ignore
+
+ ## HANDLE ASSISTANT FUNCTION CALL
+ if (
+ assistant_msg.get("tool_calls", []) is not None
+ or assistant_msg.get("function_call") is not None
+ ): # support assistant tool invoke conversion
+ assistant_content.extend(
+ convert_to_gemini_tool_call_invoke(assistant_msg)
+ )
+ last_message_with_tool_calls = assistant_msg
+
+ msg_i += 1
+
+ if assistant_content:
+ contents.append(ContentType(role="model", parts=assistant_content))
+
+ ## APPEND TOOL CALL MESSAGES ##
+ tool_call_message_roles = ["tool", "function"]
+ if (
+ msg_i < len(messages)
+ and messages[msg_i]["role"] in tool_call_message_roles
+ ):
+ _part = convert_to_gemini_tool_call_result(
+ messages[msg_i], last_message_with_tool_calls # type: ignore
+ )
+ msg_i += 1
+ tool_call_responses.append(_part)
+ if msg_i < len(messages) and (
+ messages[msg_i]["role"] not in tool_call_message_roles
+ ):
+ if len(tool_call_responses) > 0:
+ contents.append(ContentType(parts=tool_call_responses))
+ tool_call_responses = []
+
+ if msg_i == init_msg_i: # prevent infinite loops
+ raise Exception(
+ "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
+ messages[msg_i]
+ )
+ )
+ if len(tool_call_responses) > 0:
+ contents.append(ContentType(parts=tool_call_responses))
+ return contents
+ except Exception as e:
+ raise e
+
+
+def _transform_request_body(
+ messages: List[AllMessageValues],
+ model: str,
+ optional_params: dict,
+ custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
+ litellm_params: dict,
+ cached_content: Optional[str],
+) -> RequestBody:
+ """
+ Common transformation logic across sync + async Gemini /generateContent calls.
+ """
+ # Separate system prompt from rest of message
+ supports_system_message = get_supports_system_message(
+ model=model, custom_llm_provider=custom_llm_provider
+ )
+ system_instructions, messages = _transform_system_message(
+ supports_system_message=supports_system_message, messages=messages
+ )
+ # Checks for 'response_schema' support - if passed in
+ if "response_schema" in optional_params:
+ supports_response_schema = get_supports_response_schema(
+ model=model, custom_llm_provider=custom_llm_provider
+ )
+ if supports_response_schema is False:
+ user_response_schema_message = response_schema_prompt(
+ model=model, response_schema=optional_params.get("response_schema") # type: ignore
+ )
+ messages.append({"role": "user", "content": user_response_schema_message})
+ optional_params.pop("response_schema")
+
+ # Check for any 'litellm_param_*' set during optional param mapping
+
+ remove_keys = []
+ for k, v in optional_params.items():
+ if k.startswith("litellm_param_"):
+ litellm_params.update({k: v})
+ remove_keys.append(k)
+
+ optional_params = {k: v for k, v in optional_params.items() if k not in remove_keys}
+
+ try:
+ if custom_llm_provider == "gemini":
+ content = litellm.GoogleAIStudioGeminiConfig()._transform_messages(
+ messages=messages
+ )
+ else:
+ content = litellm.VertexGeminiConfig()._transform_messages(
+ messages=messages
+ )
+ tools: Optional[Tools] = optional_params.pop("tools", None)
+ tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
+ safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
+ "safety_settings", None
+ ) # type: ignore
+ config_fields = GenerationConfig.__annotations__.keys()
+
+ filtered_params = {
+ k: v for k, v in optional_params.items() if k in config_fields
+ }
+
+ generation_config: Optional[GenerationConfig] = GenerationConfig(
+ **filtered_params
+ )
+ data = RequestBody(contents=content)
+ if system_instructions is not None:
+ data["system_instruction"] = system_instructions
+ if tools is not None:
+ data["tools"] = tools
+ if tool_choice is not None:
+ data["toolConfig"] = tool_choice
+ if safety_settings is not None:
+ data["safetySettings"] = safety_settings
+ if generation_config is not None:
+ data["generationConfig"] = generation_config
+ if cached_content is not None:
+ data["cachedContent"] = cached_content
+ except Exception as e:
+ raise e
+
+ return data
+
+
+def sync_transform_request_body(
+ gemini_api_key: Optional[str],
+ messages: List[AllMessageValues],
+ api_base: Optional[str],
+ model: str,
+ client: Optional[HTTPHandler],
+ timeout: Optional[Union[float, httpx.Timeout]],
+ extra_headers: Optional[dict],
+ optional_params: dict,
+ logging_obj: LiteLLMLoggingObj,
+ custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
+ litellm_params: dict,
+) -> RequestBody:
+ from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints
+
+ context_caching_endpoints = ContextCachingEndpoints()
+
+ if gemini_api_key is not None:
+ messages, cached_content = context_caching_endpoints.check_and_create_cache(
+ messages=messages,
+ api_key=gemini_api_key,
+ api_base=api_base,
+ model=model,
+ client=client,
+ timeout=timeout,
+ extra_headers=extra_headers,
+ cached_content=optional_params.pop("cached_content", None),
+ logging_obj=logging_obj,
+ )
+ else: # [TODO] implement context caching for gemini as well
+ cached_content = optional_params.pop("cached_content", None)
+
+ return _transform_request_body(
+ messages=messages,
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ litellm_params=litellm_params,
+ cached_content=cached_content,
+ optional_params=optional_params,
+ )
+
+
+async def async_transform_request_body(
+ gemini_api_key: Optional[str],
+ messages: List[AllMessageValues],
+ api_base: Optional[str],
+ model: str,
+ client: Optional[AsyncHTTPHandler],
+ timeout: Optional[Union[float, httpx.Timeout]],
+ extra_headers: Optional[dict],
+ optional_params: dict,
+ logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, # type: ignore
+ custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
+ litellm_params: dict,
+) -> RequestBody:
+ from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints
+
+ context_caching_endpoints = ContextCachingEndpoints()
+
+ if gemini_api_key is not None:
+ messages, cached_content = (
+ await context_caching_endpoints.async_check_and_create_cache(
+ messages=messages,
+ api_key=gemini_api_key,
+ api_base=api_base,
+ model=model,
+ client=client,
+ timeout=timeout,
+ extra_headers=extra_headers,
+ cached_content=optional_params.pop("cached_content", None),
+ logging_obj=logging_obj,
+ )
+ )
+ else: # [TODO] implement context caching for gemini as well
+ cached_content = optional_params.pop("cached_content", None)
+
+ return _transform_request_body(
+ messages=messages,
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ litellm_params=litellm_params,
+ cached_content=cached_content,
+ optional_params=optional_params,
+ )
+
+
+def _transform_system_message(
+ supports_system_message: bool, messages: List[AllMessageValues]
+) -> Tuple[Optional[SystemInstructions], List[AllMessageValues]]:
+ """
+ Extracts the system message from the openai message list.
+
+ Converts the system message to Gemini format
+
+ Returns
+ - system_content_blocks: Optional[SystemInstructions] - the system message list in Gemini format.
+ - messages: List[AllMessageValues] - filtered list of messages in OpenAI format (transformed separately)
+ """
+ # Separate system prompt from rest of message
+ system_prompt_indices = []
+ system_content_blocks: List[PartType] = []
+ if supports_system_message is True:
+ for idx, message in enumerate(messages):
+ if message["role"] == "system":
+ _system_content_block: Optional[PartType] = None
+ if isinstance(message["content"], str):
+ _system_content_block = PartType(text=message["content"])
+ elif isinstance(message["content"], list):
+ system_text = ""
+ for content in message["content"]:
+ system_text += content.get("text") or ""
+ _system_content_block = PartType(text=system_text)
+ if _system_content_block is not None:
+ system_content_blocks.append(_system_content_block)
+ system_prompt_indices.append(idx)
+ if len(system_prompt_indices) > 0:
+ for idx in reversed(system_prompt_indices):
+ messages.pop(idx)
+
+ if len(system_content_blocks) > 0:
+ return SystemInstructions(parts=system_content_blocks), messages
+
+ return None, messages
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py
new file mode 100644
index 00000000..9ac1b1ff
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py
@@ -0,0 +1,1495 @@
+# What is this?
+## httpx client for vertex ai calls
+## Initial implementation - covers gemini + image gen calls
+import json
+import uuid
+from copy import deepcopy
+from functools import partial
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ List,
+ Literal,
+ Optional,
+ Tuple,
+ Union,
+ cast,
+)
+
+import httpx # type: ignore
+
+import litellm
+import litellm.litellm_core_utils
+import litellm.litellm_core_utils.litellm_logging
+from litellm import verbose_logger
+from litellm.litellm_core_utils.core_helpers import map_finish_reason
+from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ get_async_httpx_client,
+)
+from litellm.types.llms.openai import (
+ AllMessageValues,
+ ChatCompletionResponseMessage,
+ ChatCompletionToolCallChunk,
+ ChatCompletionToolCallFunctionChunk,
+ ChatCompletionToolParamFunctionChunk,
+ ChatCompletionUsageBlock,
+)
+from litellm.types.llms.vertex_ai import (
+ VERTEX_CREDENTIALS_TYPES,
+ Candidates,
+ ContentType,
+ FunctionCallingConfig,
+ FunctionDeclaration,
+ GenerateContentResponseBody,
+ HttpxPartType,
+ LogprobsResult,
+ ToolConfig,
+ Tools,
+)
+from litellm.types.utils import (
+ ChatCompletionTokenLogprob,
+ ChoiceLogprobs,
+ GenericStreamingChunk,
+ PromptTokensDetailsWrapper,
+ TopLogprob,
+ Usage,
+)
+from litellm.utils import CustomStreamWrapper, ModelResponse
+
+from ....utils import _remove_additional_properties, _remove_strict_from_schema
+from ..common_utils import VertexAIError, _build_vertex_schema
+from ..vertex_llm_base import VertexBase
+from .transformation import (
+ _gemini_convert_messages_with_history,
+ async_transform_request_body,
+ sync_transform_request_body,
+)
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+
+ LoggingClass = LiteLLMLoggingObj
+else:
+ LoggingClass = Any
+
+
+class VertexAIBaseConfig:
+ def get_mapped_special_auth_params(self) -> dict:
+ """
+ Common auth params across bedrock/vertex_ai/azure/watsonx
+ """
+ return {"project": "vertex_project", "region_name": "vertex_location"}
+
+ def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
+ mapped_params = self.get_mapped_special_auth_params()
+
+ for param, value in non_default_params.items():
+ if param in mapped_params:
+ optional_params[mapped_params[param]] = value
+ return optional_params
+
+ def get_eu_regions(self) -> List[str]:
+ """
+ Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
+ """
+ return [
+ "europe-central2",
+ "europe-north1",
+ "europe-southwest1",
+ "europe-west1",
+ "europe-west2",
+ "europe-west3",
+ "europe-west4",
+ "europe-west6",
+ "europe-west8",
+ "europe-west9",
+ ]
+
+ def get_us_regions(self) -> List[str]:
+ """
+ Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
+ """
+ return [
+ "us-central1",
+ "us-east1",
+ "us-east4",
+ "us-east5",
+ "us-south1",
+ "us-west1",
+ "us-west4",
+ "us-west5",
+ ]
+
+
+class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
+ """
+ Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
+ Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
+
+ The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters:
+
+ - `temperature` (float): This controls the degree of randomness in token selection.
+
+ - `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256.
+
+ - `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95.
+
+ - `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
+
+ - `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'.
+
+ - `candidate_count` (int): Number of generated responses to return.
+
+ - `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
+
+ - `frequency_penalty` (float): This parameter is used to penalize the model from repeating the same output. The default value is 0.0.
+
+ - `presence_penalty` (float): This parameter is used to penalize the model from generating the same output as the input. The default value is 0.0.
+
+ - `seed` (int): The seed value is used to help generate the same output for the same input. The default value is None.
+
+ Note: Please make sure to modify the default parameters as required for your use case.
+ """
+
+ temperature: Optional[float] = None
+ max_output_tokens: Optional[int] = None
+ top_p: Optional[float] = None
+ top_k: Optional[int] = None
+ response_mime_type: Optional[str] = None
+ candidate_count: Optional[int] = None
+ stop_sequences: Optional[list] = None
+ frequency_penalty: Optional[float] = None
+ presence_penalty: Optional[float] = None
+ seed: Optional[int] = None
+
+ def __init__(
+ self,
+ temperature: Optional[float] = None,
+ max_output_tokens: Optional[int] = None,
+ top_p: Optional[float] = None,
+ top_k: Optional[int] = None,
+ response_mime_type: Optional[str] = None,
+ candidate_count: Optional[int] = None,
+ stop_sequences: Optional[list] = None,
+ frequency_penalty: Optional[float] = None,
+ presence_penalty: Optional[float] = None,
+ seed: Optional[int] = None,
+ ) -> None:
+ locals_ = locals().copy()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return super().get_config()
+
+ def get_supported_openai_params(self, model: str) -> List[str]:
+ return [
+ "temperature",
+ "top_p",
+ "max_tokens",
+ "max_completion_tokens",
+ "stream",
+ "tools",
+ "functions",
+ "tool_choice",
+ "response_format",
+ "n",
+ "stop",
+ "frequency_penalty",
+ "presence_penalty",
+ "extra_headers",
+ "seed",
+ "logprobs",
+ ]
+
+ def map_tool_choice_values(
+ self, model: str, tool_choice: Union[str, dict]
+ ) -> Optional[ToolConfig]:
+ if tool_choice == "none":
+ return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="NONE"))
+ elif tool_choice == "required":
+ return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="ANY"))
+ elif tool_choice == "auto":
+ return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="AUTO"))
+ elif isinstance(tool_choice, dict):
+ # only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
+ name = tool_choice.get("function", {}).get("name", "")
+ return ToolConfig(
+ functionCallingConfig=FunctionCallingConfig(
+ mode="ANY", allowed_function_names=[name]
+ )
+ )
+ else:
+ raise litellm.utils.UnsupportedParamsError(
+ message="VertexAI doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
+ tool_choice
+ ),
+ status_code=400,
+ )
+
+ def _map_function(self, value: List[dict]) -> List[Tools]:
+ gtool_func_declarations = []
+ googleSearch: Optional[dict] = None
+ googleSearchRetrieval: Optional[dict] = None
+ code_execution: Optional[dict] = None
+ # remove 'additionalProperties' from tools
+ value = _remove_additional_properties(value)
+ # remove 'strict' from tools
+ value = _remove_strict_from_schema(value)
+
+ for tool in value:
+ openai_function_object: Optional[ChatCompletionToolParamFunctionChunk] = (
+ None
+ )
+ if "function" in tool: # tools list
+ _openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore
+ **tool["function"]
+ )
+
+ if (
+ "parameters" in _openai_function_object
+ and _openai_function_object["parameters"] is not None
+ ): # OPENAI accepts JSON Schema, Google accepts OpenAPI schema.
+ _openai_function_object["parameters"] = _build_vertex_schema(
+ _openai_function_object["parameters"]
+ )
+
+ openai_function_object = _openai_function_object
+
+ elif "name" in tool: # functions list
+ openai_function_object = ChatCompletionToolParamFunctionChunk(**tool) # type: ignore
+
+ # check if grounding
+ if tool.get("googleSearch", None) is not None:
+ googleSearch = tool["googleSearch"]
+ elif tool.get("googleSearchRetrieval", None) is not None:
+ googleSearchRetrieval = tool["googleSearchRetrieval"]
+ elif tool.get("code_execution", None) is not None:
+ code_execution = tool["code_execution"]
+ elif openai_function_object is not None:
+ gtool_func_declaration = FunctionDeclaration(
+ name=openai_function_object["name"],
+ )
+ _description = openai_function_object.get("description", None)
+ _parameters = openai_function_object.get("parameters", None)
+ if _description is not None:
+ gtool_func_declaration["description"] = _description
+ if _parameters is not None:
+ gtool_func_declaration["parameters"] = _parameters
+ gtool_func_declarations.append(gtool_func_declaration)
+ else:
+ # assume it's a provider-specific param
+ verbose_logger.warning(
+ "Invalid tool={}. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request."
+ )
+
+ _tools = Tools(
+ function_declarations=gtool_func_declarations,
+ )
+ if googleSearch is not None:
+ _tools["googleSearch"] = googleSearch
+ if googleSearchRetrieval is not None:
+ _tools["googleSearchRetrieval"] = googleSearchRetrieval
+ if code_execution is not None:
+ _tools["code_execution"] = code_execution
+ return [_tools]
+
+ def _map_response_schema(self, value: dict) -> dict:
+ old_schema = deepcopy(value)
+ if isinstance(old_schema, list):
+ for item in old_schema:
+ if isinstance(item, dict):
+ item = _build_vertex_schema(parameters=item)
+ elif isinstance(old_schema, dict):
+ old_schema = _build_vertex_schema(parameters=old_schema)
+ return old_schema
+
+ def map_openai_params(
+ self,
+ non_default_params: Dict,
+ optional_params: Dict,
+ model: str,
+ drop_params: bool,
+ ) -> Dict:
+ for param, value in non_default_params.items():
+ if param == "temperature":
+ optional_params["temperature"] = value
+ if param == "top_p":
+ optional_params["top_p"] = value
+ if (
+ param == "stream" and value is True
+ ): # sending stream = False, can cause it to get passed unchecked and raise issues
+ optional_params["stream"] = value
+ if param == "n":
+ optional_params["candidate_count"] = value
+ if param == "stop":
+ if isinstance(value, str):
+ optional_params["stop_sequences"] = [value]
+ elif isinstance(value, list):
+ optional_params["stop_sequences"] = value
+ if param == "max_tokens" or param == "max_completion_tokens":
+ optional_params["max_output_tokens"] = value
+ if param == "response_format" and isinstance(value, dict): # type: ignore
+ # remove 'additionalProperties' from json schema
+ value = _remove_additional_properties(value)
+ # remove 'strict' from json schema
+ value = _remove_strict_from_schema(value)
+ if value["type"] == "json_object":
+ optional_params["response_mime_type"] = "application/json"
+ elif value["type"] == "text":
+ optional_params["response_mime_type"] = "text/plain"
+ if "response_schema" in value:
+ optional_params["response_mime_type"] = "application/json"
+ optional_params["response_schema"] = value["response_schema"]
+ elif value["type"] == "json_schema": # type: ignore
+ if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
+ optional_params["response_mime_type"] = "application/json"
+ optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
+
+ if "response_schema" in optional_params and isinstance(
+ optional_params["response_schema"], dict
+ ):
+ optional_params["response_schema"] = self._map_response_schema(
+ value=optional_params["response_schema"]
+ )
+ if param == "frequency_penalty":
+ optional_params["frequency_penalty"] = value
+ if param == "presence_penalty":
+ optional_params["presence_penalty"] = value
+ if param == "logprobs":
+ optional_params["responseLogprobs"] = value
+ if (param == "tools" or param == "functions") and isinstance(value, list):
+ optional_params["tools"] = self._map_function(value=value)
+ optional_params["litellm_param_is_function_call"] = (
+ True if param == "functions" else False
+ )
+ if param == "tool_choice" and (
+ isinstance(value, str) or isinstance(value, dict)
+ ):
+ _tool_choice_value = self.map_tool_choice_values(
+ model=model, tool_choice=value # type: ignore
+ )
+ if _tool_choice_value is not None:
+ optional_params["tool_choice"] = _tool_choice_value
+ if param == "seed":
+ optional_params["seed"] = value
+
+ if litellm.vertex_ai_safety_settings is not None:
+ optional_params["safety_settings"] = litellm.vertex_ai_safety_settings
+ return optional_params
+
+ def get_mapped_special_auth_params(self) -> dict:
+ """
+ Common auth params across bedrock/vertex_ai/azure/watsonx
+ """
+ return {"project": "vertex_project", "region_name": "vertex_location"}
+
+ def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
+ mapped_params = self.get_mapped_special_auth_params()
+
+ for param, value in non_default_params.items():
+ if param in mapped_params:
+ optional_params[mapped_params[param]] = value
+ return optional_params
+
+ def get_eu_regions(self) -> List[str]:
+ """
+ Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
+ """
+ return [
+ "europe-central2",
+ "europe-north1",
+ "europe-southwest1",
+ "europe-west1",
+ "europe-west2",
+ "europe-west3",
+ "europe-west4",
+ "europe-west6",
+ "europe-west8",
+ "europe-west9",
+ ]
+
+ def get_flagged_finish_reasons(self) -> Dict[str, str]:
+ """
+ Return Dictionary of finish reasons which indicate response was flagged
+
+ and what it means
+ """
+ return {
+ "SAFETY": "The token generation was stopped as the response was flagged for safety reasons. NOTE: When streaming the Candidate.content will be empty if content filters blocked the output.",
+ "RECITATION": "The token generation was stopped as the response was flagged for unauthorized citations.",
+ "BLOCKLIST": "The token generation was stopped as the response was flagged for the terms which are included from the terminology blocklist.",
+ "PROHIBITED_CONTENT": "The token generation was stopped as the response was flagged for the prohibited contents.",
+ "SPII": "The token generation was stopped as the response was flagged for Sensitive Personally Identifiable Information (SPII) contents.",
+ }
+
+ def translate_exception_str(self, exception_string: str):
+ if (
+ "GenerateContentRequest.tools[0].function_declarations[0].parameters.properties: should be non-empty for OBJECT type"
+ in exception_string
+ ):
+ return "'properties' field in tools[0]['function']['parameters'] cannot be empty if 'type' == 'object'. Received error from provider - {}".format(
+ exception_string
+ )
+ return exception_string
+
+ def get_assistant_content_message(
+ self, parts: List[HttpxPartType]
+ ) -> Optional[str]:
+ _content_str = ""
+ for part in parts:
+ if "text" in part:
+ _content_str += part["text"]
+ if _content_str:
+ return _content_str
+ return None
+
+ def _transform_parts(
+ self,
+ parts: List[HttpxPartType],
+ index: int,
+ is_function_call: Optional[bool],
+ ) -> Tuple[
+ Optional[ChatCompletionToolCallFunctionChunk],
+ Optional[List[ChatCompletionToolCallChunk]],
+ ]:
+ function: Optional[ChatCompletionToolCallFunctionChunk] = None
+ _tools: List[ChatCompletionToolCallChunk] = []
+ for part in parts:
+ if "functionCall" in part:
+ _function_chunk = ChatCompletionToolCallFunctionChunk(
+ name=part["functionCall"]["name"],
+ arguments=json.dumps(part["functionCall"]["args"]),
+ )
+ if is_function_call is True:
+ function = _function_chunk
+ else:
+ _tool_response_chunk = ChatCompletionToolCallChunk(
+ id=f"call_{str(uuid.uuid4())}",
+ type="function",
+ function=_function_chunk,
+ index=index,
+ )
+ _tools.append(_tool_response_chunk)
+ if len(_tools) == 0:
+ tools: Optional[List[ChatCompletionToolCallChunk]] = None
+ else:
+ tools = _tools
+ return function, tools
+
+ def _transform_logprobs(
+ self, logprobs_result: Optional[LogprobsResult]
+ ) -> Optional[ChoiceLogprobs]:
+ if logprobs_result is None:
+ return None
+ if "chosenCandidates" not in logprobs_result:
+ return None
+ logprobs_list: List[ChatCompletionTokenLogprob] = []
+ for index, candidate in enumerate(logprobs_result["chosenCandidates"]):
+ top_logprobs: List[TopLogprob] = []
+ if "topCandidates" in logprobs_result and index < len(
+ logprobs_result["topCandidates"]
+ ):
+ top_candidates_for_index = logprobs_result["topCandidates"][index][
+ "candidates"
+ ]
+
+ for options in top_candidates_for_index:
+ top_logprobs.append(
+ TopLogprob(
+ token=options["token"], logprob=options["logProbability"]
+ )
+ )
+ logprobs_list.append(
+ ChatCompletionTokenLogprob(
+ token=candidate["token"],
+ logprob=candidate["logProbability"],
+ top_logprobs=top_logprobs,
+ )
+ )
+ return ChoiceLogprobs(content=logprobs_list)
+
+ def _handle_blocked_response(
+ self,
+ model_response: ModelResponse,
+ completion_response: GenerateContentResponseBody,
+ ) -> ModelResponse:
+ # If set, the prompt was blocked and no candidates are returned. Rephrase your prompt
+ model_response.choices[0].finish_reason = "content_filter"
+
+ chat_completion_message: ChatCompletionResponseMessage = {
+ "role": "assistant",
+ "content": None,
+ }
+
+ choice = litellm.Choices(
+ finish_reason="content_filter",
+ index=0,
+ message=chat_completion_message, # type: ignore
+ logprobs=None,
+ enhancements=None,
+ )
+
+ model_response.choices = [choice]
+
+ ## GET USAGE ##
+ usage = Usage(
+ prompt_tokens=completion_response["usageMetadata"].get(
+ "promptTokenCount", 0
+ ),
+ completion_tokens=completion_response["usageMetadata"].get(
+ "candidatesTokenCount", 0
+ ),
+ total_tokens=completion_response["usageMetadata"].get("totalTokenCount", 0),
+ )
+
+ setattr(model_response, "usage", usage)
+
+ return model_response
+
+ def _handle_content_policy_violation(
+ self,
+ model_response: ModelResponse,
+ completion_response: GenerateContentResponseBody,
+ ) -> ModelResponse:
+ ## CONTENT POLICY VIOLATION ERROR
+ model_response.choices[0].finish_reason = "content_filter"
+
+ _chat_completion_message = {
+ "role": "assistant",
+ "content": None,
+ }
+
+ choice = litellm.Choices(
+ finish_reason="content_filter",
+ index=0,
+ message=_chat_completion_message,
+ logprobs=None,
+ enhancements=None,
+ )
+
+ model_response.choices = [choice]
+
+ ## GET USAGE ##
+ usage = Usage(
+ prompt_tokens=completion_response["usageMetadata"].get(
+ "promptTokenCount", 0
+ ),
+ completion_tokens=completion_response["usageMetadata"].get(
+ "candidatesTokenCount", 0
+ ),
+ total_tokens=completion_response["usageMetadata"].get("totalTokenCount", 0),
+ )
+
+ setattr(model_response, "usage", usage)
+
+ return model_response
+
+ def _calculate_usage(
+ self,
+ completion_response: GenerateContentResponseBody,
+ ) -> Usage:
+ cached_tokens: Optional[int] = None
+ prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
+ if "cachedContentTokenCount" in completion_response["usageMetadata"]:
+ cached_tokens = completion_response["usageMetadata"][
+ "cachedContentTokenCount"
+ ]
+
+ if cached_tokens is not None:
+ prompt_tokens_details = PromptTokensDetailsWrapper(
+ cached_tokens=cached_tokens,
+ )
+ ## GET USAGE ##
+ usage = Usage(
+ prompt_tokens=completion_response["usageMetadata"].get(
+ "promptTokenCount", 0
+ ),
+ completion_tokens=completion_response["usageMetadata"].get(
+ "candidatesTokenCount", 0
+ ),
+ total_tokens=completion_response["usageMetadata"].get("totalTokenCount", 0),
+ prompt_tokens_details=prompt_tokens_details,
+ )
+
+ return usage
+
+ def transform_response(
+ self,
+ model: str,
+ raw_response: httpx.Response,
+ model_response: ModelResponse,
+ logging_obj: LoggingClass,
+ request_data: Dict,
+ messages: List[AllMessageValues],
+ optional_params: Dict,
+ litellm_params: Dict,
+ encoding: Any,
+ api_key: Optional[str] = None,
+ json_mode: Optional[bool] = None,
+ ) -> ModelResponse:
+ ## LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response=raw_response.text,
+ additional_args={"complete_input_dict": request_data},
+ )
+
+ ## RESPONSE OBJECT
+ try:
+ completion_response = GenerateContentResponseBody(**raw_response.json()) # type: ignore
+ except Exception as e:
+ raise VertexAIError(
+ message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
+ raw_response.text, str(e)
+ ),
+ status_code=422,
+ headers=raw_response.headers,
+ )
+
+ ## GET MODEL ##
+ model_response.model = model
+
+ ## CHECK IF RESPONSE FLAGGED
+ if (
+ "promptFeedback" in completion_response
+ and "blockReason" in completion_response["promptFeedback"]
+ ):
+ return self._handle_blocked_response(
+ model_response=model_response,
+ completion_response=completion_response,
+ )
+
+ _candidates = completion_response.get("candidates")
+ if _candidates and len(_candidates) > 0:
+ content_policy_violations = (
+ VertexGeminiConfig().get_flagged_finish_reasons()
+ )
+ if (
+ "finishReason" in _candidates[0]
+ and _candidates[0]["finishReason"] in content_policy_violations.keys()
+ ):
+ return self._handle_content_policy_violation(
+ model_response=model_response,
+ completion_response=completion_response,
+ )
+
+ model_response.choices = [] # type: ignore
+
+ try:
+ ## CHECK IF GROUNDING METADATA IN REQUEST
+ grounding_metadata: List[dict] = []
+ safety_ratings: List = []
+ citation_metadata: List = []
+ ## GET TEXT ##
+ chat_completion_message: ChatCompletionResponseMessage = {
+ "role": "assistant"
+ }
+ chat_completion_logprobs: Optional[ChoiceLogprobs] = None
+ tools: Optional[List[ChatCompletionToolCallChunk]] = []
+ functions: Optional[ChatCompletionToolCallFunctionChunk] = None
+ if _candidates:
+ for idx, candidate in enumerate(_candidates):
+ if "content" not in candidate:
+ continue
+
+ if "groundingMetadata" in candidate:
+ grounding_metadata.append(candidate["groundingMetadata"]) # type: ignore
+
+ if "safetyRatings" in candidate:
+ safety_ratings.append(candidate["safetyRatings"])
+
+ if "citationMetadata" in candidate:
+ citation_metadata.append(candidate["citationMetadata"])
+ if "parts" in candidate["content"]:
+ chat_completion_message[
+ "content"
+ ] = VertexGeminiConfig().get_assistant_content_message(
+ parts=candidate["content"]["parts"]
+ )
+
+ functions, tools = self._transform_parts(
+ parts=candidate["content"]["parts"],
+ index=candidate.get("index", idx),
+ is_function_call=litellm_params.get(
+ "litellm_param_is_function_call"
+ ),
+ )
+
+ if "logprobsResult" in candidate:
+ chat_completion_logprobs = self._transform_logprobs(
+ logprobs_result=candidate["logprobsResult"]
+ )
+
+ if tools:
+ chat_completion_message["tool_calls"] = tools
+
+ if functions is not None:
+ chat_completion_message["function_call"] = functions
+ choice = litellm.Choices(
+ finish_reason=candidate.get("finishReason", "stop"),
+ index=candidate.get("index", idx),
+ message=chat_completion_message, # type: ignore
+ logprobs=chat_completion_logprobs,
+ enhancements=None,
+ )
+
+ model_response.choices.append(choice)
+
+ usage = self._calculate_usage(completion_response=completion_response)
+ setattr(model_response, "usage", usage)
+
+ ## ADD GROUNDING METADATA ##
+ setattr(model_response, "vertex_ai_grounding_metadata", grounding_metadata)
+ model_response._hidden_params[
+ "vertex_ai_grounding_metadata"
+ ] = ( # older approach - maintaining to prevent regressions
+ grounding_metadata
+ )
+
+ ## ADD SAFETY RATINGS ##
+ setattr(model_response, "vertex_ai_safety_results", safety_ratings)
+ model_response._hidden_params["vertex_ai_safety_results"] = (
+ safety_ratings # older approach - maintaining to prevent regressions
+ )
+
+ ## ADD CITATION METADATA ##
+ setattr(model_response, "vertex_ai_citation_metadata", citation_metadata)
+ model_response._hidden_params["vertex_ai_citation_metadata"] = (
+ citation_metadata # older approach - maintaining to prevent regressions
+ )
+
+ except Exception as e:
+ raise VertexAIError(
+ message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
+ completion_response, str(e)
+ ),
+ status_code=422,
+ headers=raw_response.headers,
+ )
+
+ return model_response
+
+ def _transform_messages(
+ self, messages: List[AllMessageValues]
+ ) -> List[ContentType]:
+ return _gemini_convert_messages_with_history(messages=messages)
+
+ def get_error_class(
+ self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
+ ) -> BaseLLMException:
+ return VertexAIError(
+ message=error_message, status_code=status_code, headers=headers
+ )
+
+ def transform_request(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: Dict,
+ litellm_params: Dict,
+ headers: Dict,
+ ) -> Dict:
+ raise NotImplementedError(
+ "Vertex AI has a custom implementation of transform_request. Needs sync + async."
+ )
+
+ def validate_environment(
+ self,
+ headers: Optional[Dict],
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: Dict,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ ) -> Dict:
+ default_headers = {
+ "Content-Type": "application/json",
+ }
+ if api_key is not None:
+ default_headers["Authorization"] = f"Bearer {api_key}"
+ if headers is not None:
+ default_headers.update(headers)
+
+ return default_headers
+
+
+async def make_call(
+ client: Optional[AsyncHTTPHandler],
+ api_base: str,
+ headers: dict,
+ data: str,
+ model: str,
+ messages: list,
+ logging_obj,
+):
+ if client is None:
+ client = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.VERTEX_AI,
+ )
+
+ try:
+ response = await client.post(api_base, headers=headers, data=data, stream=True)
+ response.raise_for_status()
+ except httpx.HTTPStatusError as e:
+ exception_string = str(await e.response.aread())
+ raise VertexAIError(
+ status_code=e.response.status_code,
+ message=VertexGeminiConfig().translate_exception_str(exception_string),
+ headers=e.response.headers,
+ )
+ if response.status_code != 200 and response.status_code != 201:
+ raise VertexAIError(
+ status_code=response.status_code,
+ message=response.text,
+ headers=response.headers,
+ )
+
+ completion_stream = ModelResponseIterator(
+ streaming_response=response.aiter_lines(), sync_stream=False
+ )
+ # LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response="first stream response received",
+ additional_args={"complete_input_dict": data},
+ )
+
+ return completion_stream
+
+
+def make_sync_call(
+ client: Optional[HTTPHandler], # module-level client
+ gemini_client: Optional[HTTPHandler], # if passed by user
+ api_base: str,
+ headers: dict,
+ data: str,
+ model: str,
+ messages: list,
+ logging_obj,
+):
+ if gemini_client is not None:
+ client = gemini_client
+ if client is None:
+ client = HTTPHandler() # Create a new client if none provided
+
+ response = client.post(api_base, headers=headers, data=data, stream=True)
+
+ if response.status_code != 200 and response.status_code != 201:
+ raise VertexAIError(
+ status_code=response.status_code,
+ message=str(response.read()),
+ headers=response.headers,
+ )
+
+ completion_stream = ModelResponseIterator(
+ streaming_response=response.iter_lines(), sync_stream=True
+ )
+
+ # LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response="first stream response received",
+ additional_args={"complete_input_dict": data},
+ )
+
+ return completion_stream
+
+
+class VertexLLM(VertexBase):
+ def __init__(self) -> None:
+ super().__init__()
+
+ async def async_streaming(
+ self,
+ model: str,
+ custom_llm_provider: Literal[
+ "vertex_ai", "vertex_ai_beta", "gemini"
+ ], # if it's vertex_ai or gemini (google ai studio)
+ messages: list,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ data: dict,
+ timeout: Optional[Union[float, httpx.Timeout]],
+ encoding,
+ logging_obj,
+ stream,
+ optional_params: dict,
+ litellm_params=None,
+ logger_fn=None,
+ api_base: Optional[str] = None,
+ client: Optional[AsyncHTTPHandler] = None,
+ vertex_project: Optional[str] = None,
+ vertex_location: Optional[str] = None,
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
+ gemini_api_key: Optional[str] = None,
+ extra_headers: Optional[dict] = None,
+ ) -> CustomStreamWrapper:
+ request_body = await async_transform_request_body(**data) # type: ignore
+
+ should_use_v1beta1_features = self.is_using_v1beta1_features(
+ optional_params=optional_params
+ )
+
+ _auth_header, vertex_project = await self._ensure_access_token_async(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider=custom_llm_provider,
+ )
+
+ auth_header, api_base = self._get_token_and_url(
+ model=model,
+ gemini_api_key=gemini_api_key,
+ auth_header=_auth_header,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ stream=stream,
+ custom_llm_provider=custom_llm_provider,
+ api_base=api_base,
+ should_use_v1beta1_features=should_use_v1beta1_features,
+ )
+
+ headers = VertexGeminiConfig().validate_environment(
+ api_key=auth_header,
+ headers=extra_headers,
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=messages,
+ api_key="",
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": api_base,
+ "headers": headers,
+ },
+ )
+
+ request_body_str = json.dumps(request_body)
+ streaming_response = CustomStreamWrapper(
+ completion_stream=None,
+ make_call=partial(
+ make_call,
+ client=client,
+ api_base=api_base,
+ headers=headers,
+ data=request_body_str,
+ model=model,
+ messages=messages,
+ logging_obj=logging_obj,
+ ),
+ model=model,
+ custom_llm_provider="vertex_ai_beta",
+ logging_obj=logging_obj,
+ )
+ return streaming_response
+
+ async def async_completion(
+ self,
+ model: str,
+ messages: list,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ data: dict,
+ custom_llm_provider: Literal[
+ "vertex_ai", "vertex_ai_beta", "gemini"
+ ], # if it's vertex_ai or gemini (google ai studio)
+ timeout: Optional[Union[float, httpx.Timeout]],
+ encoding,
+ logging_obj,
+ stream,
+ optional_params: dict,
+ litellm_params: dict,
+ logger_fn=None,
+ api_base: Optional[str] = None,
+ client: Optional[AsyncHTTPHandler] = None,
+ vertex_project: Optional[str] = None,
+ vertex_location: Optional[str] = None,
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
+ gemini_api_key: Optional[str] = None,
+ extra_headers: Optional[dict] = None,
+ ) -> Union[ModelResponse, CustomStreamWrapper]:
+ should_use_v1beta1_features = self.is_using_v1beta1_features(
+ optional_params=optional_params
+ )
+
+ _auth_header, vertex_project = await self._ensure_access_token_async(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider=custom_llm_provider,
+ )
+
+ auth_header, api_base = self._get_token_and_url(
+ model=model,
+ gemini_api_key=gemini_api_key,
+ auth_header=_auth_header,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ stream=stream,
+ custom_llm_provider=custom_llm_provider,
+ api_base=api_base,
+ should_use_v1beta1_features=should_use_v1beta1_features,
+ )
+
+ headers = VertexGeminiConfig().validate_environment(
+ api_key=auth_header,
+ headers=extra_headers,
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ )
+
+ request_body = await async_transform_request_body(**data) # type: ignore
+ _async_client_params = {}
+ if timeout:
+ _async_client_params["timeout"] = timeout
+ if client is None or not isinstance(client, AsyncHTTPHandler):
+ client = get_async_httpx_client(
+ params=_async_client_params, llm_provider=litellm.LlmProviders.VERTEX_AI
+ )
+ else:
+ client = client # type: ignore
+ ## LOGGING
+ logging_obj.pre_call(
+ input=messages,
+ api_key="",
+ additional_args={
+ "complete_input_dict": request_body,
+ "api_base": api_base,
+ "headers": headers,
+ },
+ )
+
+ try:
+ response = await client.post(
+ api_base, headers=headers, json=cast(dict, request_body)
+ ) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise VertexAIError(
+ status_code=error_code,
+ message=err.response.text,
+ headers=err.response.headers,
+ )
+ except httpx.TimeoutException:
+ raise VertexAIError(
+ status_code=408,
+ message="Timeout error occurred.",
+ headers=None,
+ )
+
+ return VertexGeminiConfig().transform_response(
+ model=model,
+ raw_response=response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key="",
+ request_data=cast(dict, request_body),
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=encoding,
+ )
+
+ def completion(
+ self,
+ model: str,
+ messages: list,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ custom_llm_provider: Literal[
+ "vertex_ai", "vertex_ai_beta", "gemini"
+ ], # if it's vertex_ai or gemini (google ai studio)
+ encoding,
+ logging_obj,
+ optional_params: dict,
+ acompletion: bool,
+ timeout: Optional[Union[float, httpx.Timeout]],
+ vertex_project: Optional[str],
+ vertex_location: Optional[str],
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ gemini_api_key: Optional[str],
+ litellm_params: dict,
+ logger_fn=None,
+ extra_headers: Optional[dict] = None,
+ client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
+ api_base: Optional[str] = None,
+ ) -> Union[ModelResponse, CustomStreamWrapper]:
+ stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
+
+ transform_request_params = {
+ "gemini_api_key": gemini_api_key,
+ "messages": messages,
+ "api_base": api_base,
+ "model": model,
+ "client": client,
+ "timeout": timeout,
+ "extra_headers": extra_headers,
+ "optional_params": optional_params,
+ "logging_obj": logging_obj,
+ "custom_llm_provider": custom_llm_provider,
+ "litellm_params": litellm_params,
+ }
+
+ ### ROUTING (ASYNC, STREAMING, SYNC)
+ if acompletion:
+ ### ASYNC STREAMING
+ if stream is True:
+ return self.async_streaming(
+ model=model,
+ messages=messages,
+ api_base=api_base,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ encoding=encoding,
+ logging_obj=logging_obj,
+ optional_params=optional_params,
+ stream=stream,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ timeout=timeout,
+ client=client, # type: ignore
+ data=transform_request_params,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ gemini_api_key=gemini_api_key,
+ custom_llm_provider=custom_llm_provider,
+ extra_headers=extra_headers,
+ )
+ ### ASYNC COMPLETION
+ return self.async_completion(
+ model=model,
+ messages=messages,
+ data=transform_request_params, # type: ignore
+ api_base=api_base,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ encoding=encoding,
+ logging_obj=logging_obj,
+ optional_params=optional_params,
+ stream=stream,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ timeout=timeout,
+ client=client, # type: ignore
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ gemini_api_key=gemini_api_key,
+ custom_llm_provider=custom_llm_provider,
+ extra_headers=extra_headers,
+ )
+
+ should_use_v1beta1_features = self.is_using_v1beta1_features(
+ optional_params=optional_params
+ )
+
+ _auth_header, vertex_project = self._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider=custom_llm_provider,
+ )
+
+ auth_header, url = self._get_token_and_url(
+ model=model,
+ gemini_api_key=gemini_api_key,
+ auth_header=_auth_header,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ stream=stream,
+ custom_llm_provider=custom_llm_provider,
+ api_base=api_base,
+ should_use_v1beta1_features=should_use_v1beta1_features,
+ )
+ headers = VertexGeminiConfig().validate_environment(
+ api_key=auth_header,
+ headers=extra_headers,
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ )
+
+ ## TRANSFORMATION ##
+ data = sync_transform_request_body(**transform_request_params)
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=messages,
+ api_key="",
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": url,
+ "headers": headers,
+ },
+ )
+
+ ## SYNC STREAMING CALL ##
+ if stream is True:
+ request_data_str = json.dumps(data)
+ streaming_response = CustomStreamWrapper(
+ completion_stream=None,
+ make_call=partial(
+ make_sync_call,
+ gemini_client=(
+ client
+ if client is not None and isinstance(client, HTTPHandler)
+ else None
+ ),
+ api_base=url,
+ data=request_data_str,
+ model=model,
+ messages=messages,
+ logging_obj=logging_obj,
+ headers=headers,
+ ),
+ model=model,
+ custom_llm_provider="vertex_ai_beta",
+ logging_obj=logging_obj,
+ )
+
+ return streaming_response
+ ## COMPLETION CALL ##
+
+ if client is None or isinstance(client, AsyncHTTPHandler):
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ timeout = httpx.Timeout(timeout)
+ _params["timeout"] = timeout
+ client = HTTPHandler(**_params) # type: ignore
+ else:
+ client = client
+
+ try:
+ response = client.post(url=url, headers=headers, json=data) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise VertexAIError(
+ status_code=error_code,
+ message=err.response.text,
+ headers=err.response.headers,
+ )
+ except httpx.TimeoutException:
+ raise VertexAIError(
+ status_code=408,
+ message="Timeout error occurred.",
+ headers=None,
+ )
+
+ return VertexGeminiConfig().transform_response(
+ model=model,
+ raw_response=response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ api_key="",
+ request_data=data, # type: ignore
+ messages=messages,
+ encoding=encoding,
+ )
+
+
+class ModelResponseIterator:
+ def __init__(self, streaming_response, sync_stream: bool):
+ self.streaming_response = streaming_response
+ self.chunk_type: Literal["valid_json", "accumulated_json"] = "valid_json"
+ self.accumulated_json = ""
+ self.sent_first_chunk = False
+
+ def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
+ try:
+ processed_chunk = GenerateContentResponseBody(**chunk) # type: ignore
+
+ text = ""
+ tool_use: Optional[ChatCompletionToolCallChunk] = None
+ finish_reason = ""
+ usage: Optional[ChatCompletionUsageBlock] = None
+ _candidates: Optional[List[Candidates]] = processed_chunk.get("candidates")
+ gemini_chunk: Optional[Candidates] = None
+ if _candidates and len(_candidates) > 0:
+ gemini_chunk = _candidates[0]
+
+ if (
+ gemini_chunk
+ and "content" in gemini_chunk
+ and "parts" in gemini_chunk["content"]
+ ):
+ if "text" in gemini_chunk["content"]["parts"][0]:
+ text = gemini_chunk["content"]["parts"][0]["text"]
+ elif "functionCall" in gemini_chunk["content"]["parts"][0]:
+ function_call = ChatCompletionToolCallFunctionChunk(
+ name=gemini_chunk["content"]["parts"][0]["functionCall"][
+ "name"
+ ],
+ arguments=json.dumps(
+ gemini_chunk["content"]["parts"][0]["functionCall"]["args"]
+ ),
+ )
+ tool_use = ChatCompletionToolCallChunk(
+ id=str(uuid.uuid4()),
+ type="function",
+ function=function_call,
+ index=0,
+ )
+
+ if gemini_chunk and "finishReason" in gemini_chunk:
+ finish_reason = map_finish_reason(
+ finish_reason=gemini_chunk["finishReason"]
+ )
+ ## DO NOT SET 'is_finished' = True
+ ## GEMINI SETS FINISHREASON ON EVERY CHUNK!
+
+ if "usageMetadata" in processed_chunk:
+ usage = ChatCompletionUsageBlock(
+ prompt_tokens=processed_chunk["usageMetadata"].get(
+ "promptTokenCount", 0
+ ),
+ completion_tokens=processed_chunk["usageMetadata"].get(
+ "candidatesTokenCount", 0
+ ),
+ total_tokens=processed_chunk["usageMetadata"].get(
+ "totalTokenCount", 0
+ ),
+ )
+
+ returned_chunk = GenericStreamingChunk(
+ text=text,
+ tool_use=tool_use,
+ is_finished=False,
+ finish_reason=finish_reason,
+ usage=usage,
+ index=0,
+ )
+ return returned_chunk
+ except json.JSONDecodeError:
+ raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
+
+ # Sync iterator
+ def __iter__(self):
+ self.response_iterator = self.streaming_response
+ return self
+
+ def handle_valid_json_chunk(self, chunk: str) -> GenericStreamingChunk:
+ chunk = chunk.strip()
+ try:
+ json_chunk = json.loads(chunk)
+
+ except json.JSONDecodeError as e:
+ if (
+ self.sent_first_chunk is False
+ ): # only check for accumulated json, on first chunk, else raise error. Prevent real errors from being masked.
+ self.chunk_type = "accumulated_json"
+ return self.handle_accumulated_json_chunk(chunk=chunk)
+ raise e
+
+ if self.sent_first_chunk is False:
+ self.sent_first_chunk = True
+
+ return self.chunk_parser(chunk=json_chunk)
+
+ def handle_accumulated_json_chunk(self, chunk: str) -> GenericStreamingChunk:
+ chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
+ message = chunk.replace("\n\n", "")
+
+ # Accumulate JSON data
+ self.accumulated_json += message
+
+ # Try to parse the accumulated JSON
+ try:
+ _data = json.loads(self.accumulated_json)
+ self.accumulated_json = "" # reset after successful parsing
+ return self.chunk_parser(chunk=_data)
+ except json.JSONDecodeError:
+ # If it's not valid JSON yet, continue to the next event
+ return GenericStreamingChunk(
+ text="",
+ is_finished=False,
+ finish_reason="",
+ usage=None,
+ index=0,
+ tool_use=None,
+ )
+
+ def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk:
+ try:
+ chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
+ if len(chunk) > 0:
+ """
+ Check if initial chunk valid json
+ - if partial json -> enter accumulated json logic
+ - if valid - continue
+ """
+ if self.chunk_type == "valid_json":
+ return self.handle_valid_json_chunk(chunk=chunk)
+ elif self.chunk_type == "accumulated_json":
+ return self.handle_accumulated_json_chunk(chunk=chunk)
+
+ return GenericStreamingChunk(
+ text="",
+ is_finished=False,
+ finish_reason="",
+ usage=None,
+ index=0,
+ tool_use=None,
+ )
+ except Exception:
+ raise
+
+ def __next__(self):
+ try:
+ chunk = self.response_iterator.__next__()
+ except StopIteration:
+ if self.chunk_type == "accumulated_json" and self.accumulated_json:
+ return self.handle_accumulated_json_chunk(chunk="")
+ raise StopIteration
+ except ValueError as e:
+ raise RuntimeError(f"Error receiving chunk from stream: {e}")
+
+ try:
+ return self._common_chunk_parsing_logic(chunk=chunk)
+ except StopIteration:
+ raise StopIteration
+ except ValueError as e:
+ raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
+
+ # Async iterator
+ def __aiter__(self):
+ self.async_response_iterator = self.streaming_response.__aiter__()
+ return self
+
+ async def __anext__(self):
+ try:
+ chunk = await self.async_response_iterator.__anext__()
+ except StopAsyncIteration:
+ if self.chunk_type == "accumulated_json" and self.accumulated_json:
+ return self.handle_accumulated_json_chunk(chunk="")
+ raise StopAsyncIteration
+ except ValueError as e:
+ raise RuntimeError(f"Error receiving chunk from stream: {e}")
+
+ try:
+ return self._common_chunk_parsing_logic(chunk=chunk)
+ except StopAsyncIteration:
+ raise StopAsyncIteration
+ except ValueError as e:
+ raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_handler.py
new file mode 100644
index 00000000..0fe5145a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_handler.py
@@ -0,0 +1,182 @@
+"""
+Google AI Studio /batchEmbedContents Embeddings Endpoint
+"""
+
+import json
+from typing import Any, Literal, Optional, Union
+
+import httpx
+
+import litellm
+from litellm import EmbeddingResponse
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ get_async_httpx_client,
+)
+from litellm.types.llms.openai import EmbeddingInput
+from litellm.types.llms.vertex_ai import (
+ VertexAIBatchEmbeddingsRequestBody,
+ VertexAIBatchEmbeddingsResponseObject,
+)
+
+from ..gemini.vertex_and_google_ai_studio_gemini import VertexLLM
+from .batch_embed_content_transformation import (
+ process_response,
+ transform_openai_input_gemini_content,
+)
+
+
+class GoogleBatchEmbeddings(VertexLLM):
+ def batch_embeddings(
+ self,
+ model: str,
+ input: EmbeddingInput,
+ print_verbose,
+ model_response: EmbeddingResponse,
+ custom_llm_provider: Literal["gemini", "vertex_ai"],
+ optional_params: dict,
+ logging_obj: Any,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ encoding=None,
+ vertex_project=None,
+ vertex_location=None,
+ vertex_credentials=None,
+ aembedding=False,
+ timeout=300,
+ client=None,
+ ) -> EmbeddingResponse:
+
+ _auth_header, vertex_project = self._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider=custom_llm_provider,
+ )
+
+ auth_header, url = self._get_token_and_url(
+ model=model,
+ auth_header=_auth_header,
+ gemini_api_key=api_key,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ stream=None,
+ custom_llm_provider=custom_llm_provider,
+ api_base=api_base,
+ should_use_v1beta1_features=False,
+ mode="batch_embedding",
+ )
+
+ if client is None:
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ _httpx_timeout = httpx.Timeout(timeout)
+ _params["timeout"] = _httpx_timeout
+ else:
+ _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
+
+ sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
+ else:
+ sync_handler = client # type: ignore
+
+ optional_params = optional_params or {}
+
+ ### TRANSFORMATION ###
+ request_data = transform_openai_input_gemini_content(
+ input=input, model=model, optional_params=optional_params
+ )
+
+ headers = {
+ "Content-Type": "application/json; charset=utf-8",
+ }
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=input,
+ api_key="",
+ additional_args={
+ "complete_input_dict": request_data,
+ "api_base": url,
+ "headers": headers,
+ },
+ )
+
+ if aembedding is True:
+ return self.async_batch_embeddings( # type: ignore
+ model=model,
+ api_base=api_base,
+ url=url,
+ data=request_data,
+ model_response=model_response,
+ timeout=timeout,
+ headers=headers,
+ input=input,
+ )
+
+ response = sync_handler.post(
+ url=url,
+ headers=headers,
+ data=json.dumps(request_data),
+ )
+
+ if response.status_code != 200:
+ raise Exception(f"Error: {response.status_code} {response.text}")
+
+ _json_response = response.json()
+ _predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
+
+ return process_response(
+ model=model,
+ model_response=model_response,
+ _predictions=_predictions,
+ input=input,
+ )
+
+ async def async_batch_embeddings(
+ self,
+ model: str,
+ api_base: Optional[str],
+ url: str,
+ data: VertexAIBatchEmbeddingsRequestBody,
+ model_response: EmbeddingResponse,
+ input: EmbeddingInput,
+ timeout: Optional[Union[float, httpx.Timeout]],
+ headers={},
+ client: Optional[AsyncHTTPHandler] = None,
+ ) -> EmbeddingResponse:
+ if client is None:
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ _httpx_timeout = httpx.Timeout(timeout)
+ _params["timeout"] = _httpx_timeout
+ else:
+ _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
+
+ async_handler: AsyncHTTPHandler = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.VERTEX_AI,
+ params={"timeout": timeout},
+ )
+ else:
+ async_handler = client # type: ignore
+
+ response = await async_handler.post(
+ url=url,
+ headers=headers,
+ data=json.dumps(data),
+ )
+
+ if response.status_code != 200:
+ raise Exception(f"Error: {response.status_code} {response.text}")
+
+ _json_response = response.json()
+ _predictions = VertexAIBatchEmbeddingsResponseObject(**_json_response) # type: ignore
+
+ return process_response(
+ model=model,
+ model_response=model_response,
+ _predictions=_predictions,
+ input=input,
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py
new file mode 100644
index 00000000..592dac58
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini_embeddings/batch_embed_content_transformation.py
@@ -0,0 +1,74 @@
+"""
+Transformation logic from OpenAI /v1/embeddings format to Google AI Studio /batchEmbedContents format.
+
+Why separate file? Make it easy to see how transformation works
+"""
+
+from typing import List
+
+from litellm import EmbeddingResponse
+from litellm.types.llms.openai import EmbeddingInput
+from litellm.types.llms.vertex_ai import (
+ ContentType,
+ EmbedContentRequest,
+ PartType,
+ VertexAIBatchEmbeddingsRequestBody,
+ VertexAIBatchEmbeddingsResponseObject,
+)
+from litellm.types.utils import Embedding, Usage
+from litellm.utils import get_formatted_prompt, token_counter
+
+
+def transform_openai_input_gemini_content(
+ input: EmbeddingInput, model: str, optional_params: dict
+) -> VertexAIBatchEmbeddingsRequestBody:
+ """
+ The content to embed. Only the parts.text fields will be counted.
+ """
+ gemini_model_name = "models/{}".format(model)
+ requests: List[EmbedContentRequest] = []
+ if isinstance(input, str):
+ request = EmbedContentRequest(
+ model=gemini_model_name,
+ content=ContentType(parts=[PartType(text=input)]),
+ **optional_params
+ )
+ requests.append(request)
+ else:
+ for i in input:
+ request = EmbedContentRequest(
+ model=gemini_model_name,
+ content=ContentType(parts=[PartType(text=i)]),
+ **optional_params
+ )
+ requests.append(request)
+
+ return VertexAIBatchEmbeddingsRequestBody(requests=requests)
+
+
+def process_response(
+ input: EmbeddingInput,
+ model_response: EmbeddingResponse,
+ model: str,
+ _predictions: VertexAIBatchEmbeddingsResponseObject,
+) -> EmbeddingResponse:
+
+ openai_embeddings: List[Embedding] = []
+ for embedding in _predictions["embeddings"]:
+ openai_embedding = Embedding(
+ embedding=embedding["values"],
+ index=0,
+ object="embedding",
+ )
+ openai_embeddings.append(openai_embedding)
+
+ model_response.data = openai_embeddings
+ model_response.model = model
+
+ input_text = get_formatted_prompt(data={"input": input}, call_type="embedding")
+ prompt_tokens = token_counter(model=model, text=input_text)
+ model_response.usage = Usage(
+ prompt_tokens=prompt_tokens, total_tokens=prompt_tokens
+ )
+
+ return model_response
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/cost_calculator.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/cost_calculator.py
new file mode 100644
index 00000000..2ba18c09
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/cost_calculator.py
@@ -0,0 +1,23 @@
+"""
+Vertex AI Image Generation Cost Calculator
+"""
+
+import litellm
+from litellm.types.utils import ImageResponse
+
+
+def cost_calculator(
+ model: str,
+ image_response: ImageResponse,
+) -> float:
+ """
+ Vertex AI Image Generation Cost Calculator
+ """
+ _model_info = litellm.get_model_info(
+ model=model,
+ custom_llm_provider="vertex_ai",
+ )
+
+ output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0
+ num_images: int = len(image_response.data)
+ return output_cost_per_image * num_images
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/image_generation_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/image_generation_handler.py
new file mode 100644
index 00000000..1d5322c0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/image_generation/image_generation_handler.py
@@ -0,0 +1,236 @@
+import json
+from typing import Any, Dict, List, Optional
+
+import httpx
+from openai.types.image import Image
+
+import litellm
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ get_async_httpx_client,
+)
+from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
+from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
+from litellm.types.utils import ImageResponse
+
+
+class VertexImageGeneration(VertexLLM):
+ def process_image_generation_response(
+ self,
+ json_response: Dict[str, Any],
+ model_response: ImageResponse,
+ model: Optional[str] = None,
+ ) -> ImageResponse:
+ if "predictions" not in json_response:
+ raise litellm.InternalServerError(
+ message=f"image generation response does not contain 'predictions', got {json_response}",
+ llm_provider="vertex_ai",
+ model=model,
+ )
+
+ predictions = json_response["predictions"]
+ response_data: List[Image] = []
+
+ for prediction in predictions:
+ bytes_base64_encoded = prediction["bytesBase64Encoded"]
+ image_object = Image(b64_json=bytes_base64_encoded)
+ response_data.append(image_object)
+
+ model_response.data = response_data
+ return model_response
+
+ def image_generation(
+ self,
+ prompt: str,
+ vertex_project: Optional[str],
+ vertex_location: Optional[str],
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ model_response: ImageResponse,
+ logging_obj: Any,
+ model: Optional[
+ str
+ ] = "imagegeneration", # vertex ai uses imagegeneration as the default model
+ client: Optional[Any] = None,
+ optional_params: Optional[dict] = None,
+ timeout: Optional[int] = None,
+ aimg_generation=False,
+ ) -> ImageResponse:
+ if aimg_generation is True:
+ return self.aimage_generation( # type: ignore
+ prompt=prompt,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ model=model,
+ client=client,
+ optional_params=optional_params,
+ timeout=timeout,
+ logging_obj=logging_obj,
+ model_response=model_response,
+ )
+
+ if client is None:
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ _httpx_timeout = httpx.Timeout(timeout)
+ _params["timeout"] = _httpx_timeout
+ else:
+ _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
+
+ sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
+ else:
+ sync_handler = client # type: ignore
+
+ url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
+
+ auth_header, _ = self._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider="vertex_ai",
+ )
+ optional_params = optional_params or {
+ "sampleCount": 1
+ } # default optional params
+
+ request_data = {
+ "instances": [{"prompt": prompt}],
+ "parameters": optional_params,
+ }
+
+ request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+
+ response = sync_handler.post(
+ url=url,
+ headers={
+ "Content-Type": "application/json; charset=utf-8",
+ "Authorization": f"Bearer {auth_header}",
+ },
+ data=json.dumps(request_data),
+ )
+
+ if response.status_code != 200:
+ raise Exception(f"Error: {response.status_code} {response.text}")
+
+ json_response = response.json()
+ return self.process_image_generation_response(
+ json_response, model_response, model
+ )
+
+ async def aimage_generation(
+ self,
+ prompt: str,
+ vertex_project: Optional[str],
+ vertex_location: Optional[str],
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ model_response: litellm.ImageResponse,
+ logging_obj: Any,
+ model: Optional[
+ str
+ ] = "imagegeneration", # vertex ai uses imagegeneration as the default model
+ client: Optional[AsyncHTTPHandler] = None,
+ optional_params: Optional[dict] = None,
+ timeout: Optional[int] = None,
+ ):
+ response = None
+ if client is None:
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ _httpx_timeout = httpx.Timeout(timeout)
+ _params["timeout"] = _httpx_timeout
+ else:
+ _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
+
+ self.async_handler = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.VERTEX_AI,
+ params={"timeout": timeout},
+ )
+ else:
+ self.async_handler = client # type: ignore
+
+ # make POST request to
+ # https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
+ url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"
+
+ """
+ Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
+ curl -X POST \
+ -H "Authorization: Bearer $(gcloud auth print-access-token)" \
+ -H "Content-Type: application/json; charset=utf-8" \
+ -d {
+ "instances": [
+ {
+ "prompt": "a cat"
+ }
+ ],
+ "parameters": {
+ "sampleCount": 1
+ }
+ } \
+ "https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
+ """
+ auth_header, _ = self._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider="vertex_ai",
+ )
+ optional_params = optional_params or {
+ "sampleCount": 1
+ } # default optional params
+
+ request_data = {
+ "instances": [{"prompt": prompt}],
+ "parameters": optional_params,
+ }
+
+ request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+
+ response = await self.async_handler.post(
+ url=url,
+ headers={
+ "Content-Type": "application/json; charset=utf-8",
+ "Authorization": f"Bearer {auth_header}",
+ },
+ data=json.dumps(request_data),
+ )
+
+ if response.status_code != 200:
+ raise Exception(f"Error: {response.status_code} {response.text}")
+
+ json_response = response.json()
+ return self.process_image_generation_response(
+ json_response, model_response, model
+ )
+
+ def is_image_generation_response(self, json_response: Dict[str, Any]) -> bool:
+ if "predictions" in json_response:
+ if "bytesBase64Encoded" in json_response["predictions"][0]:
+ return True
+ return False
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py
new file mode 100644
index 00000000..f63d1ce1
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/multimodal_embeddings/embedding_handler.py
@@ -0,0 +1,294 @@
+import json
+from typing import List, Literal, Optional, Union
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ get_async_httpx_client,
+)
+from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
+ VertexAIError,
+ VertexLLM,
+)
+from litellm.types.llms.vertex_ai import (
+ Instance,
+ InstanceImage,
+ InstanceVideo,
+ MultimodalPredictions,
+ VertexMultimodalEmbeddingRequest,
+)
+from litellm.types.utils import Embedding, EmbeddingResponse
+from litellm.utils import is_base64_encoded
+
+
+class VertexMultimodalEmbedding(VertexLLM):
+ def __init__(self) -> None:
+ super().__init__()
+ self.SUPPORTED_MULTIMODAL_EMBEDDING_MODELS = [
+ "multimodalembedding",
+ "multimodalembedding@001",
+ ]
+
+ def multimodal_embedding(
+ self,
+ model: str,
+ input: Union[list, str],
+ print_verbose,
+ model_response: EmbeddingResponse,
+ custom_llm_provider: Literal["gemini", "vertex_ai"],
+ optional_params: dict,
+ logging_obj: LiteLLMLoggingObj,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ encoding=None,
+ vertex_project=None,
+ vertex_location=None,
+ vertex_credentials=None,
+ aembedding=False,
+ timeout=300,
+ client=None,
+ ) -> EmbeddingResponse:
+
+ _auth_header, vertex_project = self._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider=custom_llm_provider,
+ )
+
+ auth_header, url = self._get_token_and_url(
+ model=model,
+ auth_header=_auth_header,
+ gemini_api_key=api_key,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ stream=None,
+ custom_llm_provider=custom_llm_provider,
+ api_base=api_base,
+ should_use_v1beta1_features=False,
+ mode="embedding",
+ )
+
+ if client is None:
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ _httpx_timeout = httpx.Timeout(timeout)
+ _params["timeout"] = _httpx_timeout
+ else:
+ _params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)
+
+ sync_handler: HTTPHandler = HTTPHandler(**_params) # type: ignore
+ else:
+ sync_handler = client # type: ignore
+
+ optional_params = optional_params or {}
+
+ request_data = VertexMultimodalEmbeddingRequest()
+
+ if "instances" in optional_params:
+ request_data["instances"] = optional_params["instances"]
+ elif isinstance(input, list):
+ vertex_instances: List[Instance] = self.process_openai_embedding_input(
+ _input=input
+ )
+ request_data["instances"] = vertex_instances
+
+ else:
+ # construct instances
+ vertex_request_instance = Instance(**optional_params)
+
+ if isinstance(input, str):
+ vertex_request_instance = self._process_input_element(input)
+
+ request_data["instances"] = [vertex_request_instance]
+
+ headers = {
+ "Content-Type": "application/json; charset=utf-8",
+ "Authorization": f"Bearer {auth_header}",
+ }
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=input,
+ api_key="",
+ additional_args={
+ "complete_input_dict": request_data,
+ "api_base": url,
+ "headers": headers,
+ },
+ )
+
+ if aembedding is True:
+ return self.async_multimodal_embedding( # type: ignore
+ model=model,
+ api_base=url,
+ data=request_data,
+ timeout=timeout,
+ headers=headers,
+ client=client,
+ model_response=model_response,
+ )
+
+ response = sync_handler.post(
+ url=url,
+ headers=headers,
+ data=json.dumps(request_data),
+ )
+
+ if response.status_code != 200:
+ raise Exception(f"Error: {response.status_code} {response.text}")
+
+ _json_response = response.json()
+ if "predictions" not in _json_response:
+ raise litellm.InternalServerError(
+ message=f"embedding response does not contain 'predictions', got {_json_response}",
+ llm_provider="vertex_ai",
+ model=model,
+ )
+ _predictions = _json_response["predictions"]
+ vertex_predictions = MultimodalPredictions(predictions=_predictions)
+ model_response.data = self.transform_embedding_response_to_openai(
+ predictions=vertex_predictions
+ )
+ model_response.model = model
+
+ return model_response
+
+ async def async_multimodal_embedding(
+ self,
+ model: str,
+ api_base: str,
+ data: VertexMultimodalEmbeddingRequest,
+ model_response: litellm.EmbeddingResponse,
+ timeout: Optional[Union[float, httpx.Timeout]],
+ headers={},
+ client: Optional[AsyncHTTPHandler] = None,
+ ) -> litellm.EmbeddingResponse:
+ if client is None:
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ timeout = httpx.Timeout(timeout)
+ _params["timeout"] = timeout
+ client = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.VERTEX_AI,
+ params={"timeout": timeout},
+ )
+ else:
+ client = client # type: ignore
+
+ try:
+ response = await client.post(api_base, headers=headers, json=data) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise VertexAIError(status_code=error_code, message=err.response.text)
+ except httpx.TimeoutException:
+ raise VertexAIError(status_code=408, message="Timeout error occurred.")
+
+ _json_response = response.json()
+ if "predictions" not in _json_response:
+ raise litellm.InternalServerError(
+ message=f"embedding response does not contain 'predictions', got {_json_response}",
+ llm_provider="vertex_ai",
+ model=model,
+ )
+ _predictions = _json_response["predictions"]
+
+ vertex_predictions = MultimodalPredictions(predictions=_predictions)
+ model_response.data = self.transform_embedding_response_to_openai(
+ predictions=vertex_predictions
+ )
+ model_response.model = model
+
+ return model_response
+
+ def _process_input_element(self, input_element: str) -> Instance:
+ """
+ Process the input element for multimodal embedding requests. checks if the if the input is gcs uri, base64 encoded image or plain text.
+
+ Args:
+ input_element (str): The input element to process.
+
+ Returns:
+ Dict[str, Any]: A dictionary representing the processed input element.
+ """
+ if len(input_element) == 0:
+ return Instance(text=input_element)
+ elif "gs://" in input_element:
+ if "mp4" in input_element:
+ return Instance(video=InstanceVideo(gcsUri=input_element))
+ else:
+ return Instance(image=InstanceImage(gcsUri=input_element))
+ elif is_base64_encoded(s=input_element):
+ return Instance(image=InstanceImage(bytesBase64Encoded=input_element))
+ else:
+ return Instance(text=input_element)
+
+ def process_openai_embedding_input(
+ self, _input: Union[list, str]
+ ) -> List[Instance]:
+ """
+ Process the input for multimodal embedding requests.
+
+ Args:
+ _input (Union[list, str]): The input data to process.
+
+ Returns:
+ List[Instance]: A list of processed VertexAI Instance objects.
+ """
+
+ _input_list = None
+ if not isinstance(_input, list):
+ _input_list = [_input]
+ else:
+ _input_list = _input
+
+ processed_instances = []
+ for element in _input_list:
+ if isinstance(element, str):
+ instance = Instance(**self._process_input_element(element))
+ elif isinstance(element, dict):
+ instance = Instance(**element)
+ else:
+ raise ValueError(f"Unsupported input type: {type(element)}")
+ processed_instances.append(instance)
+
+ return processed_instances
+
+ def transform_embedding_response_to_openai(
+ self, predictions: MultimodalPredictions
+ ) -> List[Embedding]:
+
+ openai_embeddings: List[Embedding] = []
+ if "predictions" in predictions:
+ for idx, _prediction in enumerate(predictions["predictions"]):
+ if _prediction:
+ if "textEmbedding" in _prediction:
+ openai_embedding_object = Embedding(
+ embedding=_prediction["textEmbedding"],
+ index=idx,
+ object="embedding",
+ )
+ openai_embeddings.append(openai_embedding_object)
+ elif "imageEmbedding" in _prediction:
+ openai_embedding_object = Embedding(
+ embedding=_prediction["imageEmbedding"],
+ index=idx,
+ object="embedding",
+ )
+ openai_embeddings.append(openai_embedding_object)
+ elif "videoEmbeddings" in _prediction:
+ for video_embedding in _prediction["videoEmbeddings"]:
+ openai_embedding_object = Embedding(
+ embedding=video_embedding["embedding"],
+ index=idx,
+ object="embedding",
+ )
+ openai_embeddings.append(openai_embedding_object)
+ return openai_embeddings
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/text_to_speech/text_to_speech_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/text_to_speech/text_to_speech_handler.py
new file mode 100644
index 00000000..18bc72db
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/text_to_speech/text_to_speech_handler.py
@@ -0,0 +1,243 @@
+from typing import Optional, TypedDict, Union
+
+import httpx
+
+import litellm
+from litellm.llms.custom_httpx.http_handler import (
+ _get_httpx_client,
+ get_async_httpx_client,
+)
+from litellm.llms.openai.openai import HttpxBinaryResponseContent
+from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
+from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
+
+
+class VertexInput(TypedDict, total=False):
+ text: Optional[str]
+ ssml: Optional[str]
+
+
+class VertexVoice(TypedDict, total=False):
+ languageCode: str
+ name: str
+
+
+class VertexAudioConfig(TypedDict, total=False):
+ audioEncoding: str
+ speakingRate: str
+
+
+class VertexTextToSpeechRequest(TypedDict, total=False):
+ input: VertexInput
+ voice: VertexVoice
+ audioConfig: Optional[VertexAudioConfig]
+
+
+class VertexTextToSpeechAPI(VertexLLM):
+ """
+ Vertex methods to support for batches
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def audio_speech(
+ self,
+ logging_obj,
+ vertex_project: Optional[str],
+ vertex_location: Optional[str],
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ api_base: Optional[str],
+ timeout: Union[float, httpx.Timeout],
+ model: str,
+ input: str,
+ voice: Optional[dict] = None,
+ _is_async: Optional[bool] = False,
+ optional_params: Optional[dict] = None,
+ kwargs: Optional[dict] = None,
+ ) -> HttpxBinaryResponseContent:
+ import base64
+
+ ####### Authenticate with Vertex AI ########
+ _auth_header, vertex_project = self._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider="vertex_ai_beta",
+ )
+
+ auth_header, _ = self._get_token_and_url(
+ model="",
+ auth_header=_auth_header,
+ gemini_api_key=None,
+ vertex_credentials=vertex_credentials,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ stream=False,
+ custom_llm_provider="vertex_ai_beta",
+ api_base=api_base,
+ )
+
+ headers = {
+ "Authorization": f"Bearer {auth_header}",
+ "x-goog-user-project": vertex_project,
+ "Content-Type": "application/json",
+ "charset": "UTF-8",
+ }
+
+ ######### End of Authentication ###########
+
+ ####### Build the request ################
+ # API Ref: https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize
+ kwargs = kwargs or {}
+ optional_params = optional_params or {}
+
+ vertex_input = VertexInput(text=input)
+ validate_vertex_input(vertex_input, kwargs, optional_params)
+
+ # required param
+ if voice is not None:
+ vertex_voice = VertexVoice(**voice)
+ elif "voice" in kwargs:
+ vertex_voice = VertexVoice(**kwargs["voice"])
+ else:
+ # use defaults to not fail the request
+ vertex_voice = VertexVoice(
+ languageCode="en-US",
+ name="en-US-Studio-O",
+ )
+
+ if "audioConfig" in kwargs:
+ vertex_audio_config = VertexAudioConfig(**kwargs["audioConfig"])
+ else:
+ # use defaults to not fail the request
+ vertex_audio_config = VertexAudioConfig(
+ audioEncoding="LINEAR16",
+ speakingRate="1",
+ )
+
+ request = VertexTextToSpeechRequest(
+ input=vertex_input,
+ voice=vertex_voice,
+ audioConfig=vertex_audio_config,
+ )
+
+ url = "https://texttospeech.googleapis.com/v1/text:synthesize"
+ ########## End of building request ############
+
+ ########## Log the request for debugging / logging ############
+ logging_obj.pre_call(
+ input=[],
+ api_key="",
+ additional_args={
+ "complete_input_dict": request,
+ "api_base": url,
+ "headers": headers,
+ },
+ )
+
+ ########## End of logging ############
+ ####### Send the request ###################
+ if _is_async is True:
+ return self.async_audio_speech( # type:ignore
+ logging_obj=logging_obj, url=url, headers=headers, request=request
+ )
+ sync_handler = _get_httpx_client()
+
+ response = sync_handler.post(
+ url=url,
+ headers=headers,
+ json=request, # type: ignore
+ )
+ if response.status_code != 200:
+ raise Exception(
+ f"Request failed with status code {response.status_code}, {response.text}"
+ )
+ ############ Process the response ############
+ _json_response = response.json()
+
+ response_content = _json_response["audioContent"]
+
+ # Decode base64 to get binary content
+ binary_data = base64.b64decode(response_content)
+
+ # Create an httpx.Response object
+ response = httpx.Response(
+ status_code=200,
+ content=binary_data,
+ )
+
+ # Initialize the HttpxBinaryResponseContent instance
+ http_binary_response = HttpxBinaryResponseContent(response)
+ return http_binary_response
+
+ async def async_audio_speech(
+ self,
+ logging_obj,
+ url: str,
+ headers: dict,
+ request: VertexTextToSpeechRequest,
+ ) -> HttpxBinaryResponseContent:
+ import base64
+
+ async_handler = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.VERTEX_AI
+ )
+
+ response = await async_handler.post(
+ url=url,
+ headers=headers,
+ json=request, # type: ignore
+ )
+
+ if response.status_code != 200:
+ raise Exception(
+ f"Request did not return a 200 status code: {response.status_code}, {response.text}"
+ )
+
+ _json_response = response.json()
+
+ response_content = _json_response["audioContent"]
+
+ # Decode base64 to get binary content
+ binary_data = base64.b64decode(response_content)
+
+ # Create an httpx.Response object
+ response = httpx.Response(
+ status_code=200,
+ content=binary_data,
+ )
+
+ # Initialize the HttpxBinaryResponseContent instance
+ http_binary_response = HttpxBinaryResponseContent(response)
+ return http_binary_response
+
+
+def validate_vertex_input(
+ input_data: VertexInput, kwargs: dict, optional_params: dict
+) -> None:
+ # Remove None values
+ if input_data.get("text") is None:
+ input_data.pop("text", None)
+ if input_data.get("ssml") is None:
+ input_data.pop("ssml", None)
+
+ # Check if use_ssml is set
+ use_ssml = kwargs.get("use_ssml", optional_params.get("use_ssml", False))
+
+ if use_ssml:
+ if "text" in input_data:
+ input_data["ssml"] = input_data.pop("text")
+ elif "ssml" not in input_data:
+ raise ValueError("SSML input is required when use_ssml is True.")
+ else:
+ # LiteLLM will auto-detect if text is in ssml format
+ # check if "text" is an ssml - in this case we should pass it as ssml instead of text
+ if input_data:
+ _text = input_data.get("text", None) or ""
+ if "<speak>" in _text:
+ input_data["ssml"] = input_data.pop("text")
+
+ if not input_data:
+ raise ValueError("Either 'text' or 'ssml' must be provided.")
+ if "text" in input_data and "ssml" in input_data:
+ raise ValueError("Only one of 'text' or 'ssml' should be provided, not both.")
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_non_gemini.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_non_gemini.py
new file mode 100644
index 00000000..744e1eb3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_non_gemini.py
@@ -0,0 +1,784 @@
+import json
+import os
+import time
+from typing import Any, Callable, Optional, cast
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.core_helpers import map_finish_reason
+from litellm.llms.bedrock.common_utils import ModelResponseIterator
+from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
+from litellm.types.llms.vertex_ai import *
+from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
+
+
+class VertexAIError(Exception):
+ def __init__(self, status_code, message):
+ self.status_code = status_code
+ self.message = message
+ self.request = httpx.Request(
+ method="POST", url=" https://cloud.google.com/vertex-ai/"
+ )
+ self.response = httpx.Response(status_code=status_code, request=self.request)
+ super().__init__(
+ self.message
+ ) # Call the base class constructor with the parameters it needs
+
+
+class TextStreamer:
+ """
+ Fake streaming iterator for Vertex AI Model Garden calls
+ """
+
+ def __init__(self, text):
+ self.text = text.split() # let's assume words as a streaming unit
+ self.index = 0
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.index < len(self.text):
+ result = self.text[self.index]
+ self.index += 1
+ return result
+ else:
+ raise StopIteration
+
+ def __aiter__(self):
+ return self
+
+ async def __anext__(self):
+ if self.index < len(self.text):
+ result = self.text[self.index]
+ self.index += 1
+ return result
+ else:
+ raise StopAsyncIteration # once we run out of data to stream, we raise this error
+
+
+def _get_client_cache_key(
+ model: str, vertex_project: Optional[str], vertex_location: Optional[str]
+):
+ _cache_key = f"{model}-{vertex_project}-{vertex_location}"
+ return _cache_key
+
+
+def _get_client_from_cache(client_cache_key: str):
+ return litellm.in_memory_llm_clients_cache.get_cache(client_cache_key)
+
+
+def _set_client_in_cache(client_cache_key: str, vertex_llm_model: Any):
+ litellm.in_memory_llm_clients_cache.set_cache(
+ key=client_cache_key,
+ value=vertex_llm_model,
+ ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
+ )
+
+
+def completion( # noqa: PLR0915
+ model: str,
+ messages: list,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ encoding,
+ logging_obj,
+ optional_params: dict,
+ vertex_project=None,
+ vertex_location=None,
+ vertex_credentials=None,
+ litellm_params=None,
+ logger_fn=None,
+ acompletion: bool = False,
+):
+ """
+ NON-GEMINI/ANTHROPIC CALLS.
+
+ This is the handler for OLDER PALM MODELS and VERTEX AI MODEL GARDEN
+
+ For Vertex AI Anthropic: `vertex_anthropic.py`
+ For Gemini: `vertex_httpx.py`
+ """
+ try:
+ import vertexai
+ except Exception:
+ raise VertexAIError(
+ status_code=400,
+ message="vertexai import failed please run `pip install google-cloud-aiplatform`. This is required for the 'vertex_ai/' route on LiteLLM",
+ )
+
+ if not (
+ hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
+ ):
+ raise VertexAIError(
+ status_code=400,
+ message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
+ )
+ try:
+ import google.auth # type: ignore
+ from google.cloud import aiplatform # type: ignore
+ from google.cloud.aiplatform_v1beta1.types import (
+ content as gapic_content_types, # type: ignore
+ )
+ from google.protobuf import json_format # type: ignore
+ from google.protobuf.struct_pb2 import Value # type: ignore
+ from vertexai.language_models import CodeGenerationModel, TextGenerationModel
+ from vertexai.preview.generative_models import GenerativeModel
+ from vertexai.preview.language_models import ChatModel, CodeChatModel
+
+ ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
+ print_verbose(
+ f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
+ )
+
+ _cache_key = _get_client_cache_key(
+ model=model, vertex_project=vertex_project, vertex_location=vertex_location
+ )
+ _vertex_llm_model_object = _get_client_from_cache(client_cache_key=_cache_key)
+
+ if _vertex_llm_model_object is None:
+ from google.auth.credentials import Credentials
+
+ if vertex_credentials is not None and isinstance(vertex_credentials, str):
+ import google.oauth2.service_account
+
+ json_obj = json.loads(vertex_credentials)
+
+ creds = (
+ google.oauth2.service_account.Credentials.from_service_account_info(
+ json_obj,
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+ )
+ else:
+ creds, _ = google.auth.default(quota_project_id=vertex_project)
+ print_verbose(
+ f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
+ )
+ vertexai.init(
+ project=vertex_project,
+ location=vertex_location,
+ credentials=cast(Credentials, creds),
+ )
+
+ ## Load Config
+ config = litellm.VertexAIConfig.get_config()
+ for k, v in config.items():
+ if k not in optional_params:
+ optional_params[k] = v
+
+ ## Process safety settings into format expected by vertex AI
+ safety_settings = None
+ if "safety_settings" in optional_params:
+ safety_settings = optional_params.pop("safety_settings")
+ if not isinstance(safety_settings, list):
+ raise ValueError("safety_settings must be a list")
+ if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict):
+ raise ValueError("safety_settings must be a list of dicts")
+ safety_settings = [
+ gapic_content_types.SafetySetting(x) for x in safety_settings
+ ]
+
+ # vertexai does not use an API key, it looks for credentials.json in the environment
+
+ prompt = " ".join(
+ [
+ message.get("content")
+ for message in messages
+ if isinstance(message.get("content", None), str)
+ ]
+ )
+
+ mode = ""
+
+ request_str = ""
+ response_obj = None
+ instances = None
+ client_options = {
+ "api_endpoint": f"{vertex_location}-aiplatform.googleapis.com"
+ }
+ fake_stream = False
+ if (
+ model in litellm.vertex_language_models
+ or model in litellm.vertex_vision_models
+ ):
+ llm_model: Any = _vertex_llm_model_object or GenerativeModel(model)
+ mode = "vision"
+ request_str += f"llm_model = GenerativeModel({model})\n"
+ elif model in litellm.vertex_chat_models:
+ llm_model = _vertex_llm_model_object or ChatModel.from_pretrained(model)
+ mode = "chat"
+ request_str += f"llm_model = ChatModel.from_pretrained({model})\n"
+ elif model in litellm.vertex_text_models:
+ llm_model = _vertex_llm_model_object or TextGenerationModel.from_pretrained(
+ model
+ )
+ mode = "text"
+ request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n"
+ elif model in litellm.vertex_code_text_models:
+ llm_model = _vertex_llm_model_object or CodeGenerationModel.from_pretrained(
+ model
+ )
+ mode = "text"
+ request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
+ fake_stream = True
+ elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models
+ llm_model = _vertex_llm_model_object or CodeChatModel.from_pretrained(model)
+ mode = "chat"
+ request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
+ elif model == "private":
+ mode = "private"
+ model = optional_params.pop("model_id", None)
+ # private endpoint requires a dict instead of JSON
+ instances = [optional_params.copy()]
+ instances[0]["prompt"] = prompt
+ llm_model = aiplatform.PrivateEndpoint(
+ endpoint_name=model,
+ project=vertex_project,
+ location=vertex_location,
+ )
+ request_str += f"llm_model = aiplatform.PrivateEndpoint(endpoint_name={model}, project={vertex_project}, location={vertex_location})\n"
+ else: # assume vertex model garden on public endpoint
+ mode = "custom"
+
+ instances = [optional_params.copy()]
+ instances[0]["prompt"] = prompt
+ instances = [
+ json_format.ParseDict(instance_dict, Value())
+ for instance_dict in instances
+ ]
+ # Will determine the API used based on async parameter
+ llm_model = None
+
+ # NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now
+ if acompletion is True:
+ data = {
+ "llm_model": llm_model,
+ "mode": mode,
+ "prompt": prompt,
+ "logging_obj": logging_obj,
+ "request_str": request_str,
+ "model": model,
+ "model_response": model_response,
+ "encoding": encoding,
+ "messages": messages,
+ "print_verbose": print_verbose,
+ "client_options": client_options,
+ "instances": instances,
+ "vertex_location": vertex_location,
+ "vertex_project": vertex_project,
+ "safety_settings": safety_settings,
+ **optional_params,
+ }
+ if optional_params.get("stream", False) is True:
+ # async streaming
+ return async_streaming(**data)
+
+ return async_completion(**data)
+
+ completion_response = None
+
+ stream = optional_params.pop(
+ "stream", None
+ ) # See note above on handling streaming for vertex ai
+ if mode == "chat":
+ chat = llm_model.start_chat()
+ request_str += "chat = llm_model.start_chat()\n"
+
+ if fake_stream is not True and stream is True:
+ # NOTE: VertexAI does not accept stream=True as a param and raises an error,
+ # we handle this by removing 'stream' from optional params and sending the request
+ # after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
+ optional_params.pop(
+ "stream", None
+ ) # vertex ai raises an error when passing stream in optional params
+
+ request_str += (
+ f"chat.send_message_streaming({prompt}, **{optional_params})\n"
+ )
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+
+ model_response = chat.send_message_streaming(prompt, **optional_params)
+
+ return model_response
+
+ request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ completion_response = chat.send_message(prompt, **optional_params).text
+ elif mode == "text":
+
+ if fake_stream is not True and stream is True:
+ request_str += (
+ f"llm_model.predict_streaming({prompt}, **{optional_params})\n"
+ )
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ model_response = llm_model.predict_streaming(prompt, **optional_params)
+
+ return model_response
+
+ request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ completion_response = llm_model.predict(prompt, **optional_params).text
+ elif mode == "custom":
+ """
+ Vertex AI Model Garden
+ """
+
+ if vertex_project is None or vertex_location is None:
+ raise ValueError(
+ "Vertex project and location are required for custom endpoint"
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ llm_model = aiplatform.gapic.PredictionServiceClient(
+ client_options=client_options
+ )
+ request_str += f"llm_model = aiplatform.gapic.PredictionServiceClient(client_options={client_options})\n"
+ endpoint_path = llm_model.endpoint_path(
+ project=vertex_project, location=vertex_location, endpoint=model
+ )
+ request_str += (
+ f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
+ )
+ response = llm_model.predict(
+ endpoint=endpoint_path, instances=instances
+ ).predictions
+
+ completion_response = response[0]
+ if (
+ isinstance(completion_response, str)
+ and "\nOutput:\n" in completion_response
+ ):
+ completion_response = completion_response.split("\nOutput:\n", 1)[1]
+ if stream is True:
+ response = TextStreamer(completion_response)
+ return response
+ elif mode == "private":
+ """
+ Vertex AI Model Garden deployed on private endpoint
+ """
+ if instances is None:
+ raise ValueError("instances are required for private endpoint")
+ if llm_model is None:
+ raise ValueError("Unable to pick client for private endpoint")
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ request_str += f"llm_model.predict(instances={instances})\n"
+ response = llm_model.predict(instances=instances).predictions
+
+ completion_response = response[0]
+ if (
+ isinstance(completion_response, str)
+ and "\nOutput:\n" in completion_response
+ ):
+ completion_response = completion_response.split("\nOutput:\n", 1)[1]
+ if stream is True:
+ response = TextStreamer(completion_response)
+ return response
+
+ ## LOGGING
+ logging_obj.post_call(
+ input=prompt, api_key=None, original_response=completion_response
+ )
+
+ ## RESPONSE OBJECT
+ if isinstance(completion_response, litellm.Message):
+ model_response.choices[0].message = completion_response # type: ignore
+ elif len(str(completion_response)) > 0:
+ model_response.choices[0].message.content = str(completion_response) # type: ignore
+ model_response.created = int(time.time())
+ model_response.model = model
+ ## CALCULATING USAGE
+ if model in litellm.vertex_language_models and response_obj is not None:
+ model_response.choices[0].finish_reason = map_finish_reason(
+ response_obj.candidates[0].finish_reason.name
+ )
+ usage = Usage(
+ prompt_tokens=response_obj.usage_metadata.prompt_token_count,
+ completion_tokens=response_obj.usage_metadata.candidates_token_count,
+ total_tokens=response_obj.usage_metadata.total_token_count,
+ )
+ else:
+ # init prompt tokens
+ # this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
+ prompt_tokens, completion_tokens, _ = 0, 0, 0
+ if response_obj is not None:
+ if hasattr(response_obj, "usage_metadata") and hasattr(
+ response_obj.usage_metadata, "prompt_token_count"
+ ):
+ prompt_tokens = response_obj.usage_metadata.prompt_token_count
+ completion_tokens = (
+ response_obj.usage_metadata.candidates_token_count
+ )
+ else:
+ prompt_tokens = len(encoding.encode(prompt))
+ completion_tokens = len(
+ encoding.encode(
+ model_response["choices"][0]["message"].get("content", "")
+ )
+ )
+
+ usage = Usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ )
+ setattr(model_response, "usage", usage)
+
+ if fake_stream is True and stream is True:
+ return ModelResponseIterator(model_response)
+ return model_response
+ except Exception as e:
+ if isinstance(e, VertexAIError):
+ raise e
+ raise litellm.APIConnectionError(
+ message=str(e), llm_provider="vertex_ai", model=model
+ )
+
+
+async def async_completion( # noqa: PLR0915
+ llm_model,
+ mode: str,
+ prompt: str,
+ model: str,
+ messages: list,
+ model_response: ModelResponse,
+ request_str: str,
+ print_verbose: Callable,
+ logging_obj,
+ encoding,
+ client_options=None,
+ instances=None,
+ vertex_project=None,
+ vertex_location=None,
+ safety_settings=None,
+ **optional_params,
+):
+ """
+ Add support for acompletion calls for gemini-pro
+ """
+ try:
+
+ response_obj = None
+ completion_response = None
+ if mode == "chat":
+ # chat-bison etc.
+ chat = llm_model.start_chat()
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ response_obj = await chat.send_message_async(prompt, **optional_params)
+ completion_response = response_obj.text
+ elif mode == "text":
+ # gecko etc.
+ request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ response_obj = await llm_model.predict_async(prompt, **optional_params)
+ completion_response = response_obj.text
+ elif mode == "custom":
+ """
+ Vertex AI Model Garden
+ """
+ from google.cloud import aiplatform # type: ignore
+
+ if vertex_project is None or vertex_location is None:
+ raise ValueError(
+ "Vertex project and location are required for custom endpoint"
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+
+ llm_model = aiplatform.gapic.PredictionServiceAsyncClient(
+ client_options=client_options
+ )
+ request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options})\n"
+ endpoint_path = llm_model.endpoint_path(
+ project=vertex_project, location=vertex_location, endpoint=model
+ )
+ request_str += (
+ f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n"
+ )
+ response_obj = await llm_model.predict(
+ endpoint=endpoint_path,
+ instances=instances,
+ )
+ response = response_obj.predictions
+ completion_response = response[0]
+ if (
+ isinstance(completion_response, str)
+ and "\nOutput:\n" in completion_response
+ ):
+ completion_response = completion_response.split("\nOutput:\n", 1)[1]
+
+ elif mode == "private":
+ request_str += f"llm_model.predict_async(instances={instances})\n"
+ response_obj = await llm_model.predict_async(
+ instances=instances,
+ )
+
+ response = response_obj.predictions
+ completion_response = response[0]
+ if (
+ isinstance(completion_response, str)
+ and "\nOutput:\n" in completion_response
+ ):
+ completion_response = completion_response.split("\nOutput:\n", 1)[1]
+
+ ## LOGGING
+ logging_obj.post_call(
+ input=prompt, api_key=None, original_response=completion_response
+ )
+
+ ## RESPONSE OBJECT
+ if isinstance(completion_response, litellm.Message):
+ model_response.choices[0].message = completion_response # type: ignore
+ elif len(str(completion_response)) > 0:
+ model_response.choices[0].message.content = str( # type: ignore
+ completion_response
+ )
+ model_response.created = int(time.time())
+ model_response.model = model
+ ## CALCULATING USAGE
+ if model in litellm.vertex_language_models and response_obj is not None:
+ model_response.choices[0].finish_reason = map_finish_reason(
+ response_obj.candidates[0].finish_reason.name
+ )
+ usage = Usage(
+ prompt_tokens=response_obj.usage_metadata.prompt_token_count,
+ completion_tokens=response_obj.usage_metadata.candidates_token_count,
+ total_tokens=response_obj.usage_metadata.total_token_count,
+ )
+ else:
+ # init prompt tokens
+ # this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
+ prompt_tokens, completion_tokens, _ = 0, 0, 0
+ if response_obj is not None and (
+ hasattr(response_obj, "usage_metadata")
+ and hasattr(response_obj.usage_metadata, "prompt_token_count")
+ ):
+ prompt_tokens = response_obj.usage_metadata.prompt_token_count
+ completion_tokens = response_obj.usage_metadata.candidates_token_count
+ else:
+ prompt_tokens = len(encoding.encode(prompt))
+ completion_tokens = len(
+ encoding.encode(
+ model_response["choices"][0]["message"].get("content", "")
+ )
+ )
+
+ # set usage
+ usage = Usage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=prompt_tokens + completion_tokens,
+ )
+ setattr(model_response, "usage", usage)
+ return model_response
+ except Exception as e:
+ raise VertexAIError(status_code=500, message=str(e))
+
+
+async def async_streaming( # noqa: PLR0915
+ llm_model,
+ mode: str,
+ prompt: str,
+ model: str,
+ model_response: ModelResponse,
+ messages: list,
+ print_verbose: Callable,
+ logging_obj,
+ request_str: str,
+ encoding=None,
+ client_options=None,
+ instances=None,
+ vertex_project=None,
+ vertex_location=None,
+ safety_settings=None,
+ **optional_params,
+):
+ """
+ Add support for async streaming calls for gemini-pro
+ """
+ response: Any = None
+ if mode == "chat":
+ chat = llm_model.start_chat()
+ optional_params.pop(
+ "stream", None
+ ) # vertex ai raises an error when passing stream in optional params
+ request_str += (
+ f"chat.send_message_streaming_async({prompt}, **{optional_params})\n"
+ )
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ response = chat.send_message_streaming_async(prompt, **optional_params)
+
+ elif mode == "text":
+ optional_params.pop(
+ "stream", None
+ ) # See note above on handling streaming for vertex ai
+ request_str += (
+ f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n"
+ )
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ response = llm_model.predict_streaming_async(prompt, **optional_params)
+ elif mode == "custom":
+ from google.cloud import aiplatform # type: ignore
+
+ if vertex_project is None or vertex_location is None:
+ raise ValueError(
+ "Vertex project and location are required for custom endpoint"
+ )
+
+ stream = optional_params.pop("stream", None)
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=prompt,
+ api_key=None,
+ additional_args={
+ "complete_input_dict": optional_params,
+ "request_str": request_str,
+ },
+ )
+ llm_model = aiplatform.gapic.PredictionServiceAsyncClient(
+ client_options=client_options
+ )
+ request_str += f"llm_model = aiplatform.gapic.PredictionServiceAsyncClient(client_options={client_options})\n"
+ endpoint_path = llm_model.endpoint_path(
+ project=vertex_project, location=vertex_location, endpoint=model
+ )
+ request_str += (
+ f"client.predict(endpoint={endpoint_path}, instances={instances})\n"
+ )
+ response_obj = await llm_model.predict(
+ endpoint=endpoint_path,
+ instances=instances,
+ )
+
+ response = response_obj.predictions
+ completion_response = response[0]
+ if (
+ isinstance(completion_response, str)
+ and "\nOutput:\n" in completion_response
+ ):
+ completion_response = completion_response.split("\nOutput:\n", 1)[1]
+ if stream:
+ response = TextStreamer(completion_response)
+
+ elif mode == "private":
+ if instances is None:
+ raise ValueError("Instances are required for private endpoint")
+ stream = optional_params.pop("stream", None)
+ _ = instances[0].pop("stream", None)
+ request_str += f"llm_model.predict_async(instances={instances})\n"
+ response_obj = await llm_model.predict_async(
+ instances=instances,
+ )
+ response = response_obj.predictions
+ completion_response = response[0]
+ if (
+ isinstance(completion_response, str)
+ and "\nOutput:\n" in completion_response
+ ):
+ completion_response = completion_response.split("\nOutput:\n", 1)[1]
+ if stream:
+ response = TextStreamer(completion_response)
+
+ if response is None:
+ raise ValueError("Unable to generate response")
+
+ logging_obj.post_call(input=prompt, api_key=None, original_response=response)
+
+ streamwrapper = CustomStreamWrapper(
+ completion_stream=response,
+ model=model,
+ custom_llm_provider="vertex_ai",
+ logging_obj=logging_obj,
+ )
+
+ return streamwrapper
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/ai21/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/ai21/transformation.py
new file mode 100644
index 00000000..d87b2e03
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/ai21/transformation.py
@@ -0,0 +1,62 @@
+import types
+from typing import Optional
+
+import litellm
+
+
+class VertexAIAi21Config:
+ """
+ Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/ai21
+
+ The class `VertexAIAi21Config` provides configuration for the VertexAI's AI21 API interface
+
+ -> Supports all OpenAI parameters
+ """
+
+ def __init__(
+ self,
+ max_tokens: Optional[int] = None,
+ ) -> None:
+ locals_ = locals().copy()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return {
+ k: v
+ for k, v in cls.__dict__.items()
+ if not k.startswith("__")
+ and not isinstance(
+ v,
+ (
+ types.FunctionType,
+ types.BuiltinFunctionType,
+ classmethod,
+ staticmethod,
+ ),
+ )
+ and v is not None
+ }
+
+ def get_supported_openai_params(self):
+ return litellm.OpenAIConfig().get_supported_openai_params(model="gpt-3.5-turbo")
+
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
+ ):
+ if "max_completion_tokens" in non_default_params:
+ non_default_params["max_tokens"] = non_default_params.pop(
+ "max_completion_tokens"
+ )
+ return litellm.OpenAIConfig().map_openai_params(
+ non_default_params=non_default_params,
+ optional_params=optional_params,
+ model=model,
+ drop_params=drop_params,
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/transformation.py
new file mode 100644
index 00000000..ab0555b0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/transformation.py
@@ -0,0 +1,114 @@
+# What is this?
+## Handler file for calling claude-3 on vertex ai
+from typing import Any, List, Optional
+
+import httpx
+
+import litellm
+from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.utils import ModelResponse
+
+from ....anthropic.chat.transformation import AnthropicConfig
+
+
+class VertexAIError(Exception):
+ def __init__(self, status_code, message):
+ self.status_code = status_code
+ self.message = message
+ self.request = httpx.Request(
+ method="POST", url=" https://cloud.google.com/vertex-ai/"
+ )
+ self.response = httpx.Response(status_code=status_code, request=self.request)
+ super().__init__(
+ self.message
+ ) # Call the base class constructor with the parameters it needs
+
+
+class VertexAIAnthropicConfig(AnthropicConfig):
+ """
+ Reference:https://docs.anthropic.com/claude/reference/messages_post
+
+ Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
+
+ - `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL.
+ - `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16".
+
+ The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters:
+
+ - `max_tokens` Required (integer) max tokens,
+ - `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
+ - `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
+ - `temperature` Optional (float) The amount of randomness injected into the response
+ - `top_p` Optional (float) Use nucleus sampling.
+ - `top_k` Optional (int) Only sample from the top K options for each subsequent token
+ - `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
+
+ Note: Please make sure to modify the default parameters as required for your use case.
+ """
+
+ def transform_request(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ headers: dict,
+ ) -> dict:
+ data = super().transform_request(
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ headers=headers,
+ )
+
+ data.pop("model", None) # vertex anthropic doesn't accept 'model' parameter
+ return data
+
+ def transform_response(
+ self,
+ model: str,
+ raw_response: httpx.Response,
+ model_response: ModelResponse,
+ logging_obj: LiteLLMLoggingObj,
+ request_data: dict,
+ messages: List[AllMessageValues],
+ optional_params: dict,
+ litellm_params: dict,
+ encoding: Any,
+ api_key: Optional[str] = None,
+ json_mode: Optional[bool] = None,
+ ) -> ModelResponse:
+ response = super().transform_response(
+ model,
+ raw_response,
+ model_response,
+ logging_obj,
+ request_data,
+ messages,
+ optional_params,
+ litellm_params,
+ encoding,
+ api_key,
+ json_mode,
+ )
+ response.model = model
+
+ return response
+
+ @classmethod
+ def is_supported_model(cls, model: str, custom_llm_provider: str) -> bool:
+ """
+ Check if the model is supported by the VertexAI Anthropic API.
+ """
+ if (
+ custom_llm_provider != "vertex_ai"
+ and custom_llm_provider != "vertex_ai_beta"
+ ):
+ return False
+ if "claude" in model.lower():
+ return True
+ elif model in litellm.vertex_anthropic_models:
+ return True
+ return False
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/llama3/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/llama3/transformation.py
new file mode 100644
index 00000000..cf46f4a7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/llama3/transformation.py
@@ -0,0 +1,73 @@
+import types
+from typing import Optional
+
+from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
+
+
+class VertexAILlama3Config(OpenAIGPTConfig):
+ """
+ Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
+
+ The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters:
+
+ - `max_tokens` Required (integer) max tokens,
+
+ Note: Please make sure to modify the default parameters as required for your use case.
+ """
+
+ max_tokens: Optional[int] = None
+
+ def __init__(
+ self,
+ max_tokens: Optional[int] = None,
+ ) -> None:
+ locals_ = locals().copy()
+ for key, value in locals_.items():
+ if key == "max_tokens" and value is None:
+ value = self.max_tokens
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return {
+ k: v
+ for k, v in cls.__dict__.items()
+ if not k.startswith("__")
+ and not isinstance(
+ v,
+ (
+ types.FunctionType,
+ types.BuiltinFunctionType,
+ classmethod,
+ staticmethod,
+ ),
+ )
+ and v is not None
+ }
+
+ def get_supported_openai_params(self, model: str):
+ supported_params = super().get_supported_openai_params(model=model)
+ try:
+ supported_params.remove("max_retries")
+ except KeyError:
+ pass
+ return supported_params
+
+ def map_openai_params(
+ self,
+ non_default_params: dict,
+ optional_params: dict,
+ model: str,
+ drop_params: bool,
+ ):
+ if "max_completion_tokens" in non_default_params:
+ non_default_params["max_tokens"] = non_default_params.pop(
+ "max_completion_tokens"
+ )
+ return super().map_openai_params(
+ non_default_params=non_default_params,
+ optional_params=optional_params,
+ model=model,
+ drop_params=drop_params,
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py
new file mode 100644
index 00000000..fb239363
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_ai_partner_models/main.py
@@ -0,0 +1,242 @@
+# What is this?
+## API Handler for calling Vertex AI Partner Models
+from enum import Enum
+from typing import Callable, Optional, Union
+
+import httpx # type: ignore
+
+import litellm
+from litellm import LlmProviders
+from litellm.utils import ModelResponse
+
+from ..vertex_llm_base import VertexBase
+
+
+class VertexPartnerProvider(str, Enum):
+ mistralai = "mistralai"
+ llama = "llama"
+ ai21 = "ai21"
+ claude = "claude"
+
+
+class VertexAIError(Exception):
+ def __init__(self, status_code, message):
+ self.status_code = status_code
+ self.message = message
+ self.request = httpx.Request(
+ method="POST", url=" https://cloud.google.com/vertex-ai/"
+ )
+ self.response = httpx.Response(status_code=status_code, request=self.request)
+ super().__init__(
+ self.message
+ ) # Call the base class constructor with the parameters it needs
+
+
+def create_vertex_url(
+ vertex_location: str,
+ vertex_project: str,
+ partner: VertexPartnerProvider,
+ stream: Optional[bool],
+ model: str,
+ api_base: Optional[str] = None,
+) -> str:
+ """Return the base url for the vertex partner models"""
+ if partner == VertexPartnerProvider.llama:
+ return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi/chat/completions"
+ elif partner == VertexPartnerProvider.mistralai:
+ if stream:
+ return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:streamRawPredict"
+ else:
+ return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/mistralai/models/{model}:rawPredict"
+ elif partner == VertexPartnerProvider.ai21:
+ if stream:
+ return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:streamRawPredict"
+ else:
+ return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/publishers/ai21/models/{model}:rawPredict"
+ elif partner == VertexPartnerProvider.claude:
+ if stream:
+ return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:streamRawPredict"
+ else:
+ return f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/anthropic/models/{model}:rawPredict"
+
+
+class VertexAIPartnerModels(VertexBase):
+ def __init__(self) -> None:
+ pass
+
+ def completion(
+ self,
+ model: str,
+ messages: list,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ encoding,
+ logging_obj,
+ api_base: Optional[str],
+ optional_params: dict,
+ custom_prompt_dict: dict,
+ headers: Optional[dict],
+ timeout: Union[float, httpx.Timeout],
+ litellm_params: dict,
+ vertex_project=None,
+ vertex_location=None,
+ vertex_credentials=None,
+ logger_fn=None,
+ acompletion: bool = False,
+ client=None,
+ ):
+ try:
+ import vertexai
+
+ from litellm.llms.anthropic.chat import AnthropicChatCompletion
+ from litellm.llms.codestral.completion.handler import (
+ CodestralTextCompletion,
+ )
+ from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler
+ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
+ VertexLLM,
+ )
+ except Exception as e:
+ raise VertexAIError(
+ status_code=400,
+ message=f"""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`. Got error: {e}""",
+ )
+
+ if not (
+ hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
+ ):
+ raise VertexAIError(
+ status_code=400,
+ message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
+ )
+ try:
+
+ vertex_httpx_logic = VertexLLM()
+
+ access_token, project_id = vertex_httpx_logic._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider="vertex_ai",
+ )
+
+ openai_like_chat_completions = OpenAILikeChatHandler()
+ codestral_fim_completions = CodestralTextCompletion()
+ anthropic_chat_completions = AnthropicChatCompletion()
+
+ ## CONSTRUCT API BASE
+ stream: bool = optional_params.get("stream", False) or False
+
+ optional_params["stream"] = stream
+
+ if "llama" in model:
+ partner = VertexPartnerProvider.llama
+ elif "mistral" in model or "codestral" in model:
+ partner = VertexPartnerProvider.mistralai
+ elif "jamba" in model:
+ partner = VertexPartnerProvider.ai21
+ elif "claude" in model:
+ partner = VertexPartnerProvider.claude
+
+ default_api_base = create_vertex_url(
+ vertex_location=vertex_location or "us-central1",
+ vertex_project=vertex_project or project_id,
+ partner=partner, # type: ignore
+ stream=stream,
+ model=model,
+ )
+
+ if len(default_api_base.split(":")) > 1:
+ endpoint = default_api_base.split(":")[-1]
+ else:
+ endpoint = ""
+
+ _, api_base = self._check_custom_proxy(
+ api_base=api_base,
+ custom_llm_provider="vertex_ai",
+ gemini_api_key=None,
+ endpoint=endpoint,
+ stream=stream,
+ auth_header=None,
+ url=default_api_base,
+ )
+
+ if "codestral" in model or "mistral" in model:
+ model = model.split("@")[0]
+
+ if "codestral" in model and litellm_params.get("text_completion") is True:
+ optional_params["model"] = model
+ text_completion_model_response = litellm.TextCompletionResponse(
+ stream=stream
+ )
+ return codestral_fim_completions.completion(
+ model=model,
+ messages=messages,
+ api_base=api_base,
+ api_key=access_token,
+ custom_prompt_dict=custom_prompt_dict,
+ model_response=text_completion_model_response,
+ print_verbose=print_verbose,
+ logging_obj=logging_obj,
+ optional_params=optional_params,
+ acompletion=acompletion,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ timeout=timeout,
+ encoding=encoding,
+ )
+ elif "claude" in model:
+ if headers is None:
+ headers = {}
+ headers.update({"Authorization": "Bearer {}".format(access_token)})
+
+ optional_params.update(
+ {
+ "anthropic_version": "vertex-2023-10-16",
+ "is_vertex_request": True,
+ }
+ )
+
+ return anthropic_chat_completions.completion(
+ model=model,
+ messages=messages,
+ api_base=api_base,
+ acompletion=acompletion,
+ custom_prompt_dict=litellm.custom_prompt_dict,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ encoding=encoding, # for calculating input/output tokens
+ api_key=access_token,
+ logging_obj=logging_obj,
+ headers=headers,
+ timeout=timeout,
+ client=client,
+ custom_llm_provider=LlmProviders.VERTEX_AI.value,
+ )
+
+ return openai_like_chat_completions.completion(
+ model=model,
+ messages=messages,
+ api_base=api_base,
+ api_key=access_token,
+ custom_prompt_dict=custom_prompt_dict,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ logging_obj=logging_obj,
+ optional_params=optional_params,
+ acompletion=acompletion,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ client=client,
+ timeout=timeout,
+ encoding=encoding,
+ custom_llm_provider="vertex_ai",
+ custom_endpoint=True,
+ )
+
+ except Exception as e:
+ if hasattr(e, "status_code"):
+ raise e
+ raise VertexAIError(status_code=500, message=str(e))
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py
new file mode 100644
index 00000000..3ef40703
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/embedding_handler.py
@@ -0,0 +1,228 @@
+from typing import Literal, Optional, Union
+
+import httpx
+
+import litellm
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObject
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ _get_httpx_client,
+ get_async_httpx_client,
+)
+from litellm.llms.vertex_ai.vertex_ai_non_gemini import VertexAIError
+from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
+from litellm.types.llms.vertex_ai import *
+from litellm.types.utils import EmbeddingResponse
+
+from .types import *
+
+
+class VertexEmbedding(VertexBase):
+ def __init__(self) -> None:
+ super().__init__()
+
+ def embedding(
+ self,
+ model: str,
+ input: Union[list, str],
+ print_verbose,
+ model_response: EmbeddingResponse,
+ optional_params: dict,
+ logging_obj: LiteLLMLoggingObject,
+ custom_llm_provider: Literal[
+ "vertex_ai", "vertex_ai_beta", "gemini"
+ ], # if it's vertex_ai or gemini (google ai studio)
+ timeout: Optional[Union[float, httpx.Timeout]],
+ api_key: Optional[str] = None,
+ encoding=None,
+ aembedding=False,
+ api_base: Optional[str] = None,
+ client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
+ vertex_project: Optional[str] = None,
+ vertex_location: Optional[str] = None,
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
+ gemini_api_key: Optional[str] = None,
+ extra_headers: Optional[dict] = None,
+ ) -> EmbeddingResponse:
+ if aembedding is True:
+ return self.async_embedding( # type: ignore
+ model=model,
+ input=input,
+ logging_obj=logging_obj,
+ model_response=model_response,
+ optional_params=optional_params,
+ encoding=encoding,
+ custom_llm_provider=custom_llm_provider,
+ timeout=timeout,
+ api_base=api_base,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ gemini_api_key=gemini_api_key,
+ extra_headers=extra_headers,
+ )
+
+ should_use_v1beta1_features = self.is_using_v1beta1_features(
+ optional_params=optional_params
+ )
+
+ _auth_header, vertex_project = self._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider=custom_llm_provider,
+ )
+ auth_header, api_base = self._get_token_and_url(
+ model=model,
+ gemini_api_key=gemini_api_key,
+ auth_header=_auth_header,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ stream=False,
+ custom_llm_provider=custom_llm_provider,
+ api_base=api_base,
+ should_use_v1beta1_features=should_use_v1beta1_features,
+ mode="embedding",
+ )
+ headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
+ vertex_request: VertexEmbeddingRequest = (
+ litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
+ input=input, optional_params=optional_params, model=model
+ )
+ )
+
+ _client_params = {}
+ if timeout:
+ _client_params["timeout"] = timeout
+ if client is None or not isinstance(client, HTTPHandler):
+ client = _get_httpx_client(params=_client_params)
+ else:
+ client = client # type: ignore
+ ## LOGGING
+ logging_obj.pre_call(
+ input=vertex_request,
+ api_key="",
+ additional_args={
+ "complete_input_dict": vertex_request,
+ "api_base": api_base,
+ "headers": headers,
+ },
+ )
+
+ try:
+ response = client.post(api_base, headers=headers, json=vertex_request) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise VertexAIError(status_code=error_code, message=err.response.text)
+ except httpx.TimeoutException:
+ raise VertexAIError(status_code=408, message="Timeout error occurred.")
+
+ _json_response = response.json()
+ ## LOGGING POST-CALL
+ logging_obj.post_call(
+ input=input, api_key=None, original_response=_json_response
+ )
+
+ model_response = (
+ litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
+ response=_json_response, model=model, model_response=model_response
+ )
+ )
+
+ return model_response
+
+ async def async_embedding(
+ self,
+ model: str,
+ input: Union[list, str],
+ model_response: litellm.EmbeddingResponse,
+ logging_obj: LiteLLMLoggingObject,
+ optional_params: dict,
+ custom_llm_provider: Literal[
+ "vertex_ai", "vertex_ai_beta", "gemini"
+ ], # if it's vertex_ai or gemini (google ai studio)
+ timeout: Optional[Union[float, httpx.Timeout]],
+ api_base: Optional[str] = None,
+ client: Optional[AsyncHTTPHandler] = None,
+ vertex_project: Optional[str] = None,
+ vertex_location: Optional[str] = None,
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
+ gemini_api_key: Optional[str] = None,
+ extra_headers: Optional[dict] = None,
+ encoding=None,
+ ) -> litellm.EmbeddingResponse:
+ """
+ Async embedding implementation
+ """
+ should_use_v1beta1_features = self.is_using_v1beta1_features(
+ optional_params=optional_params
+ )
+ _auth_header, vertex_project = await self._ensure_access_token_async(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider=custom_llm_provider,
+ )
+ auth_header, api_base = self._get_token_and_url(
+ model=model,
+ gemini_api_key=gemini_api_key,
+ auth_header=_auth_header,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ stream=False,
+ custom_llm_provider=custom_llm_provider,
+ api_base=api_base,
+ should_use_v1beta1_features=should_use_v1beta1_features,
+ mode="embedding",
+ )
+ headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
+ vertex_request: VertexEmbeddingRequest = (
+ litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
+ input=input, optional_params=optional_params, model=model
+ )
+ )
+
+ _async_client_params = {}
+ if timeout:
+ _async_client_params["timeout"] = timeout
+ if client is None or not isinstance(client, AsyncHTTPHandler):
+ client = get_async_httpx_client(
+ params=_async_client_params, llm_provider=litellm.LlmProviders.VERTEX_AI
+ )
+ else:
+ client = client # type: ignore
+ ## LOGGING
+ logging_obj.pre_call(
+ input=vertex_request,
+ api_key="",
+ additional_args={
+ "complete_input_dict": vertex_request,
+ "api_base": api_base,
+ "headers": headers,
+ },
+ )
+
+ try:
+ response = await client.post(api_base, headers=headers, json=vertex_request) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise VertexAIError(status_code=error_code, message=err.response.text)
+ except httpx.TimeoutException:
+ raise VertexAIError(status_code=408, message="Timeout error occurred.")
+
+ _json_response = response.json()
+ ## LOGGING POST-CALL
+ logging_obj.post_call(
+ input=input, api_key=None, original_response=_json_response
+ )
+
+ model_response = (
+ litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
+ response=_json_response, model=model, model_response=model_response
+ )
+ )
+
+ return model_response
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/transformation.py
new file mode 100644
index 00000000..d9e84fca
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/transformation.py
@@ -0,0 +1,265 @@
+import types
+from typing import List, Literal, Optional, Union
+
+from pydantic import BaseModel
+
+from litellm.types.utils import EmbeddingResponse, Usage
+
+from .types import *
+
+
+class VertexAITextEmbeddingConfig(BaseModel):
+ """
+ Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text-embeddings-api#TextEmbeddingInput
+
+ Args:
+ auto_truncate: Optional(bool) If True, will truncate input text to fit within the model's max input length.
+ task_type: Optional(str) The type of task to be performed. The default is "RETRIEVAL_QUERY".
+ title: Optional(str) The title of the document to be embedded. (only valid with task_type=RETRIEVAL_DOCUMENT).
+ """
+
+ auto_truncate: Optional[bool] = None
+ task_type: Optional[
+ Literal[
+ "RETRIEVAL_QUERY",
+ "RETRIEVAL_DOCUMENT",
+ "SEMANTIC_SIMILARITY",
+ "CLASSIFICATION",
+ "CLUSTERING",
+ "QUESTION_ANSWERING",
+ "FACT_VERIFICATION",
+ ]
+ ] = None
+ title: Optional[str] = None
+
+ def __init__(
+ self,
+ auto_truncate: Optional[bool] = None,
+ task_type: Optional[
+ Literal[
+ "RETRIEVAL_QUERY",
+ "RETRIEVAL_DOCUMENT",
+ "SEMANTIC_SIMILARITY",
+ "CLASSIFICATION",
+ "CLUSTERING",
+ "QUESTION_ANSWERING",
+ "FACT_VERIFICATION",
+ ]
+ ] = None,
+ title: Optional[str] = None,
+ ) -> None:
+ locals_ = locals().copy()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return {
+ k: v
+ for k, v in cls.__dict__.items()
+ if not k.startswith("__")
+ and not isinstance(
+ v,
+ (
+ types.FunctionType,
+ types.BuiltinFunctionType,
+ classmethod,
+ staticmethod,
+ ),
+ )
+ and v is not None
+ }
+
+ def get_supported_openai_params(self):
+ return ["dimensions"]
+
+ def map_openai_params(
+ self, non_default_params: dict, optional_params: dict, kwargs: dict
+ ):
+ for param, value in non_default_params.items():
+ if param == "dimensions":
+ optional_params["outputDimensionality"] = value
+
+ if "input_type" in kwargs:
+ optional_params["task_type"] = kwargs.pop("input_type")
+ return optional_params, kwargs
+
+ def get_mapped_special_auth_params(self) -> dict:
+ """
+ Common auth params across bedrock/vertex_ai/azure/watsonx
+ """
+ return {"project": "vertex_project", "region_name": "vertex_location"}
+
+ def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
+ mapped_params = self.get_mapped_special_auth_params()
+
+ for param, value in non_default_params.items():
+ if param in mapped_params:
+ optional_params[mapped_params[param]] = value
+ return optional_params
+
+ def transform_openai_request_to_vertex_embedding_request(
+ self, input: Union[list, str], optional_params: dict, model: str
+ ) -> VertexEmbeddingRequest:
+ """
+ Transforms an openai request to a vertex embedding request.
+ """
+ if model.isdigit():
+ return self._transform_openai_request_to_fine_tuned_embedding_request(
+ input, optional_params, model
+ )
+
+ vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
+ vertex_text_embedding_input_list: List[TextEmbeddingInput] = []
+ task_type: Optional[TaskType] = optional_params.get("task_type")
+ title = optional_params.get("title")
+
+ if isinstance(input, str):
+ input = [input] # Convert single string to list for uniform processing
+
+ for text in input:
+ embedding_input = self.create_embedding_input(
+ content=text, task_type=task_type, title=title
+ )
+ vertex_text_embedding_input_list.append(embedding_input)
+
+ vertex_request["instances"] = vertex_text_embedding_input_list
+ vertex_request["parameters"] = EmbeddingParameters(**optional_params)
+
+ return vertex_request
+
+ def _transform_openai_request_to_fine_tuned_embedding_request(
+ self, input: Union[list, str], optional_params: dict, model: str
+ ) -> VertexEmbeddingRequest:
+ """
+ Transforms an openai request to a vertex fine-tuned embedding request.
+
+ Vertex Doc: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
+ Sample Request:
+
+ ```json
+ {
+ "instances" : [
+ {
+ "inputs": "How would the Future of AI in 10 Years look?",
+ "parameters": {
+ "max_new_tokens": 128,
+ "temperature": 1.0,
+ "top_p": 0.9,
+ "top_k": 10
+ }
+ }
+ ]
+ }
+ ```
+ """
+ vertex_request: VertexEmbeddingRequest = VertexEmbeddingRequest()
+ vertex_text_embedding_input_list: List[TextEmbeddingFineTunedInput] = []
+ if isinstance(input, str):
+ input = [input] # Convert single string to list for uniform processing
+
+ for text in input:
+ embedding_input = TextEmbeddingFineTunedInput(inputs=text)
+ vertex_text_embedding_input_list.append(embedding_input)
+
+ vertex_request["instances"] = vertex_text_embedding_input_list
+ vertex_request["parameters"] = TextEmbeddingFineTunedParameters(
+ **optional_params
+ )
+
+ return vertex_request
+
+ def create_embedding_input(
+ self,
+ content: str,
+ task_type: Optional[TaskType] = None,
+ title: Optional[str] = None,
+ ) -> TextEmbeddingInput:
+ """
+ Creates a TextEmbeddingInput object.
+
+ Vertex requires a List of TextEmbeddingInput objects. This helper function creates a single TextEmbeddingInput object.
+
+ Args:
+ content (str): The content to be embedded.
+ task_type (Optional[TaskType]): The type of task to be performed".
+ title (Optional[str]): The title of the document to be embedded
+
+ Returns:
+ TextEmbeddingInput: A TextEmbeddingInput object.
+ """
+ text_embedding_input = TextEmbeddingInput(content=content)
+ if task_type is not None:
+ text_embedding_input["task_type"] = task_type
+ if title is not None:
+ text_embedding_input["title"] = title
+ return text_embedding_input
+
+ def transform_vertex_response_to_openai(
+ self, response: dict, model: str, model_response: EmbeddingResponse
+ ) -> EmbeddingResponse:
+ """
+ Transforms a vertex embedding response to an openai response.
+ """
+ if model.isdigit():
+ return self._transform_vertex_response_to_openai_for_fine_tuned_models(
+ response, model, model_response
+ )
+
+ _predictions = response["predictions"]
+
+ embedding_response = []
+ input_tokens: int = 0
+ for idx, element in enumerate(_predictions):
+
+ embedding = element["embeddings"]
+ embedding_response.append(
+ {
+ "object": "embedding",
+ "index": idx,
+ "embedding": embedding["values"],
+ }
+ )
+ input_tokens += embedding["statistics"]["token_count"]
+
+ model_response.object = "list"
+ model_response.data = embedding_response
+ model_response.model = model
+ usage = Usage(
+ prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
+ )
+ setattr(model_response, "usage", usage)
+ return model_response
+
+ def _transform_vertex_response_to_openai_for_fine_tuned_models(
+ self, response: dict, model: str, model_response: EmbeddingResponse
+ ) -> EmbeddingResponse:
+ """
+ Transforms a vertex fine-tuned model embedding response to an openai response format.
+ """
+ _predictions = response["predictions"]
+
+ embedding_response = []
+ # For fine-tuned models, we don't get token counts in the response
+ input_tokens = 0
+
+ for idx, embedding_values in enumerate(_predictions):
+ embedding_response.append(
+ {
+ "object": "embedding",
+ "index": idx,
+ "embedding": embedding_values[
+ 0
+ ], # The embedding values are nested one level deeper
+ }
+ )
+
+ model_response.object = "list"
+ model_response.data = embedding_response
+ model_response.model = model
+ usage = Usage(
+ prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
+ )
+ setattr(model_response, "usage", usage)
+ return model_response
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/types.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/types.py
new file mode 100644
index 00000000..c0c53b17
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_embeddings/types.py
@@ -0,0 +1,62 @@
+"""
+Types for Vertex Embeddings Requests
+"""
+
+from enum import Enum
+from typing import List, Optional, TypedDict, Union
+
+
+class TaskType(str, Enum):
+ RETRIEVAL_QUERY = "RETRIEVAL_QUERY"
+ RETRIEVAL_DOCUMENT = "RETRIEVAL_DOCUMENT"
+ SEMANTIC_SIMILARITY = "SEMANTIC_SIMILARITY"
+ CLASSIFICATION = "CLASSIFICATION"
+ CLUSTERING = "CLUSTERING"
+ QUESTION_ANSWERING = "QUESTION_ANSWERING"
+ FACT_VERIFICATION = "FACT_VERIFICATION"
+ CODE_RETRIEVAL_QUERY = "CODE_RETRIEVAL_QUERY"
+
+
+class TextEmbeddingInput(TypedDict, total=False):
+ content: str
+ task_type: Optional[TaskType]
+ title: Optional[str]
+
+
+# Fine-tuned models require a different input format
+# Ref: https://console.cloud.google.com/vertex-ai/model-garden?hl=en&project=adroit-crow-413218&pageState=(%22galleryStateKey%22:(%22f%22:(%22g%22:%5B%5D,%22o%22:%5B%5D),%22s%22:%22%22))
+class TextEmbeddingFineTunedInput(TypedDict, total=False):
+ inputs: str
+
+
+class TextEmbeddingFineTunedParameters(TypedDict, total=False):
+ max_new_tokens: Optional[int]
+ temperature: Optional[float]
+ top_p: Optional[float]
+ top_k: Optional[int]
+
+
+class EmbeddingParameters(TypedDict, total=False):
+ auto_truncate: Optional[bool]
+ output_dimensionality: Optional[int]
+
+
+class VertexEmbeddingRequest(TypedDict, total=False):
+ instances: Union[List[TextEmbeddingInput], List[TextEmbeddingFineTunedInput]]
+ parameters: Optional[Union[EmbeddingParameters, TextEmbeddingFineTunedParameters]]
+
+
+# Example usage:
+# example_request: VertexEmbeddingRequest = {
+# "instances": [
+# {
+# "content": "I would like embeddings for this text!",
+# "task_type": "RETRIEVAL_DOCUMENT",
+# "title": "document title"
+# }
+# ],
+# "parameters": {
+# "auto_truncate": True,
+# "output_dimensionality": None
+# }
+# }
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_llm_base.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_llm_base.py
new file mode 100644
index 00000000..8286cb51
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_llm_base.py
@@ -0,0 +1,319 @@
+"""
+Base Vertex, Google AI Studio LLM Class
+
+Handles Authentication and generating request urls for Vertex AI and Google AI Studio
+"""
+
+import json
+import os
+from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple
+
+from litellm._logging import verbose_logger
+from litellm.litellm_core_utils.asyncify import asyncify
+from litellm.llms.base import BaseLLM
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
+from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
+
+from .common_utils import _get_gemini_url, _get_vertex_url, all_gemini_url_modes
+
+if TYPE_CHECKING:
+ from google.auth.credentials import Credentials as GoogleCredentialsObject
+else:
+ GoogleCredentialsObject = Any
+
+
+class VertexBase(BaseLLM):
+ def __init__(self) -> None:
+ super().__init__()
+ self.access_token: Optional[str] = None
+ self.refresh_token: Optional[str] = None
+ self._credentials: Optional[GoogleCredentialsObject] = None
+ self.project_id: Optional[str] = None
+ self.async_handler: Optional[AsyncHTTPHandler] = None
+
+ def get_vertex_region(self, vertex_region: Optional[str]) -> str:
+ return vertex_region or "us-central1"
+
+ def load_auth(
+ self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], project_id: Optional[str]
+ ) -> Tuple[Any, str]:
+ import google.auth as google_auth
+ from google.auth import identity_pool
+ from google.auth.transport.requests import (
+ Request, # type: ignore[import-untyped]
+ )
+
+ if credentials is not None:
+ import google.oauth2.service_account
+
+ if isinstance(credentials, str):
+ verbose_logger.debug(
+ "Vertex: Loading vertex credentials from %s", credentials
+ )
+ verbose_logger.debug(
+ "Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
+ credentials,
+ os.path.exists(credentials),
+ os.getcwd(),
+ )
+
+ try:
+ if os.path.exists(credentials):
+ json_obj = json.load(open(credentials))
+ else:
+ json_obj = json.loads(credentials)
+ except Exception:
+ raise Exception(
+ "Unable to load vertex credentials from environment. Got={}".format(
+ credentials
+ )
+ )
+ elif isinstance(credentials, dict):
+ json_obj = credentials
+ else:
+ raise ValueError(
+ "Invalid credentials type: {}".format(type(credentials))
+ )
+
+ # Check if the JSON object contains Workload Identity Federation configuration
+ if "type" in json_obj and json_obj["type"] == "external_account":
+ creds = identity_pool.Credentials.from_info(json_obj)
+ else:
+ creds = (
+ google.oauth2.service_account.Credentials.from_service_account_info(
+ json_obj,
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+ )
+
+ if project_id is None:
+ project_id = getattr(creds, "project_id", None)
+ else:
+ creds, creds_project_id = google_auth.default(
+ quota_project_id=project_id,
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+ if project_id is None:
+ project_id = creds_project_id
+
+ creds.refresh(Request()) # type: ignore
+
+ if not project_id:
+ raise ValueError("Could not resolve project_id")
+
+ if not isinstance(project_id, str):
+ raise TypeError(
+ f"Expected project_id to be a str but got {type(project_id)}"
+ )
+
+ return creds, project_id
+
+ def refresh_auth(self, credentials: Any) -> None:
+ from google.auth.transport.requests import (
+ Request, # type: ignore[import-untyped]
+ )
+
+ credentials.refresh(Request())
+
+ def _ensure_access_token(
+ self,
+ credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ project_id: Optional[str],
+ custom_llm_provider: Literal[
+ "vertex_ai", "vertex_ai_beta", "gemini"
+ ], # if it's vertex_ai or gemini (google ai studio)
+ ) -> Tuple[str, str]:
+ """
+ Returns auth token and project id
+ """
+ if custom_llm_provider == "gemini":
+ return "", ""
+ if self.access_token is not None:
+ if project_id is not None:
+ return self.access_token, project_id
+ elif self.project_id is not None:
+ return self.access_token, self.project_id
+
+ if not self._credentials:
+ self._credentials, cred_project_id = self.load_auth(
+ credentials=credentials, project_id=project_id
+ )
+ if not self.project_id:
+ self.project_id = project_id or cred_project_id
+ else:
+ if self._credentials.expired or not self._credentials.token:
+ self.refresh_auth(self._credentials)
+
+ if not self.project_id:
+ self.project_id = self._credentials.quota_project_id
+
+ if not self.project_id:
+ raise ValueError("Could not resolve project_id")
+
+ if not self._credentials or not self._credentials.token:
+ raise RuntimeError("Could not resolve API token from the environment")
+
+ return self._credentials.token, project_id or self.project_id
+
+ def is_using_v1beta1_features(self, optional_params: dict) -> bool:
+ """
+ VertexAI only supports ContextCaching on v1beta1
+
+ use this helper to decide if request should be sent to v1 or v1beta1
+
+ Returns v1beta1 if context caching is enabled
+ Returns v1 in all other cases
+ """
+ if "cached_content" in optional_params:
+ return True
+ if "CachedContent" in optional_params:
+ return True
+ return False
+
+ def _check_custom_proxy(
+ self,
+ api_base: Optional[str],
+ custom_llm_provider: str,
+ gemini_api_key: Optional[str],
+ endpoint: str,
+ stream: Optional[bool],
+ auth_header: Optional[str],
+ url: str,
+ ) -> Tuple[Optional[str], str]:
+ """
+ for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
+
+ ## Returns
+ - (auth_header, url) - Tuple[Optional[str], str]
+ """
+ if api_base:
+ if custom_llm_provider == "gemini":
+ url = "{}:{}".format(api_base, endpoint)
+ if gemini_api_key is None:
+ raise ValueError(
+ "Missing gemini_api_key, please set `GEMINI_API_KEY`"
+ )
+ auth_header = (
+ gemini_api_key # cloudflare expects api key as bearer token
+ )
+ else:
+ url = "{}:{}".format(api_base, endpoint)
+
+ if stream is True:
+ url = url + "?alt=sse"
+ return auth_header, url
+
+ def _get_token_and_url(
+ self,
+ model: str,
+ auth_header: Optional[str],
+ gemini_api_key: Optional[str],
+ vertex_project: Optional[str],
+ vertex_location: Optional[str],
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ stream: Optional[bool],
+ custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
+ api_base: Optional[str],
+ should_use_v1beta1_features: Optional[bool] = False,
+ mode: all_gemini_url_modes = "chat",
+ ) -> Tuple[Optional[str], str]:
+ """
+ Internal function. Returns the token and url for the call.
+
+ Handles logic if it's google ai studio vs. vertex ai.
+
+ Returns
+ token, url
+ """
+ if custom_llm_provider == "gemini":
+ url, endpoint = _get_gemini_url(
+ mode=mode,
+ model=model,
+ stream=stream,
+ gemini_api_key=gemini_api_key,
+ )
+ auth_header = None # this field is not used for gemin
+ else:
+ vertex_location = self.get_vertex_region(vertex_region=vertex_location)
+
+ ### SET RUNTIME ENDPOINT ###
+ version: Literal["v1beta1", "v1"] = (
+ "v1beta1" if should_use_v1beta1_features is True else "v1"
+ )
+ url, endpoint = _get_vertex_url(
+ mode=mode,
+ model=model,
+ stream=stream,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_api_version=version,
+ )
+
+ return self._check_custom_proxy(
+ api_base=api_base,
+ auth_header=auth_header,
+ custom_llm_provider=custom_llm_provider,
+ gemini_api_key=gemini_api_key,
+ endpoint=endpoint,
+ stream=stream,
+ url=url,
+ )
+
+ async def _ensure_access_token_async(
+ self,
+ credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ project_id: Optional[str],
+ custom_llm_provider: Literal[
+ "vertex_ai", "vertex_ai_beta", "gemini"
+ ], # if it's vertex_ai or gemini (google ai studio)
+ ) -> Tuple[str, str]:
+ """
+ Async version of _ensure_access_token
+ """
+ if custom_llm_provider == "gemini":
+ return "", ""
+ if self.access_token is not None:
+ if project_id is not None:
+ return self.access_token, project_id
+ elif self.project_id is not None:
+ return self.access_token, self.project_id
+
+ if not self._credentials:
+ try:
+ self._credentials, cred_project_id = await asyncify(self.load_auth)(
+ credentials=credentials, project_id=project_id
+ )
+ except Exception:
+ verbose_logger.exception(
+ "Failed to load vertex credentials. Check to see if credentials containing partial/invalid information."
+ )
+ raise
+ if not self.project_id:
+ self.project_id = project_id or cred_project_id
+ else:
+ if self._credentials.expired or not self._credentials.token:
+ await asyncify(self.refresh_auth)(self._credentials)
+
+ if not self.project_id:
+ self.project_id = self._credentials.quota_project_id
+
+ if not self.project_id:
+ raise ValueError("Could not resolve project_id")
+
+ if not self._credentials or not self._credentials.token:
+ raise RuntimeError("Could not resolve API token from the environment")
+
+ return self._credentials.token, project_id or self.project_id
+
+ def set_headers(
+ self, auth_header: Optional[str], extra_headers: Optional[dict]
+ ) -> dict:
+ headers = {
+ "Content-Type": "application/json",
+ }
+ if auth_header is not None:
+ headers["Authorization"] = f"Bearer {auth_header}"
+ if extra_headers is not None:
+ headers.update(extra_headers)
+
+ return headers
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_model_garden/main.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_model_garden/main.py
new file mode 100644
index 00000000..7b54d4e3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_model_garden/main.py
@@ -0,0 +1,149 @@
+"""
+API Handler for calling Vertex AI Model Garden Models
+
+Most Vertex Model Garden Models are OpenAI compatible - so this handler calls `openai_like_chat_completions`
+
+Usage:
+
+response = litellm.completion(
+ model="vertex_ai/openai/5464397967697903616",
+ messages=[{"role": "user", "content": "Hello, how are you?"}],
+)
+
+Sent to this route when `model` is in the format `vertex_ai/openai/{MODEL_ID}`
+
+
+Vertex Documentation for using the OpenAI /chat/completions endpoint: https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_pytorch_llama3_deployment.ipynb
+"""
+
+from typing import Callable, Optional, Union
+
+import httpx # type: ignore
+
+from litellm.utils import ModelResponse
+
+from ..common_utils import VertexAIError
+from ..vertex_llm_base import VertexBase
+
+
+def create_vertex_url(
+ vertex_location: str,
+ vertex_project: str,
+ stream: Optional[bool],
+ model: str,
+ api_base: Optional[str] = None,
+) -> str:
+ """Return the base url for the vertex garden models"""
+ # f"https://{self.endpoint.location}-aiplatform.googleapis.com/v1beta1/projects/{PROJECT_ID}/locations/{self.endpoint.location}"
+ return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/{model}"
+
+
+class VertexAIModelGardenModels(VertexBase):
+ def __init__(self) -> None:
+ pass
+
+ def completion(
+ self,
+ model: str,
+ messages: list,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ encoding,
+ logging_obj,
+ api_base: Optional[str],
+ optional_params: dict,
+ custom_prompt_dict: dict,
+ headers: Optional[dict],
+ timeout: Union[float, httpx.Timeout],
+ litellm_params: dict,
+ vertex_project=None,
+ vertex_location=None,
+ vertex_credentials=None,
+ logger_fn=None,
+ acompletion: bool = False,
+ client=None,
+ ):
+ """
+ Handles calling Vertex AI Model Garden Models in OpenAI compatible format
+
+ Sent to this route when `model` is in the format `vertex_ai/openai/{MODEL_ID}`
+ """
+ try:
+ import vertexai
+
+ from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler
+ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
+ VertexLLM,
+ )
+ except Exception as e:
+
+ raise VertexAIError(
+ status_code=400,
+ message=f"""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`. Got error: {e}""",
+ )
+
+ if not (
+ hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
+ ):
+ raise VertexAIError(
+ status_code=400,
+ message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
+ )
+ try:
+ model = model.replace("openai/", "")
+ vertex_httpx_logic = VertexLLM()
+
+ access_token, project_id = vertex_httpx_logic._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider="vertex_ai",
+ )
+
+ openai_like_chat_completions = OpenAILikeChatHandler()
+
+ ## CONSTRUCT API BASE
+ stream: bool = optional_params.get("stream", False) or False
+ optional_params["stream"] = stream
+ default_api_base = create_vertex_url(
+ vertex_location=vertex_location or "us-central1",
+ vertex_project=vertex_project or project_id,
+ stream=stream,
+ model=model,
+ )
+
+ if len(default_api_base.split(":")) > 1:
+ endpoint = default_api_base.split(":")[-1]
+ else:
+ endpoint = ""
+
+ _, api_base = self._check_custom_proxy(
+ api_base=api_base,
+ custom_llm_provider="vertex_ai",
+ gemini_api_key=None,
+ endpoint=endpoint,
+ stream=stream,
+ auth_header=None,
+ url=default_api_base,
+ )
+ model = ""
+ return openai_like_chat_completions.completion(
+ model=model,
+ messages=messages,
+ api_base=api_base,
+ api_key=access_token,
+ custom_prompt_dict=custom_prompt_dict,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ logging_obj=logging_obj,
+ optional_params=optional_params,
+ acompletion=acompletion,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ client=client,
+ timeout=timeout,
+ encoding=encoding,
+ custom_llm_provider="vertex_ai",
+ )
+
+ except Exception as e:
+ raise VertexAIError(status_code=500, message=str(e))