aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py479
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py1495
2 files changed, 1974 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py
new file mode 100644
index 00000000..d6bafc7c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py
@@ -0,0 +1,479 @@
+"""
+Transformation logic from OpenAI format to Gemini format.
+
+Why separate file? Make it easy to see how transformation works
+"""
+
+import os
+from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union, cast
+
+import httpx
+from pydantic import BaseModel
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.litellm_core_utils.prompt_templates.factory import (
+ convert_to_anthropic_image_obj,
+ convert_to_gemini_tool_call_invoke,
+ convert_to_gemini_tool_call_result,
+ response_schema_prompt,
+)
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
+from litellm.types.files import (
+ get_file_mime_type_for_file_type,
+ get_file_type_from_extension,
+ is_gemini_1_5_accepted_file_type,
+)
+from litellm.types.llms.openai import (
+ AllMessageValues,
+ ChatCompletionAssistantMessage,
+ ChatCompletionImageObject,
+ ChatCompletionTextObject,
+)
+from litellm.types.llms.vertex_ai import *
+from litellm.types.llms.vertex_ai import (
+ GenerationConfig,
+ PartType,
+ RequestBody,
+ SafetSettingsConfig,
+ SystemInstructions,
+ ToolConfig,
+ Tools,
+)
+
+from ..common_utils import (
+ _check_text_in_content,
+ get_supports_response_schema,
+ get_supports_system_message,
+)
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
+
+ LiteLLMLoggingObj = _LiteLLMLoggingObj
+else:
+ LiteLLMLoggingObj = Any
+
+
+def _process_gemini_image(image_url: str, format: Optional[str] = None) -> PartType:
+ """
+ Given an image URL, return the appropriate PartType for Gemini
+ """
+
+ try:
+ # GCS URIs
+ if "gs://" in image_url:
+ # Figure out file type
+ extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
+ extension = extension_with_dot[1:] # Ex: "png"
+
+ if not format:
+ file_type = get_file_type_from_extension(extension)
+
+ # Validate the file type is supported by Gemini
+ if not is_gemini_1_5_accepted_file_type(file_type):
+ raise Exception(f"File type not supported by gemini - {file_type}")
+
+ mime_type = get_file_mime_type_for_file_type(file_type)
+ else:
+ mime_type = format
+ file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
+
+ return PartType(file_data=file_data)
+ elif (
+ "https://" in image_url
+ and (image_type := format or _get_image_mime_type_from_url(image_url))
+ is not None
+ ):
+
+ file_data = FileDataType(file_uri=image_url, mime_type=image_type)
+ return PartType(file_data=file_data)
+ elif "http://" in image_url or "https://" in image_url or "base64" in image_url:
+ # https links for unsupported mime types and base64 images
+ image = convert_to_anthropic_image_obj(image_url, format=format)
+ _blob = BlobType(data=image["data"], mime_type=image["media_type"])
+ return PartType(inline_data=_blob)
+ raise Exception("Invalid image received - {}".format(image_url))
+ except Exception as e:
+ raise e
+
+
+def _get_image_mime_type_from_url(url: str) -> Optional[str]:
+ """
+ Get mime type for common image URLs
+ See gemini mime types: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/image-understanding#image-requirements
+
+ Supported by Gemini:
+ - PNG (`image/png`)
+ - JPEG (`image/jpeg`)
+ - WebP (`image/webp`)
+ Example:
+ url = https://example.com/image.jpg
+ Returns: image/jpeg
+ """
+ url = url.lower()
+ if url.endswith((".jpg", ".jpeg")):
+ return "image/jpeg"
+ elif url.endswith(".png"):
+ return "image/png"
+ elif url.endswith(".webp"):
+ return "image/webp"
+ elif url.endswith(".mp4"):
+ return "video/mp4"
+ elif url.endswith(".pdf"):
+ return "application/pdf"
+ return None
+
+
+def _gemini_convert_messages_with_history( # noqa: PLR0915
+ messages: List[AllMessageValues],
+) -> List[ContentType]:
+ """
+ Converts given messages from OpenAI format to Gemini format
+
+ - Parts must be iterable
+ - Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
+ - Please ensure that function response turn comes immediately after a function call turn
+ """
+ user_message_types = {"user", "system"}
+ contents: List[ContentType] = []
+
+ last_message_with_tool_calls = None
+
+ msg_i = 0
+ tool_call_responses = []
+ try:
+ while msg_i < len(messages):
+ user_content: List[PartType] = []
+ init_msg_i = msg_i
+ ## MERGE CONSECUTIVE USER CONTENT ##
+ while (
+ msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
+ ):
+ _message_content = messages[msg_i].get("content")
+ if _message_content is not None and isinstance(_message_content, list):
+ _parts: List[PartType] = []
+ for element in _message_content:
+ if (
+ element["type"] == "text"
+ and "text" in element
+ and len(element["text"]) > 0
+ ):
+ element = cast(ChatCompletionTextObject, element)
+ _part = PartType(text=element["text"])
+ _parts.append(_part)
+ elif element["type"] == "image_url":
+ element = cast(ChatCompletionImageObject, element)
+ img_element = element
+ format: Optional[str] = None
+ if isinstance(img_element["image_url"], dict):
+ image_url = img_element["image_url"]["url"]
+ format = img_element["image_url"].get("format")
+ else:
+ image_url = img_element["image_url"]
+ _part = _process_gemini_image(
+ image_url=image_url, format=format
+ )
+ _parts.append(_part)
+ user_content.extend(_parts)
+ elif (
+ _message_content is not None
+ and isinstance(_message_content, str)
+ and len(_message_content) > 0
+ ):
+ _part = PartType(text=_message_content)
+ user_content.append(_part)
+
+ msg_i += 1
+
+ if user_content:
+ """
+ check that user_content has 'text' parameter.
+ - Known Vertex Error: Unable to submit request because it must have a text parameter.
+ - Relevant Issue: https://github.com/BerriAI/litellm/issues/5515
+ """
+ has_text_in_content = _check_text_in_content(user_content)
+ if has_text_in_content is False:
+ verbose_logger.warning(
+ "No text in user content. Adding a blank text to user content, to ensure Gemini doesn't fail the request. Relevant Issue - https://github.com/BerriAI/litellm/issues/5515"
+ )
+ user_content.append(
+ PartType(text=" ")
+ ) # add a blank text, to ensure Gemini doesn't fail the request.
+ contents.append(ContentType(role="user", parts=user_content))
+ assistant_content = []
+ ## MERGE CONSECUTIVE ASSISTANT CONTENT ##
+ while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
+ if isinstance(messages[msg_i], BaseModel):
+ msg_dict: Union[ChatCompletionAssistantMessage, dict] = messages[msg_i].model_dump() # type: ignore
+ else:
+ msg_dict = messages[msg_i] # type: ignore
+ assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore
+ _message_content = assistant_msg.get("content", None)
+ if _message_content is not None and isinstance(_message_content, list):
+ _parts = []
+ for element in _message_content:
+ if isinstance(element, dict):
+ if element["type"] == "text":
+ _part = PartType(text=element["text"])
+ _parts.append(_part)
+ assistant_content.extend(_parts)
+ elif (
+ _message_content is not None
+ and isinstance(_message_content, str)
+ and _message_content
+ ):
+ assistant_text = _message_content # either string or none
+ assistant_content.append(PartType(text=assistant_text)) # type: ignore
+
+ ## HANDLE ASSISTANT FUNCTION CALL
+ if (
+ assistant_msg.get("tool_calls", []) is not None
+ or assistant_msg.get("function_call") is not None
+ ): # support assistant tool invoke conversion
+ assistant_content.extend(
+ convert_to_gemini_tool_call_invoke(assistant_msg)
+ )
+ last_message_with_tool_calls = assistant_msg
+
+ msg_i += 1
+
+ if assistant_content:
+ contents.append(ContentType(role="model", parts=assistant_content))
+
+ ## APPEND TOOL CALL MESSAGES ##
+ tool_call_message_roles = ["tool", "function"]
+ if (
+ msg_i < len(messages)
+ and messages[msg_i]["role"] in tool_call_message_roles
+ ):
+ _part = convert_to_gemini_tool_call_result(
+ messages[msg_i], last_message_with_tool_calls # type: ignore
+ )
+ msg_i += 1
+ tool_call_responses.append(_part)
+ if msg_i < len(messages) and (
+ messages[msg_i]["role"] not in tool_call_message_roles
+ ):
+ if len(tool_call_responses) > 0:
+ contents.append(ContentType(parts=tool_call_responses))
+ tool_call_responses = []
+
+ if msg_i == init_msg_i: # prevent infinite loops
+ raise Exception(
+ "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
+ messages[msg_i]
+ )
+ )
+ if len(tool_call_responses) > 0:
+ contents.append(ContentType(parts=tool_call_responses))
+ return contents
+ except Exception as e:
+ raise e
+
+
+def _transform_request_body(
+ messages: List[AllMessageValues],
+ model: str,
+ optional_params: dict,
+ custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
+ litellm_params: dict,
+ cached_content: Optional[str],
+) -> RequestBody:
+ """
+ Common transformation logic across sync + async Gemini /generateContent calls.
+ """
+ # Separate system prompt from rest of message
+ supports_system_message = get_supports_system_message(
+ model=model, custom_llm_provider=custom_llm_provider
+ )
+ system_instructions, messages = _transform_system_message(
+ supports_system_message=supports_system_message, messages=messages
+ )
+ # Checks for 'response_schema' support - if passed in
+ if "response_schema" in optional_params:
+ supports_response_schema = get_supports_response_schema(
+ model=model, custom_llm_provider=custom_llm_provider
+ )
+ if supports_response_schema is False:
+ user_response_schema_message = response_schema_prompt(
+ model=model, response_schema=optional_params.get("response_schema") # type: ignore
+ )
+ messages.append({"role": "user", "content": user_response_schema_message})
+ optional_params.pop("response_schema")
+
+ # Check for any 'litellm_param_*' set during optional param mapping
+
+ remove_keys = []
+ for k, v in optional_params.items():
+ if k.startswith("litellm_param_"):
+ litellm_params.update({k: v})
+ remove_keys.append(k)
+
+ optional_params = {k: v for k, v in optional_params.items() if k not in remove_keys}
+
+ try:
+ if custom_llm_provider == "gemini":
+ content = litellm.GoogleAIStudioGeminiConfig()._transform_messages(
+ messages=messages
+ )
+ else:
+ content = litellm.VertexGeminiConfig()._transform_messages(
+ messages=messages
+ )
+ tools: Optional[Tools] = optional_params.pop("tools", None)
+ tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
+ safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
+ "safety_settings", None
+ ) # type: ignore
+ config_fields = GenerationConfig.__annotations__.keys()
+
+ filtered_params = {
+ k: v for k, v in optional_params.items() if k in config_fields
+ }
+
+ generation_config: Optional[GenerationConfig] = GenerationConfig(
+ **filtered_params
+ )
+ data = RequestBody(contents=content)
+ if system_instructions is not None:
+ data["system_instruction"] = system_instructions
+ if tools is not None:
+ data["tools"] = tools
+ if tool_choice is not None:
+ data["toolConfig"] = tool_choice
+ if safety_settings is not None:
+ data["safetySettings"] = safety_settings
+ if generation_config is not None:
+ data["generationConfig"] = generation_config
+ if cached_content is not None:
+ data["cachedContent"] = cached_content
+ except Exception as e:
+ raise e
+
+ return data
+
+
+def sync_transform_request_body(
+ gemini_api_key: Optional[str],
+ messages: List[AllMessageValues],
+ api_base: Optional[str],
+ model: str,
+ client: Optional[HTTPHandler],
+ timeout: Optional[Union[float, httpx.Timeout]],
+ extra_headers: Optional[dict],
+ optional_params: dict,
+ logging_obj: LiteLLMLoggingObj,
+ custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
+ litellm_params: dict,
+) -> RequestBody:
+ from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints
+
+ context_caching_endpoints = ContextCachingEndpoints()
+
+ if gemini_api_key is not None:
+ messages, cached_content = context_caching_endpoints.check_and_create_cache(
+ messages=messages,
+ api_key=gemini_api_key,
+ api_base=api_base,
+ model=model,
+ client=client,
+ timeout=timeout,
+ extra_headers=extra_headers,
+ cached_content=optional_params.pop("cached_content", None),
+ logging_obj=logging_obj,
+ )
+ else: # [TODO] implement context caching for gemini as well
+ cached_content = optional_params.pop("cached_content", None)
+
+ return _transform_request_body(
+ messages=messages,
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ litellm_params=litellm_params,
+ cached_content=cached_content,
+ optional_params=optional_params,
+ )
+
+
+async def async_transform_request_body(
+ gemini_api_key: Optional[str],
+ messages: List[AllMessageValues],
+ api_base: Optional[str],
+ model: str,
+ client: Optional[AsyncHTTPHandler],
+ timeout: Optional[Union[float, httpx.Timeout]],
+ extra_headers: Optional[dict],
+ optional_params: dict,
+ logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, # type: ignore
+ custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
+ litellm_params: dict,
+) -> RequestBody:
+ from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints
+
+ context_caching_endpoints = ContextCachingEndpoints()
+
+ if gemini_api_key is not None:
+ messages, cached_content = (
+ await context_caching_endpoints.async_check_and_create_cache(
+ messages=messages,
+ api_key=gemini_api_key,
+ api_base=api_base,
+ model=model,
+ client=client,
+ timeout=timeout,
+ extra_headers=extra_headers,
+ cached_content=optional_params.pop("cached_content", None),
+ logging_obj=logging_obj,
+ )
+ )
+ else: # [TODO] implement context caching for gemini as well
+ cached_content = optional_params.pop("cached_content", None)
+
+ return _transform_request_body(
+ messages=messages,
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ litellm_params=litellm_params,
+ cached_content=cached_content,
+ optional_params=optional_params,
+ )
+
+
+def _transform_system_message(
+ supports_system_message: bool, messages: List[AllMessageValues]
+) -> Tuple[Optional[SystemInstructions], List[AllMessageValues]]:
+ """
+ Extracts the system message from the openai message list.
+
+ Converts the system message to Gemini format
+
+ Returns
+ - system_content_blocks: Optional[SystemInstructions] - the system message list in Gemini format.
+ - messages: List[AllMessageValues] - filtered list of messages in OpenAI format (transformed separately)
+ """
+ # Separate system prompt from rest of message
+ system_prompt_indices = []
+ system_content_blocks: List[PartType] = []
+ if supports_system_message is True:
+ for idx, message in enumerate(messages):
+ if message["role"] == "system":
+ _system_content_block: Optional[PartType] = None
+ if isinstance(message["content"], str):
+ _system_content_block = PartType(text=message["content"])
+ elif isinstance(message["content"], list):
+ system_text = ""
+ for content in message["content"]:
+ system_text += content.get("text") or ""
+ _system_content_block = PartType(text=system_text)
+ if _system_content_block is not None:
+ system_content_blocks.append(_system_content_block)
+ system_prompt_indices.append(idx)
+ if len(system_prompt_indices) > 0:
+ for idx in reversed(system_prompt_indices):
+ messages.pop(idx)
+
+ if len(system_content_blocks) > 0:
+ return SystemInstructions(parts=system_content_blocks), messages
+
+ return None, messages
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py
new file mode 100644
index 00000000..9ac1b1ff
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py
@@ -0,0 +1,1495 @@
+# What is this?
+## httpx client for vertex ai calls
+## Initial implementation - covers gemini + image gen calls
+import json
+import uuid
+from copy import deepcopy
+from functools import partial
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ List,
+ Literal,
+ Optional,
+ Tuple,
+ Union,
+ cast,
+)
+
+import httpx # type: ignore
+
+import litellm
+import litellm.litellm_core_utils
+import litellm.litellm_core_utils.litellm_logging
+from litellm import verbose_logger
+from litellm.litellm_core_utils.core_helpers import map_finish_reason
+from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ get_async_httpx_client,
+)
+from litellm.types.llms.openai import (
+ AllMessageValues,
+ ChatCompletionResponseMessage,
+ ChatCompletionToolCallChunk,
+ ChatCompletionToolCallFunctionChunk,
+ ChatCompletionToolParamFunctionChunk,
+ ChatCompletionUsageBlock,
+)
+from litellm.types.llms.vertex_ai import (
+ VERTEX_CREDENTIALS_TYPES,
+ Candidates,
+ ContentType,
+ FunctionCallingConfig,
+ FunctionDeclaration,
+ GenerateContentResponseBody,
+ HttpxPartType,
+ LogprobsResult,
+ ToolConfig,
+ Tools,
+)
+from litellm.types.utils import (
+ ChatCompletionTokenLogprob,
+ ChoiceLogprobs,
+ GenericStreamingChunk,
+ PromptTokensDetailsWrapper,
+ TopLogprob,
+ Usage,
+)
+from litellm.utils import CustomStreamWrapper, ModelResponse
+
+from ....utils import _remove_additional_properties, _remove_strict_from_schema
+from ..common_utils import VertexAIError, _build_vertex_schema
+from ..vertex_llm_base import VertexBase
+from .transformation import (
+ _gemini_convert_messages_with_history,
+ async_transform_request_body,
+ sync_transform_request_body,
+)
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+
+ LoggingClass = LiteLLMLoggingObj
+else:
+ LoggingClass = Any
+
+
+class VertexAIBaseConfig:
+ def get_mapped_special_auth_params(self) -> dict:
+ """
+ Common auth params across bedrock/vertex_ai/azure/watsonx
+ """
+ return {"project": "vertex_project", "region_name": "vertex_location"}
+
+ def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
+ mapped_params = self.get_mapped_special_auth_params()
+
+ for param, value in non_default_params.items():
+ if param in mapped_params:
+ optional_params[mapped_params[param]] = value
+ return optional_params
+
+ def get_eu_regions(self) -> List[str]:
+ """
+ Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
+ """
+ return [
+ "europe-central2",
+ "europe-north1",
+ "europe-southwest1",
+ "europe-west1",
+ "europe-west2",
+ "europe-west3",
+ "europe-west4",
+ "europe-west6",
+ "europe-west8",
+ "europe-west9",
+ ]
+
+ def get_us_regions(self) -> List[str]:
+ """
+ Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
+ """
+ return [
+ "us-central1",
+ "us-east1",
+ "us-east4",
+ "us-east5",
+ "us-south1",
+ "us-west1",
+ "us-west4",
+ "us-west5",
+ ]
+
+
+class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig):
+ """
+ Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
+ Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
+
+ The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters:
+
+ - `temperature` (float): This controls the degree of randomness in token selection.
+
+ - `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256.
+
+ - `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95.
+
+ - `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
+
+ - `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'.
+
+ - `candidate_count` (int): Number of generated responses to return.
+
+ - `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
+
+ - `frequency_penalty` (float): This parameter is used to penalize the model from repeating the same output. The default value is 0.0.
+
+ - `presence_penalty` (float): This parameter is used to penalize the model from generating the same output as the input. The default value is 0.0.
+
+ - `seed` (int): The seed value is used to help generate the same output for the same input. The default value is None.
+
+ Note: Please make sure to modify the default parameters as required for your use case.
+ """
+
+ temperature: Optional[float] = None
+ max_output_tokens: Optional[int] = None
+ top_p: Optional[float] = None
+ top_k: Optional[int] = None
+ response_mime_type: Optional[str] = None
+ candidate_count: Optional[int] = None
+ stop_sequences: Optional[list] = None
+ frequency_penalty: Optional[float] = None
+ presence_penalty: Optional[float] = None
+ seed: Optional[int] = None
+
+ def __init__(
+ self,
+ temperature: Optional[float] = None,
+ max_output_tokens: Optional[int] = None,
+ top_p: Optional[float] = None,
+ top_k: Optional[int] = None,
+ response_mime_type: Optional[str] = None,
+ candidate_count: Optional[int] = None,
+ stop_sequences: Optional[list] = None,
+ frequency_penalty: Optional[float] = None,
+ presence_penalty: Optional[float] = None,
+ seed: Optional[int] = None,
+ ) -> None:
+ locals_ = locals().copy()
+ for key, value in locals_.items():
+ if key != "self" and value is not None:
+ setattr(self.__class__, key, value)
+
+ @classmethod
+ def get_config(cls):
+ return super().get_config()
+
+ def get_supported_openai_params(self, model: str) -> List[str]:
+ return [
+ "temperature",
+ "top_p",
+ "max_tokens",
+ "max_completion_tokens",
+ "stream",
+ "tools",
+ "functions",
+ "tool_choice",
+ "response_format",
+ "n",
+ "stop",
+ "frequency_penalty",
+ "presence_penalty",
+ "extra_headers",
+ "seed",
+ "logprobs",
+ ]
+
+ def map_tool_choice_values(
+ self, model: str, tool_choice: Union[str, dict]
+ ) -> Optional[ToolConfig]:
+ if tool_choice == "none":
+ return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="NONE"))
+ elif tool_choice == "required":
+ return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="ANY"))
+ elif tool_choice == "auto":
+ return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="AUTO"))
+ elif isinstance(tool_choice, dict):
+ # only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
+ name = tool_choice.get("function", {}).get("name", "")
+ return ToolConfig(
+ functionCallingConfig=FunctionCallingConfig(
+ mode="ANY", allowed_function_names=[name]
+ )
+ )
+ else:
+ raise litellm.utils.UnsupportedParamsError(
+ message="VertexAI doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
+ tool_choice
+ ),
+ status_code=400,
+ )
+
+ def _map_function(self, value: List[dict]) -> List[Tools]:
+ gtool_func_declarations = []
+ googleSearch: Optional[dict] = None
+ googleSearchRetrieval: Optional[dict] = None
+ code_execution: Optional[dict] = None
+ # remove 'additionalProperties' from tools
+ value = _remove_additional_properties(value)
+ # remove 'strict' from tools
+ value = _remove_strict_from_schema(value)
+
+ for tool in value:
+ openai_function_object: Optional[ChatCompletionToolParamFunctionChunk] = (
+ None
+ )
+ if "function" in tool: # tools list
+ _openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore
+ **tool["function"]
+ )
+
+ if (
+ "parameters" in _openai_function_object
+ and _openai_function_object["parameters"] is not None
+ ): # OPENAI accepts JSON Schema, Google accepts OpenAPI schema.
+ _openai_function_object["parameters"] = _build_vertex_schema(
+ _openai_function_object["parameters"]
+ )
+
+ openai_function_object = _openai_function_object
+
+ elif "name" in tool: # functions list
+ openai_function_object = ChatCompletionToolParamFunctionChunk(**tool) # type: ignore
+
+ # check if grounding
+ if tool.get("googleSearch", None) is not None:
+ googleSearch = tool["googleSearch"]
+ elif tool.get("googleSearchRetrieval", None) is not None:
+ googleSearchRetrieval = tool["googleSearchRetrieval"]
+ elif tool.get("code_execution", None) is not None:
+ code_execution = tool["code_execution"]
+ elif openai_function_object is not None:
+ gtool_func_declaration = FunctionDeclaration(
+ name=openai_function_object["name"],
+ )
+ _description = openai_function_object.get("description", None)
+ _parameters = openai_function_object.get("parameters", None)
+ if _description is not None:
+ gtool_func_declaration["description"] = _description
+ if _parameters is not None:
+ gtool_func_declaration["parameters"] = _parameters
+ gtool_func_declarations.append(gtool_func_declaration)
+ else:
+ # assume it's a provider-specific param
+ verbose_logger.warning(
+ "Invalid tool={}. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request."
+ )
+
+ _tools = Tools(
+ function_declarations=gtool_func_declarations,
+ )
+ if googleSearch is not None:
+ _tools["googleSearch"] = googleSearch
+ if googleSearchRetrieval is not None:
+ _tools["googleSearchRetrieval"] = googleSearchRetrieval
+ if code_execution is not None:
+ _tools["code_execution"] = code_execution
+ return [_tools]
+
+ def _map_response_schema(self, value: dict) -> dict:
+ old_schema = deepcopy(value)
+ if isinstance(old_schema, list):
+ for item in old_schema:
+ if isinstance(item, dict):
+ item = _build_vertex_schema(parameters=item)
+ elif isinstance(old_schema, dict):
+ old_schema = _build_vertex_schema(parameters=old_schema)
+ return old_schema
+
+ def map_openai_params(
+ self,
+ non_default_params: Dict,
+ optional_params: Dict,
+ model: str,
+ drop_params: bool,
+ ) -> Dict:
+ for param, value in non_default_params.items():
+ if param == "temperature":
+ optional_params["temperature"] = value
+ if param == "top_p":
+ optional_params["top_p"] = value
+ if (
+ param == "stream" and value is True
+ ): # sending stream = False, can cause it to get passed unchecked and raise issues
+ optional_params["stream"] = value
+ if param == "n":
+ optional_params["candidate_count"] = value
+ if param == "stop":
+ if isinstance(value, str):
+ optional_params["stop_sequences"] = [value]
+ elif isinstance(value, list):
+ optional_params["stop_sequences"] = value
+ if param == "max_tokens" or param == "max_completion_tokens":
+ optional_params["max_output_tokens"] = value
+ if param == "response_format" and isinstance(value, dict): # type: ignore
+ # remove 'additionalProperties' from json schema
+ value = _remove_additional_properties(value)
+ # remove 'strict' from json schema
+ value = _remove_strict_from_schema(value)
+ if value["type"] == "json_object":
+ optional_params["response_mime_type"] = "application/json"
+ elif value["type"] == "text":
+ optional_params["response_mime_type"] = "text/plain"
+ if "response_schema" in value:
+ optional_params["response_mime_type"] = "application/json"
+ optional_params["response_schema"] = value["response_schema"]
+ elif value["type"] == "json_schema": # type: ignore
+ if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
+ optional_params["response_mime_type"] = "application/json"
+ optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
+
+ if "response_schema" in optional_params and isinstance(
+ optional_params["response_schema"], dict
+ ):
+ optional_params["response_schema"] = self._map_response_schema(
+ value=optional_params["response_schema"]
+ )
+ if param == "frequency_penalty":
+ optional_params["frequency_penalty"] = value
+ if param == "presence_penalty":
+ optional_params["presence_penalty"] = value
+ if param == "logprobs":
+ optional_params["responseLogprobs"] = value
+ if (param == "tools" or param == "functions") and isinstance(value, list):
+ optional_params["tools"] = self._map_function(value=value)
+ optional_params["litellm_param_is_function_call"] = (
+ True if param == "functions" else False
+ )
+ if param == "tool_choice" and (
+ isinstance(value, str) or isinstance(value, dict)
+ ):
+ _tool_choice_value = self.map_tool_choice_values(
+ model=model, tool_choice=value # type: ignore
+ )
+ if _tool_choice_value is not None:
+ optional_params["tool_choice"] = _tool_choice_value
+ if param == "seed":
+ optional_params["seed"] = value
+
+ if litellm.vertex_ai_safety_settings is not None:
+ optional_params["safety_settings"] = litellm.vertex_ai_safety_settings
+ return optional_params
+
+ def get_mapped_special_auth_params(self) -> dict:
+ """
+ Common auth params across bedrock/vertex_ai/azure/watsonx
+ """
+ return {"project": "vertex_project", "region_name": "vertex_location"}
+
+ def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
+ mapped_params = self.get_mapped_special_auth_params()
+
+ for param, value in non_default_params.items():
+ if param in mapped_params:
+ optional_params[mapped_params[param]] = value
+ return optional_params
+
+ def get_eu_regions(self) -> List[str]:
+ """
+ Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions
+ """
+ return [
+ "europe-central2",
+ "europe-north1",
+ "europe-southwest1",
+ "europe-west1",
+ "europe-west2",
+ "europe-west3",
+ "europe-west4",
+ "europe-west6",
+ "europe-west8",
+ "europe-west9",
+ ]
+
+ def get_flagged_finish_reasons(self) -> Dict[str, str]:
+ """
+ Return Dictionary of finish reasons which indicate response was flagged
+
+ and what it means
+ """
+ return {
+ "SAFETY": "The token generation was stopped as the response was flagged for safety reasons. NOTE: When streaming the Candidate.content will be empty if content filters blocked the output.",
+ "RECITATION": "The token generation was stopped as the response was flagged for unauthorized citations.",
+ "BLOCKLIST": "The token generation was stopped as the response was flagged for the terms which are included from the terminology blocklist.",
+ "PROHIBITED_CONTENT": "The token generation was stopped as the response was flagged for the prohibited contents.",
+ "SPII": "The token generation was stopped as the response was flagged for Sensitive Personally Identifiable Information (SPII) contents.",
+ }
+
+ def translate_exception_str(self, exception_string: str):
+ if (
+ "GenerateContentRequest.tools[0].function_declarations[0].parameters.properties: should be non-empty for OBJECT type"
+ in exception_string
+ ):
+ return "'properties' field in tools[0]['function']['parameters'] cannot be empty if 'type' == 'object'. Received error from provider - {}".format(
+ exception_string
+ )
+ return exception_string
+
+ def get_assistant_content_message(
+ self, parts: List[HttpxPartType]
+ ) -> Optional[str]:
+ _content_str = ""
+ for part in parts:
+ if "text" in part:
+ _content_str += part["text"]
+ if _content_str:
+ return _content_str
+ return None
+
+ def _transform_parts(
+ self,
+ parts: List[HttpxPartType],
+ index: int,
+ is_function_call: Optional[bool],
+ ) -> Tuple[
+ Optional[ChatCompletionToolCallFunctionChunk],
+ Optional[List[ChatCompletionToolCallChunk]],
+ ]:
+ function: Optional[ChatCompletionToolCallFunctionChunk] = None
+ _tools: List[ChatCompletionToolCallChunk] = []
+ for part in parts:
+ if "functionCall" in part:
+ _function_chunk = ChatCompletionToolCallFunctionChunk(
+ name=part["functionCall"]["name"],
+ arguments=json.dumps(part["functionCall"]["args"]),
+ )
+ if is_function_call is True:
+ function = _function_chunk
+ else:
+ _tool_response_chunk = ChatCompletionToolCallChunk(
+ id=f"call_{str(uuid.uuid4())}",
+ type="function",
+ function=_function_chunk,
+ index=index,
+ )
+ _tools.append(_tool_response_chunk)
+ if len(_tools) == 0:
+ tools: Optional[List[ChatCompletionToolCallChunk]] = None
+ else:
+ tools = _tools
+ return function, tools
+
+ def _transform_logprobs(
+ self, logprobs_result: Optional[LogprobsResult]
+ ) -> Optional[ChoiceLogprobs]:
+ if logprobs_result is None:
+ return None
+ if "chosenCandidates" not in logprobs_result:
+ return None
+ logprobs_list: List[ChatCompletionTokenLogprob] = []
+ for index, candidate in enumerate(logprobs_result["chosenCandidates"]):
+ top_logprobs: List[TopLogprob] = []
+ if "topCandidates" in logprobs_result and index < len(
+ logprobs_result["topCandidates"]
+ ):
+ top_candidates_for_index = logprobs_result["topCandidates"][index][
+ "candidates"
+ ]
+
+ for options in top_candidates_for_index:
+ top_logprobs.append(
+ TopLogprob(
+ token=options["token"], logprob=options["logProbability"]
+ )
+ )
+ logprobs_list.append(
+ ChatCompletionTokenLogprob(
+ token=candidate["token"],
+ logprob=candidate["logProbability"],
+ top_logprobs=top_logprobs,
+ )
+ )
+ return ChoiceLogprobs(content=logprobs_list)
+
+ def _handle_blocked_response(
+ self,
+ model_response: ModelResponse,
+ completion_response: GenerateContentResponseBody,
+ ) -> ModelResponse:
+ # If set, the prompt was blocked and no candidates are returned. Rephrase your prompt
+ model_response.choices[0].finish_reason = "content_filter"
+
+ chat_completion_message: ChatCompletionResponseMessage = {
+ "role": "assistant",
+ "content": None,
+ }
+
+ choice = litellm.Choices(
+ finish_reason="content_filter",
+ index=0,
+ message=chat_completion_message, # type: ignore
+ logprobs=None,
+ enhancements=None,
+ )
+
+ model_response.choices = [choice]
+
+ ## GET USAGE ##
+ usage = Usage(
+ prompt_tokens=completion_response["usageMetadata"].get(
+ "promptTokenCount", 0
+ ),
+ completion_tokens=completion_response["usageMetadata"].get(
+ "candidatesTokenCount", 0
+ ),
+ total_tokens=completion_response["usageMetadata"].get("totalTokenCount", 0),
+ )
+
+ setattr(model_response, "usage", usage)
+
+ return model_response
+
+ def _handle_content_policy_violation(
+ self,
+ model_response: ModelResponse,
+ completion_response: GenerateContentResponseBody,
+ ) -> ModelResponse:
+ ## CONTENT POLICY VIOLATION ERROR
+ model_response.choices[0].finish_reason = "content_filter"
+
+ _chat_completion_message = {
+ "role": "assistant",
+ "content": None,
+ }
+
+ choice = litellm.Choices(
+ finish_reason="content_filter",
+ index=0,
+ message=_chat_completion_message,
+ logprobs=None,
+ enhancements=None,
+ )
+
+ model_response.choices = [choice]
+
+ ## GET USAGE ##
+ usage = Usage(
+ prompt_tokens=completion_response["usageMetadata"].get(
+ "promptTokenCount", 0
+ ),
+ completion_tokens=completion_response["usageMetadata"].get(
+ "candidatesTokenCount", 0
+ ),
+ total_tokens=completion_response["usageMetadata"].get("totalTokenCount", 0),
+ )
+
+ setattr(model_response, "usage", usage)
+
+ return model_response
+
+ def _calculate_usage(
+ self,
+ completion_response: GenerateContentResponseBody,
+ ) -> Usage:
+ cached_tokens: Optional[int] = None
+ prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
+ if "cachedContentTokenCount" in completion_response["usageMetadata"]:
+ cached_tokens = completion_response["usageMetadata"][
+ "cachedContentTokenCount"
+ ]
+
+ if cached_tokens is not None:
+ prompt_tokens_details = PromptTokensDetailsWrapper(
+ cached_tokens=cached_tokens,
+ )
+ ## GET USAGE ##
+ usage = Usage(
+ prompt_tokens=completion_response["usageMetadata"].get(
+ "promptTokenCount", 0
+ ),
+ completion_tokens=completion_response["usageMetadata"].get(
+ "candidatesTokenCount", 0
+ ),
+ total_tokens=completion_response["usageMetadata"].get("totalTokenCount", 0),
+ prompt_tokens_details=prompt_tokens_details,
+ )
+
+ return usage
+
+ def transform_response(
+ self,
+ model: str,
+ raw_response: httpx.Response,
+ model_response: ModelResponse,
+ logging_obj: LoggingClass,
+ request_data: Dict,
+ messages: List[AllMessageValues],
+ optional_params: Dict,
+ litellm_params: Dict,
+ encoding: Any,
+ api_key: Optional[str] = None,
+ json_mode: Optional[bool] = None,
+ ) -> ModelResponse:
+ ## LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response=raw_response.text,
+ additional_args={"complete_input_dict": request_data},
+ )
+
+ ## RESPONSE OBJECT
+ try:
+ completion_response = GenerateContentResponseBody(**raw_response.json()) # type: ignore
+ except Exception as e:
+ raise VertexAIError(
+ message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
+ raw_response.text, str(e)
+ ),
+ status_code=422,
+ headers=raw_response.headers,
+ )
+
+ ## GET MODEL ##
+ model_response.model = model
+
+ ## CHECK IF RESPONSE FLAGGED
+ if (
+ "promptFeedback" in completion_response
+ and "blockReason" in completion_response["promptFeedback"]
+ ):
+ return self._handle_blocked_response(
+ model_response=model_response,
+ completion_response=completion_response,
+ )
+
+ _candidates = completion_response.get("candidates")
+ if _candidates and len(_candidates) > 0:
+ content_policy_violations = (
+ VertexGeminiConfig().get_flagged_finish_reasons()
+ )
+ if (
+ "finishReason" in _candidates[0]
+ and _candidates[0]["finishReason"] in content_policy_violations.keys()
+ ):
+ return self._handle_content_policy_violation(
+ model_response=model_response,
+ completion_response=completion_response,
+ )
+
+ model_response.choices = [] # type: ignore
+
+ try:
+ ## CHECK IF GROUNDING METADATA IN REQUEST
+ grounding_metadata: List[dict] = []
+ safety_ratings: List = []
+ citation_metadata: List = []
+ ## GET TEXT ##
+ chat_completion_message: ChatCompletionResponseMessage = {
+ "role": "assistant"
+ }
+ chat_completion_logprobs: Optional[ChoiceLogprobs] = None
+ tools: Optional[List[ChatCompletionToolCallChunk]] = []
+ functions: Optional[ChatCompletionToolCallFunctionChunk] = None
+ if _candidates:
+ for idx, candidate in enumerate(_candidates):
+ if "content" not in candidate:
+ continue
+
+ if "groundingMetadata" in candidate:
+ grounding_metadata.append(candidate["groundingMetadata"]) # type: ignore
+
+ if "safetyRatings" in candidate:
+ safety_ratings.append(candidate["safetyRatings"])
+
+ if "citationMetadata" in candidate:
+ citation_metadata.append(candidate["citationMetadata"])
+ if "parts" in candidate["content"]:
+ chat_completion_message[
+ "content"
+ ] = VertexGeminiConfig().get_assistant_content_message(
+ parts=candidate["content"]["parts"]
+ )
+
+ functions, tools = self._transform_parts(
+ parts=candidate["content"]["parts"],
+ index=candidate.get("index", idx),
+ is_function_call=litellm_params.get(
+ "litellm_param_is_function_call"
+ ),
+ )
+
+ if "logprobsResult" in candidate:
+ chat_completion_logprobs = self._transform_logprobs(
+ logprobs_result=candidate["logprobsResult"]
+ )
+
+ if tools:
+ chat_completion_message["tool_calls"] = tools
+
+ if functions is not None:
+ chat_completion_message["function_call"] = functions
+ choice = litellm.Choices(
+ finish_reason=candidate.get("finishReason", "stop"),
+ index=candidate.get("index", idx),
+ message=chat_completion_message, # type: ignore
+ logprobs=chat_completion_logprobs,
+ enhancements=None,
+ )
+
+ model_response.choices.append(choice)
+
+ usage = self._calculate_usage(completion_response=completion_response)
+ setattr(model_response, "usage", usage)
+
+ ## ADD GROUNDING METADATA ##
+ setattr(model_response, "vertex_ai_grounding_metadata", grounding_metadata)
+ model_response._hidden_params[
+ "vertex_ai_grounding_metadata"
+ ] = ( # older approach - maintaining to prevent regressions
+ grounding_metadata
+ )
+
+ ## ADD SAFETY RATINGS ##
+ setattr(model_response, "vertex_ai_safety_results", safety_ratings)
+ model_response._hidden_params["vertex_ai_safety_results"] = (
+ safety_ratings # older approach - maintaining to prevent regressions
+ )
+
+ ## ADD CITATION METADATA ##
+ setattr(model_response, "vertex_ai_citation_metadata", citation_metadata)
+ model_response._hidden_params["vertex_ai_citation_metadata"] = (
+ citation_metadata # older approach - maintaining to prevent regressions
+ )
+
+ except Exception as e:
+ raise VertexAIError(
+ message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format(
+ completion_response, str(e)
+ ),
+ status_code=422,
+ headers=raw_response.headers,
+ )
+
+ return model_response
+
+ def _transform_messages(
+ self, messages: List[AllMessageValues]
+ ) -> List[ContentType]:
+ return _gemini_convert_messages_with_history(messages=messages)
+
+ def get_error_class(
+ self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
+ ) -> BaseLLMException:
+ return VertexAIError(
+ message=error_message, status_code=status_code, headers=headers
+ )
+
+ def transform_request(
+ self,
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: Dict,
+ litellm_params: Dict,
+ headers: Dict,
+ ) -> Dict:
+ raise NotImplementedError(
+ "Vertex AI has a custom implementation of transform_request. Needs sync + async."
+ )
+
+ def validate_environment(
+ self,
+ headers: Optional[Dict],
+ model: str,
+ messages: List[AllMessageValues],
+ optional_params: Dict,
+ api_key: Optional[str] = None,
+ api_base: Optional[str] = None,
+ ) -> Dict:
+ default_headers = {
+ "Content-Type": "application/json",
+ }
+ if api_key is not None:
+ default_headers["Authorization"] = f"Bearer {api_key}"
+ if headers is not None:
+ default_headers.update(headers)
+
+ return default_headers
+
+
+async def make_call(
+ client: Optional[AsyncHTTPHandler],
+ api_base: str,
+ headers: dict,
+ data: str,
+ model: str,
+ messages: list,
+ logging_obj,
+):
+ if client is None:
+ client = get_async_httpx_client(
+ llm_provider=litellm.LlmProviders.VERTEX_AI,
+ )
+
+ try:
+ response = await client.post(api_base, headers=headers, data=data, stream=True)
+ response.raise_for_status()
+ except httpx.HTTPStatusError as e:
+ exception_string = str(await e.response.aread())
+ raise VertexAIError(
+ status_code=e.response.status_code,
+ message=VertexGeminiConfig().translate_exception_str(exception_string),
+ headers=e.response.headers,
+ )
+ if response.status_code != 200 and response.status_code != 201:
+ raise VertexAIError(
+ status_code=response.status_code,
+ message=response.text,
+ headers=response.headers,
+ )
+
+ completion_stream = ModelResponseIterator(
+ streaming_response=response.aiter_lines(), sync_stream=False
+ )
+ # LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response="first stream response received",
+ additional_args={"complete_input_dict": data},
+ )
+
+ return completion_stream
+
+
+def make_sync_call(
+ client: Optional[HTTPHandler], # module-level client
+ gemini_client: Optional[HTTPHandler], # if passed by user
+ api_base: str,
+ headers: dict,
+ data: str,
+ model: str,
+ messages: list,
+ logging_obj,
+):
+ if gemini_client is not None:
+ client = gemini_client
+ if client is None:
+ client = HTTPHandler() # Create a new client if none provided
+
+ response = client.post(api_base, headers=headers, data=data, stream=True)
+
+ if response.status_code != 200 and response.status_code != 201:
+ raise VertexAIError(
+ status_code=response.status_code,
+ message=str(response.read()),
+ headers=response.headers,
+ )
+
+ completion_stream = ModelResponseIterator(
+ streaming_response=response.iter_lines(), sync_stream=True
+ )
+
+ # LOGGING
+ logging_obj.post_call(
+ input=messages,
+ api_key="",
+ original_response="first stream response received",
+ additional_args={"complete_input_dict": data},
+ )
+
+ return completion_stream
+
+
+class VertexLLM(VertexBase):
+ def __init__(self) -> None:
+ super().__init__()
+
+ async def async_streaming(
+ self,
+ model: str,
+ custom_llm_provider: Literal[
+ "vertex_ai", "vertex_ai_beta", "gemini"
+ ], # if it's vertex_ai or gemini (google ai studio)
+ messages: list,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ data: dict,
+ timeout: Optional[Union[float, httpx.Timeout]],
+ encoding,
+ logging_obj,
+ stream,
+ optional_params: dict,
+ litellm_params=None,
+ logger_fn=None,
+ api_base: Optional[str] = None,
+ client: Optional[AsyncHTTPHandler] = None,
+ vertex_project: Optional[str] = None,
+ vertex_location: Optional[str] = None,
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
+ gemini_api_key: Optional[str] = None,
+ extra_headers: Optional[dict] = None,
+ ) -> CustomStreamWrapper:
+ request_body = await async_transform_request_body(**data) # type: ignore
+
+ should_use_v1beta1_features = self.is_using_v1beta1_features(
+ optional_params=optional_params
+ )
+
+ _auth_header, vertex_project = await self._ensure_access_token_async(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider=custom_llm_provider,
+ )
+
+ auth_header, api_base = self._get_token_and_url(
+ model=model,
+ gemini_api_key=gemini_api_key,
+ auth_header=_auth_header,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ stream=stream,
+ custom_llm_provider=custom_llm_provider,
+ api_base=api_base,
+ should_use_v1beta1_features=should_use_v1beta1_features,
+ )
+
+ headers = VertexGeminiConfig().validate_environment(
+ api_key=auth_header,
+ headers=extra_headers,
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=messages,
+ api_key="",
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": api_base,
+ "headers": headers,
+ },
+ )
+
+ request_body_str = json.dumps(request_body)
+ streaming_response = CustomStreamWrapper(
+ completion_stream=None,
+ make_call=partial(
+ make_call,
+ client=client,
+ api_base=api_base,
+ headers=headers,
+ data=request_body_str,
+ model=model,
+ messages=messages,
+ logging_obj=logging_obj,
+ ),
+ model=model,
+ custom_llm_provider="vertex_ai_beta",
+ logging_obj=logging_obj,
+ )
+ return streaming_response
+
+ async def async_completion(
+ self,
+ model: str,
+ messages: list,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ data: dict,
+ custom_llm_provider: Literal[
+ "vertex_ai", "vertex_ai_beta", "gemini"
+ ], # if it's vertex_ai or gemini (google ai studio)
+ timeout: Optional[Union[float, httpx.Timeout]],
+ encoding,
+ logging_obj,
+ stream,
+ optional_params: dict,
+ litellm_params: dict,
+ logger_fn=None,
+ api_base: Optional[str] = None,
+ client: Optional[AsyncHTTPHandler] = None,
+ vertex_project: Optional[str] = None,
+ vertex_location: Optional[str] = None,
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
+ gemini_api_key: Optional[str] = None,
+ extra_headers: Optional[dict] = None,
+ ) -> Union[ModelResponse, CustomStreamWrapper]:
+ should_use_v1beta1_features = self.is_using_v1beta1_features(
+ optional_params=optional_params
+ )
+
+ _auth_header, vertex_project = await self._ensure_access_token_async(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider=custom_llm_provider,
+ )
+
+ auth_header, api_base = self._get_token_and_url(
+ model=model,
+ gemini_api_key=gemini_api_key,
+ auth_header=_auth_header,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ stream=stream,
+ custom_llm_provider=custom_llm_provider,
+ api_base=api_base,
+ should_use_v1beta1_features=should_use_v1beta1_features,
+ )
+
+ headers = VertexGeminiConfig().validate_environment(
+ api_key=auth_header,
+ headers=extra_headers,
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ )
+
+ request_body = await async_transform_request_body(**data) # type: ignore
+ _async_client_params = {}
+ if timeout:
+ _async_client_params["timeout"] = timeout
+ if client is None or not isinstance(client, AsyncHTTPHandler):
+ client = get_async_httpx_client(
+ params=_async_client_params, llm_provider=litellm.LlmProviders.VERTEX_AI
+ )
+ else:
+ client = client # type: ignore
+ ## LOGGING
+ logging_obj.pre_call(
+ input=messages,
+ api_key="",
+ additional_args={
+ "complete_input_dict": request_body,
+ "api_base": api_base,
+ "headers": headers,
+ },
+ )
+
+ try:
+ response = await client.post(
+ api_base, headers=headers, json=cast(dict, request_body)
+ ) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise VertexAIError(
+ status_code=error_code,
+ message=err.response.text,
+ headers=err.response.headers,
+ )
+ except httpx.TimeoutException:
+ raise VertexAIError(
+ status_code=408,
+ message="Timeout error occurred.",
+ headers=None,
+ )
+
+ return VertexGeminiConfig().transform_response(
+ model=model,
+ raw_response=response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ api_key="",
+ request_data=cast(dict, request_body),
+ messages=messages,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ encoding=encoding,
+ )
+
+ def completion(
+ self,
+ model: str,
+ messages: list,
+ model_response: ModelResponse,
+ print_verbose: Callable,
+ custom_llm_provider: Literal[
+ "vertex_ai", "vertex_ai_beta", "gemini"
+ ], # if it's vertex_ai or gemini (google ai studio)
+ encoding,
+ logging_obj,
+ optional_params: dict,
+ acompletion: bool,
+ timeout: Optional[Union[float, httpx.Timeout]],
+ vertex_project: Optional[str],
+ vertex_location: Optional[str],
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ gemini_api_key: Optional[str],
+ litellm_params: dict,
+ logger_fn=None,
+ extra_headers: Optional[dict] = None,
+ client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
+ api_base: Optional[str] = None,
+ ) -> Union[ModelResponse, CustomStreamWrapper]:
+ stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore
+
+ transform_request_params = {
+ "gemini_api_key": gemini_api_key,
+ "messages": messages,
+ "api_base": api_base,
+ "model": model,
+ "client": client,
+ "timeout": timeout,
+ "extra_headers": extra_headers,
+ "optional_params": optional_params,
+ "logging_obj": logging_obj,
+ "custom_llm_provider": custom_llm_provider,
+ "litellm_params": litellm_params,
+ }
+
+ ### ROUTING (ASYNC, STREAMING, SYNC)
+ if acompletion:
+ ### ASYNC STREAMING
+ if stream is True:
+ return self.async_streaming(
+ model=model,
+ messages=messages,
+ api_base=api_base,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ encoding=encoding,
+ logging_obj=logging_obj,
+ optional_params=optional_params,
+ stream=stream,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ timeout=timeout,
+ client=client, # type: ignore
+ data=transform_request_params,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ gemini_api_key=gemini_api_key,
+ custom_llm_provider=custom_llm_provider,
+ extra_headers=extra_headers,
+ )
+ ### ASYNC COMPLETION
+ return self.async_completion(
+ model=model,
+ messages=messages,
+ data=transform_request_params, # type: ignore
+ api_base=api_base,
+ model_response=model_response,
+ print_verbose=print_verbose,
+ encoding=encoding,
+ logging_obj=logging_obj,
+ optional_params=optional_params,
+ stream=stream,
+ litellm_params=litellm_params,
+ logger_fn=logger_fn,
+ timeout=timeout,
+ client=client, # type: ignore
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ gemini_api_key=gemini_api_key,
+ custom_llm_provider=custom_llm_provider,
+ extra_headers=extra_headers,
+ )
+
+ should_use_v1beta1_features = self.is_using_v1beta1_features(
+ optional_params=optional_params
+ )
+
+ _auth_header, vertex_project = self._ensure_access_token(
+ credentials=vertex_credentials,
+ project_id=vertex_project,
+ custom_llm_provider=custom_llm_provider,
+ )
+
+ auth_header, url = self._get_token_and_url(
+ model=model,
+ gemini_api_key=gemini_api_key,
+ auth_header=_auth_header,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_credentials=vertex_credentials,
+ stream=stream,
+ custom_llm_provider=custom_llm_provider,
+ api_base=api_base,
+ should_use_v1beta1_features=should_use_v1beta1_features,
+ )
+ headers = VertexGeminiConfig().validate_environment(
+ api_key=auth_header,
+ headers=extra_headers,
+ model=model,
+ messages=messages,
+ optional_params=optional_params,
+ )
+
+ ## TRANSFORMATION ##
+ data = sync_transform_request_body(**transform_request_params)
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=messages,
+ api_key="",
+ additional_args={
+ "complete_input_dict": data,
+ "api_base": url,
+ "headers": headers,
+ },
+ )
+
+ ## SYNC STREAMING CALL ##
+ if stream is True:
+ request_data_str = json.dumps(data)
+ streaming_response = CustomStreamWrapper(
+ completion_stream=None,
+ make_call=partial(
+ make_sync_call,
+ gemini_client=(
+ client
+ if client is not None and isinstance(client, HTTPHandler)
+ else None
+ ),
+ api_base=url,
+ data=request_data_str,
+ model=model,
+ messages=messages,
+ logging_obj=logging_obj,
+ headers=headers,
+ ),
+ model=model,
+ custom_llm_provider="vertex_ai_beta",
+ logging_obj=logging_obj,
+ )
+
+ return streaming_response
+ ## COMPLETION CALL ##
+
+ if client is None or isinstance(client, AsyncHTTPHandler):
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ timeout = httpx.Timeout(timeout)
+ _params["timeout"] = timeout
+ client = HTTPHandler(**_params) # type: ignore
+ else:
+ client = client
+
+ try:
+ response = client.post(url=url, headers=headers, json=data) # type: ignore
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise VertexAIError(
+ status_code=error_code,
+ message=err.response.text,
+ headers=err.response.headers,
+ )
+ except httpx.TimeoutException:
+ raise VertexAIError(
+ status_code=408,
+ message="Timeout error occurred.",
+ headers=None,
+ )
+
+ return VertexGeminiConfig().transform_response(
+ model=model,
+ raw_response=response,
+ model_response=model_response,
+ logging_obj=logging_obj,
+ optional_params=optional_params,
+ litellm_params=litellm_params,
+ api_key="",
+ request_data=data, # type: ignore
+ messages=messages,
+ encoding=encoding,
+ )
+
+
+class ModelResponseIterator:
+ def __init__(self, streaming_response, sync_stream: bool):
+ self.streaming_response = streaming_response
+ self.chunk_type: Literal["valid_json", "accumulated_json"] = "valid_json"
+ self.accumulated_json = ""
+ self.sent_first_chunk = False
+
+ def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
+ try:
+ processed_chunk = GenerateContentResponseBody(**chunk) # type: ignore
+
+ text = ""
+ tool_use: Optional[ChatCompletionToolCallChunk] = None
+ finish_reason = ""
+ usage: Optional[ChatCompletionUsageBlock] = None
+ _candidates: Optional[List[Candidates]] = processed_chunk.get("candidates")
+ gemini_chunk: Optional[Candidates] = None
+ if _candidates and len(_candidates) > 0:
+ gemini_chunk = _candidates[0]
+
+ if (
+ gemini_chunk
+ and "content" in gemini_chunk
+ and "parts" in gemini_chunk["content"]
+ ):
+ if "text" in gemini_chunk["content"]["parts"][0]:
+ text = gemini_chunk["content"]["parts"][0]["text"]
+ elif "functionCall" in gemini_chunk["content"]["parts"][0]:
+ function_call = ChatCompletionToolCallFunctionChunk(
+ name=gemini_chunk["content"]["parts"][0]["functionCall"][
+ "name"
+ ],
+ arguments=json.dumps(
+ gemini_chunk["content"]["parts"][0]["functionCall"]["args"]
+ ),
+ )
+ tool_use = ChatCompletionToolCallChunk(
+ id=str(uuid.uuid4()),
+ type="function",
+ function=function_call,
+ index=0,
+ )
+
+ if gemini_chunk and "finishReason" in gemini_chunk:
+ finish_reason = map_finish_reason(
+ finish_reason=gemini_chunk["finishReason"]
+ )
+ ## DO NOT SET 'is_finished' = True
+ ## GEMINI SETS FINISHREASON ON EVERY CHUNK!
+
+ if "usageMetadata" in processed_chunk:
+ usage = ChatCompletionUsageBlock(
+ prompt_tokens=processed_chunk["usageMetadata"].get(
+ "promptTokenCount", 0
+ ),
+ completion_tokens=processed_chunk["usageMetadata"].get(
+ "candidatesTokenCount", 0
+ ),
+ total_tokens=processed_chunk["usageMetadata"].get(
+ "totalTokenCount", 0
+ ),
+ )
+
+ returned_chunk = GenericStreamingChunk(
+ text=text,
+ tool_use=tool_use,
+ is_finished=False,
+ finish_reason=finish_reason,
+ usage=usage,
+ index=0,
+ )
+ return returned_chunk
+ except json.JSONDecodeError:
+ raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
+
+ # Sync iterator
+ def __iter__(self):
+ self.response_iterator = self.streaming_response
+ return self
+
+ def handle_valid_json_chunk(self, chunk: str) -> GenericStreamingChunk:
+ chunk = chunk.strip()
+ try:
+ json_chunk = json.loads(chunk)
+
+ except json.JSONDecodeError as e:
+ if (
+ self.sent_first_chunk is False
+ ): # only check for accumulated json, on first chunk, else raise error. Prevent real errors from being masked.
+ self.chunk_type = "accumulated_json"
+ return self.handle_accumulated_json_chunk(chunk=chunk)
+ raise e
+
+ if self.sent_first_chunk is False:
+ self.sent_first_chunk = True
+
+ return self.chunk_parser(chunk=json_chunk)
+
+ def handle_accumulated_json_chunk(self, chunk: str) -> GenericStreamingChunk:
+ chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
+ message = chunk.replace("\n\n", "")
+
+ # Accumulate JSON data
+ self.accumulated_json += message
+
+ # Try to parse the accumulated JSON
+ try:
+ _data = json.loads(self.accumulated_json)
+ self.accumulated_json = "" # reset after successful parsing
+ return self.chunk_parser(chunk=_data)
+ except json.JSONDecodeError:
+ # If it's not valid JSON yet, continue to the next event
+ return GenericStreamingChunk(
+ text="",
+ is_finished=False,
+ finish_reason="",
+ usage=None,
+ index=0,
+ tool_use=None,
+ )
+
+ def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk:
+ try:
+ chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or ""
+ if len(chunk) > 0:
+ """
+ Check if initial chunk valid json
+ - if partial json -> enter accumulated json logic
+ - if valid - continue
+ """
+ if self.chunk_type == "valid_json":
+ return self.handle_valid_json_chunk(chunk=chunk)
+ elif self.chunk_type == "accumulated_json":
+ return self.handle_accumulated_json_chunk(chunk=chunk)
+
+ return GenericStreamingChunk(
+ text="",
+ is_finished=False,
+ finish_reason="",
+ usage=None,
+ index=0,
+ tool_use=None,
+ )
+ except Exception:
+ raise
+
+ def __next__(self):
+ try:
+ chunk = self.response_iterator.__next__()
+ except StopIteration:
+ if self.chunk_type == "accumulated_json" and self.accumulated_json:
+ return self.handle_accumulated_json_chunk(chunk="")
+ raise StopIteration
+ except ValueError as e:
+ raise RuntimeError(f"Error receiving chunk from stream: {e}")
+
+ try:
+ return self._common_chunk_parsing_logic(chunk=chunk)
+ except StopIteration:
+ raise StopIteration
+ except ValueError as e:
+ raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
+
+ # Async iterator
+ def __aiter__(self):
+ self.async_response_iterator = self.streaming_response.__aiter__()
+ return self
+
+ async def __anext__(self):
+ try:
+ chunk = await self.async_response_iterator.__anext__()
+ except StopAsyncIteration:
+ if self.chunk_type == "accumulated_json" and self.accumulated_json:
+ return self.handle_accumulated_json_chunk(chunk="")
+ raise StopAsyncIteration
+ except ValueError as e:
+ raise RuntimeError(f"Error receiving chunk from stream: {e}")
+
+ try:
+ return self._common_chunk_parsing_logic(chunk=chunk)
+ except StopAsyncIteration:
+ raise StopAsyncIteration
+ except ValueError as e:
+ raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")