diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini')
2 files changed, 1974 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py new file mode 100644 index 00000000..d6bafc7c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py @@ -0,0 +1,479 @@ +""" +Transformation logic from OpenAI format to Gemini format. + +Why separate file? Make it easy to see how transformation works +""" + +import os +from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union, cast + +import httpx +from pydantic import BaseModel + +import litellm +from litellm._logging import verbose_logger +from litellm.litellm_core_utils.prompt_templates.factory import ( + convert_to_anthropic_image_obj, + convert_to_gemini_tool_call_invoke, + convert_to_gemini_tool_call_result, + response_schema_prompt, +) +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.files import ( + get_file_mime_type_for_file_type, + get_file_type_from_extension, + is_gemini_1_5_accepted_file_type, +) +from litellm.types.llms.openai import ( + AllMessageValues, + ChatCompletionAssistantMessage, + ChatCompletionImageObject, + ChatCompletionTextObject, +) +from litellm.types.llms.vertex_ai import * +from litellm.types.llms.vertex_ai import ( + GenerationConfig, + PartType, + RequestBody, + SafetSettingsConfig, + SystemInstructions, + ToolConfig, + Tools, +) + +from ..common_utils import ( + _check_text_in_content, + get_supports_response_schema, + get_supports_system_message, +) + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +def _process_gemini_image(image_url: str, format: Optional[str] = None) -> PartType: + """ + Given an image URL, return the appropriate PartType for Gemini + """ + + try: + # GCS URIs + if "gs://" in image_url: + # Figure out file type + extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png" + extension = extension_with_dot[1:] # Ex: "png" + + if not format: + file_type = get_file_type_from_extension(extension) + + # Validate the file type is supported by Gemini + if not is_gemini_1_5_accepted_file_type(file_type): + raise Exception(f"File type not supported by gemini - {file_type}") + + mime_type = get_file_mime_type_for_file_type(file_type) + else: + mime_type = format + file_data = FileDataType(mime_type=mime_type, file_uri=image_url) + + return PartType(file_data=file_data) + elif ( + "https://" in image_url + and (image_type := format or _get_image_mime_type_from_url(image_url)) + is not None + ): + + file_data = FileDataType(file_uri=image_url, mime_type=image_type) + return PartType(file_data=file_data) + elif "http://" in image_url or "https://" in image_url or "base64" in image_url: + # https links for unsupported mime types and base64 images + image = convert_to_anthropic_image_obj(image_url, format=format) + _blob = BlobType(data=image["data"], mime_type=image["media_type"]) + return PartType(inline_data=_blob) + raise Exception("Invalid image received - {}".format(image_url)) + except Exception as e: + raise e + + +def _get_image_mime_type_from_url(url: str) -> Optional[str]: + """ + Get mime type for common image URLs + See gemini mime types: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/image-understanding#image-requirements + + Supported by Gemini: + - PNG (`image/png`) + - JPEG (`image/jpeg`) + - WebP (`image/webp`) + Example: + url = https://example.com/image.jpg + Returns: image/jpeg + """ + url = url.lower() + if url.endswith((".jpg", ".jpeg")): + return "image/jpeg" + elif url.endswith(".png"): + return "image/png" + elif url.endswith(".webp"): + return "image/webp" + elif url.endswith(".mp4"): + return "video/mp4" + elif url.endswith(".pdf"): + return "application/pdf" + return None + + +def _gemini_convert_messages_with_history( # noqa: PLR0915 + messages: List[AllMessageValues], +) -> List[ContentType]: + """ + Converts given messages from OpenAI format to Gemini format + + - Parts must be iterable + - Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles) + - Please ensure that function response turn comes immediately after a function call turn + """ + user_message_types = {"user", "system"} + contents: List[ContentType] = [] + + last_message_with_tool_calls = None + + msg_i = 0 + tool_call_responses = [] + try: + while msg_i < len(messages): + user_content: List[PartType] = [] + init_msg_i = msg_i + ## MERGE CONSECUTIVE USER CONTENT ## + while ( + msg_i < len(messages) and messages[msg_i]["role"] in user_message_types + ): + _message_content = messages[msg_i].get("content") + if _message_content is not None and isinstance(_message_content, list): + _parts: List[PartType] = [] + for element in _message_content: + if ( + element["type"] == "text" + and "text" in element + and len(element["text"]) > 0 + ): + element = cast(ChatCompletionTextObject, element) + _part = PartType(text=element["text"]) + _parts.append(_part) + elif element["type"] == "image_url": + element = cast(ChatCompletionImageObject, element) + img_element = element + format: Optional[str] = None + if isinstance(img_element["image_url"], dict): + image_url = img_element["image_url"]["url"] + format = img_element["image_url"].get("format") + else: + image_url = img_element["image_url"] + _part = _process_gemini_image( + image_url=image_url, format=format + ) + _parts.append(_part) + user_content.extend(_parts) + elif ( + _message_content is not None + and isinstance(_message_content, str) + and len(_message_content) > 0 + ): + _part = PartType(text=_message_content) + user_content.append(_part) + + msg_i += 1 + + if user_content: + """ + check that user_content has 'text' parameter. + - Known Vertex Error: Unable to submit request because it must have a text parameter. + - Relevant Issue: https://github.com/BerriAI/litellm/issues/5515 + """ + has_text_in_content = _check_text_in_content(user_content) + if has_text_in_content is False: + verbose_logger.warning( + "No text in user content. Adding a blank text to user content, to ensure Gemini doesn't fail the request. Relevant Issue - https://github.com/BerriAI/litellm/issues/5515" + ) + user_content.append( + PartType(text=" ") + ) # add a blank text, to ensure Gemini doesn't fail the request. + contents.append(ContentType(role="user", parts=user_content)) + assistant_content = [] + ## MERGE CONSECUTIVE ASSISTANT CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": + if isinstance(messages[msg_i], BaseModel): + msg_dict: Union[ChatCompletionAssistantMessage, dict] = messages[msg_i].model_dump() # type: ignore + else: + msg_dict = messages[msg_i] # type: ignore + assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore + _message_content = assistant_msg.get("content", None) + if _message_content is not None and isinstance(_message_content, list): + _parts = [] + for element in _message_content: + if isinstance(element, dict): + if element["type"] == "text": + _part = PartType(text=element["text"]) + _parts.append(_part) + assistant_content.extend(_parts) + elif ( + _message_content is not None + and isinstance(_message_content, str) + and _message_content + ): + assistant_text = _message_content # either string or none + assistant_content.append(PartType(text=assistant_text)) # type: ignore + + ## HANDLE ASSISTANT FUNCTION CALL + if ( + assistant_msg.get("tool_calls", []) is not None + or assistant_msg.get("function_call") is not None + ): # support assistant tool invoke conversion + assistant_content.extend( + convert_to_gemini_tool_call_invoke(assistant_msg) + ) + last_message_with_tool_calls = assistant_msg + + msg_i += 1 + + if assistant_content: + contents.append(ContentType(role="model", parts=assistant_content)) + + ## APPEND TOOL CALL MESSAGES ## + tool_call_message_roles = ["tool", "function"] + if ( + msg_i < len(messages) + and messages[msg_i]["role"] in tool_call_message_roles + ): + _part = convert_to_gemini_tool_call_result( + messages[msg_i], last_message_with_tool_calls # type: ignore + ) + msg_i += 1 + tool_call_responses.append(_part) + if msg_i < len(messages) and ( + messages[msg_i]["role"] not in tool_call_message_roles + ): + if len(tool_call_responses) > 0: + contents.append(ContentType(parts=tool_call_responses)) + tool_call_responses = [] + + if msg_i == init_msg_i: # prevent infinite loops + raise Exception( + "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format( + messages[msg_i] + ) + ) + if len(tool_call_responses) > 0: + contents.append(ContentType(parts=tool_call_responses)) + return contents + except Exception as e: + raise e + + +def _transform_request_body( + messages: List[AllMessageValues], + model: str, + optional_params: dict, + custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"], + litellm_params: dict, + cached_content: Optional[str], +) -> RequestBody: + """ + Common transformation logic across sync + async Gemini /generateContent calls. + """ + # Separate system prompt from rest of message + supports_system_message = get_supports_system_message( + model=model, custom_llm_provider=custom_llm_provider + ) + system_instructions, messages = _transform_system_message( + supports_system_message=supports_system_message, messages=messages + ) + # Checks for 'response_schema' support - if passed in + if "response_schema" in optional_params: + supports_response_schema = get_supports_response_schema( + model=model, custom_llm_provider=custom_llm_provider + ) + if supports_response_schema is False: + user_response_schema_message = response_schema_prompt( + model=model, response_schema=optional_params.get("response_schema") # type: ignore + ) + messages.append({"role": "user", "content": user_response_schema_message}) + optional_params.pop("response_schema") + + # Check for any 'litellm_param_*' set during optional param mapping + + remove_keys = [] + for k, v in optional_params.items(): + if k.startswith("litellm_param_"): + litellm_params.update({k: v}) + remove_keys.append(k) + + optional_params = {k: v for k, v in optional_params.items() if k not in remove_keys} + + try: + if custom_llm_provider == "gemini": + content = litellm.GoogleAIStudioGeminiConfig()._transform_messages( + messages=messages + ) + else: + content = litellm.VertexGeminiConfig()._transform_messages( + messages=messages + ) + tools: Optional[Tools] = optional_params.pop("tools", None) + tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) + safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( + "safety_settings", None + ) # type: ignore + config_fields = GenerationConfig.__annotations__.keys() + + filtered_params = { + k: v for k, v in optional_params.items() if k in config_fields + } + + generation_config: Optional[GenerationConfig] = GenerationConfig( + **filtered_params + ) + data = RequestBody(contents=content) + if system_instructions is not None: + data["system_instruction"] = system_instructions + if tools is not None: + data["tools"] = tools + if tool_choice is not None: + data["toolConfig"] = tool_choice + if safety_settings is not None: + data["safetySettings"] = safety_settings + if generation_config is not None: + data["generationConfig"] = generation_config + if cached_content is not None: + data["cachedContent"] = cached_content + except Exception as e: + raise e + + return data + + +def sync_transform_request_body( + gemini_api_key: Optional[str], + messages: List[AllMessageValues], + api_base: Optional[str], + model: str, + client: Optional[HTTPHandler], + timeout: Optional[Union[float, httpx.Timeout]], + extra_headers: Optional[dict], + optional_params: dict, + logging_obj: LiteLLMLoggingObj, + custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"], + litellm_params: dict, +) -> RequestBody: + from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints + + context_caching_endpoints = ContextCachingEndpoints() + + if gemini_api_key is not None: + messages, cached_content = context_caching_endpoints.check_and_create_cache( + messages=messages, + api_key=gemini_api_key, + api_base=api_base, + model=model, + client=client, + timeout=timeout, + extra_headers=extra_headers, + cached_content=optional_params.pop("cached_content", None), + logging_obj=logging_obj, + ) + else: # [TODO] implement context caching for gemini as well + cached_content = optional_params.pop("cached_content", None) + + return _transform_request_body( + messages=messages, + model=model, + custom_llm_provider=custom_llm_provider, + litellm_params=litellm_params, + cached_content=cached_content, + optional_params=optional_params, + ) + + +async def async_transform_request_body( + gemini_api_key: Optional[str], + messages: List[AllMessageValues], + api_base: Optional[str], + model: str, + client: Optional[AsyncHTTPHandler], + timeout: Optional[Union[float, httpx.Timeout]], + extra_headers: Optional[dict], + optional_params: dict, + logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, # type: ignore + custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"], + litellm_params: dict, +) -> RequestBody: + from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints + + context_caching_endpoints = ContextCachingEndpoints() + + if gemini_api_key is not None: + messages, cached_content = ( + await context_caching_endpoints.async_check_and_create_cache( + messages=messages, + api_key=gemini_api_key, + api_base=api_base, + model=model, + client=client, + timeout=timeout, + extra_headers=extra_headers, + cached_content=optional_params.pop("cached_content", None), + logging_obj=logging_obj, + ) + ) + else: # [TODO] implement context caching for gemini as well + cached_content = optional_params.pop("cached_content", None) + + return _transform_request_body( + messages=messages, + model=model, + custom_llm_provider=custom_llm_provider, + litellm_params=litellm_params, + cached_content=cached_content, + optional_params=optional_params, + ) + + +def _transform_system_message( + supports_system_message: bool, messages: List[AllMessageValues] +) -> Tuple[Optional[SystemInstructions], List[AllMessageValues]]: + """ + Extracts the system message from the openai message list. + + Converts the system message to Gemini format + + Returns + - system_content_blocks: Optional[SystemInstructions] - the system message list in Gemini format. + - messages: List[AllMessageValues] - filtered list of messages in OpenAI format (transformed separately) + """ + # Separate system prompt from rest of message + system_prompt_indices = [] + system_content_blocks: List[PartType] = [] + if supports_system_message is True: + for idx, message in enumerate(messages): + if message["role"] == "system": + _system_content_block: Optional[PartType] = None + if isinstance(message["content"], str): + _system_content_block = PartType(text=message["content"]) + elif isinstance(message["content"], list): + system_text = "" + for content in message["content"]: + system_text += content.get("text") or "" + _system_content_block = PartType(text=system_text) + if _system_content_block is not None: + system_content_blocks.append(_system_content_block) + system_prompt_indices.append(idx) + if len(system_prompt_indices) > 0: + for idx in reversed(system_prompt_indices): + messages.pop(idx) + + if len(system_content_blocks) > 0: + return SystemInstructions(parts=system_content_blocks), messages + + return None, messages diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py new file mode 100644 index 00000000..9ac1b1ff --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -0,0 +1,1495 @@ +# What is this? +## httpx client for vertex ai calls +## Initial implementation - covers gemini + image gen calls +import json +import uuid +from copy import deepcopy +from functools import partial +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Literal, + Optional, + Tuple, + Union, + cast, +) + +import httpx # type: ignore + +import litellm +import litellm.litellm_core_utils +import litellm.litellm_core_utils.litellm_logging +from litellm import verbose_logger +from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, +) +from litellm.types.llms.openai import ( + AllMessageValues, + ChatCompletionResponseMessage, + ChatCompletionToolCallChunk, + ChatCompletionToolCallFunctionChunk, + ChatCompletionToolParamFunctionChunk, + ChatCompletionUsageBlock, +) +from litellm.types.llms.vertex_ai import ( + VERTEX_CREDENTIALS_TYPES, + Candidates, + ContentType, + FunctionCallingConfig, + FunctionDeclaration, + GenerateContentResponseBody, + HttpxPartType, + LogprobsResult, + ToolConfig, + Tools, +) +from litellm.types.utils import ( + ChatCompletionTokenLogprob, + ChoiceLogprobs, + GenericStreamingChunk, + PromptTokensDetailsWrapper, + TopLogprob, + Usage, +) +from litellm.utils import CustomStreamWrapper, ModelResponse + +from ....utils import _remove_additional_properties, _remove_strict_from_schema +from ..common_utils import VertexAIError, _build_vertex_schema +from ..vertex_llm_base import VertexBase +from .transformation import ( + _gemini_convert_messages_with_history, + async_transform_request_body, + sync_transform_request_body, +) + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj + + LoggingClass = LiteLLMLoggingObj +else: + LoggingClass = Any + + +class VertexAIBaseConfig: + def get_mapped_special_auth_params(self) -> dict: + """ + Common auth params across bedrock/vertex_ai/azure/watsonx + """ + return {"project": "vertex_project", "region_name": "vertex_location"} + + def map_special_auth_params(self, non_default_params: dict, optional_params: dict): + mapped_params = self.get_mapped_special_auth_params() + + for param, value in non_default_params.items(): + if param in mapped_params: + optional_params[mapped_params[param]] = value + return optional_params + + def get_eu_regions(self) -> List[str]: + """ + Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions + """ + return [ + "europe-central2", + "europe-north1", + "europe-southwest1", + "europe-west1", + "europe-west2", + "europe-west3", + "europe-west4", + "europe-west6", + "europe-west8", + "europe-west9", + ] + + def get_us_regions(self) -> List[str]: + """ + Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions + """ + return [ + "us-central1", + "us-east1", + "us-east4", + "us-east5", + "us-south1", + "us-west1", + "us-west4", + "us-west5", + ] + + +class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): + """ + Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts + Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference + + The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters: + + - `temperature` (float): This controls the degree of randomness in token selection. + + - `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256. + + - `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95. + + - `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40. + + - `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'. + + - `candidate_count` (int): Number of generated responses to return. + + - `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response. + + - `frequency_penalty` (float): This parameter is used to penalize the model from repeating the same output. The default value is 0.0. + + - `presence_penalty` (float): This parameter is used to penalize the model from generating the same output as the input. The default value is 0.0. + + - `seed` (int): The seed value is used to help generate the same output for the same input. The default value is None. + + Note: Please make sure to modify the default parameters as required for your use case. + """ + + temperature: Optional[float] = None + max_output_tokens: Optional[int] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + response_mime_type: Optional[str] = None + candidate_count: Optional[int] = None + stop_sequences: Optional[list] = None + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + seed: Optional[int] = None + + def __init__( + self, + temperature: Optional[float] = None, + max_output_tokens: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + response_mime_type: Optional[str] = None, + candidate_count: Optional[int] = None, + stop_sequences: Optional[list] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + ) -> None: + locals_ = locals().copy() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return super().get_config() + + def get_supported_openai_params(self, model: str) -> List[str]: + return [ + "temperature", + "top_p", + "max_tokens", + "max_completion_tokens", + "stream", + "tools", + "functions", + "tool_choice", + "response_format", + "n", + "stop", + "frequency_penalty", + "presence_penalty", + "extra_headers", + "seed", + "logprobs", + ] + + def map_tool_choice_values( + self, model: str, tool_choice: Union[str, dict] + ) -> Optional[ToolConfig]: + if tool_choice == "none": + return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="NONE")) + elif tool_choice == "required": + return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="ANY")) + elif tool_choice == "auto": + return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="AUTO")) + elif isinstance(tool_choice, dict): + # only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + name = tool_choice.get("function", {}).get("name", "") + return ToolConfig( + functionCallingConfig=FunctionCallingConfig( + mode="ANY", allowed_function_names=[name] + ) + ) + else: + raise litellm.utils.UnsupportedParamsError( + message="VertexAI doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format( + tool_choice + ), + status_code=400, + ) + + def _map_function(self, value: List[dict]) -> List[Tools]: + gtool_func_declarations = [] + googleSearch: Optional[dict] = None + googleSearchRetrieval: Optional[dict] = None + code_execution: Optional[dict] = None + # remove 'additionalProperties' from tools + value = _remove_additional_properties(value) + # remove 'strict' from tools + value = _remove_strict_from_schema(value) + + for tool in value: + openai_function_object: Optional[ChatCompletionToolParamFunctionChunk] = ( + None + ) + if "function" in tool: # tools list + _openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore + **tool["function"] + ) + + if ( + "parameters" in _openai_function_object + and _openai_function_object["parameters"] is not None + ): # OPENAI accepts JSON Schema, Google accepts OpenAPI schema. + _openai_function_object["parameters"] = _build_vertex_schema( + _openai_function_object["parameters"] + ) + + openai_function_object = _openai_function_object + + elif "name" in tool: # functions list + openai_function_object = ChatCompletionToolParamFunctionChunk(**tool) # type: ignore + + # check if grounding + if tool.get("googleSearch", None) is not None: + googleSearch = tool["googleSearch"] + elif tool.get("googleSearchRetrieval", None) is not None: + googleSearchRetrieval = tool["googleSearchRetrieval"] + elif tool.get("code_execution", None) is not None: + code_execution = tool["code_execution"] + elif openai_function_object is not None: + gtool_func_declaration = FunctionDeclaration( + name=openai_function_object["name"], + ) + _description = openai_function_object.get("description", None) + _parameters = openai_function_object.get("parameters", None) + if _description is not None: + gtool_func_declaration["description"] = _description + if _parameters is not None: + gtool_func_declaration["parameters"] = _parameters + gtool_func_declarations.append(gtool_func_declaration) + else: + # assume it's a provider-specific param + verbose_logger.warning( + "Invalid tool={}. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request." + ) + + _tools = Tools( + function_declarations=gtool_func_declarations, + ) + if googleSearch is not None: + _tools["googleSearch"] = googleSearch + if googleSearchRetrieval is not None: + _tools["googleSearchRetrieval"] = googleSearchRetrieval + if code_execution is not None: + _tools["code_execution"] = code_execution + return [_tools] + + def _map_response_schema(self, value: dict) -> dict: + old_schema = deepcopy(value) + if isinstance(old_schema, list): + for item in old_schema: + if isinstance(item, dict): + item = _build_vertex_schema(parameters=item) + elif isinstance(old_schema, dict): + old_schema = _build_vertex_schema(parameters=old_schema) + return old_schema + + def map_openai_params( + self, + non_default_params: Dict, + optional_params: Dict, + model: str, + drop_params: bool, + ) -> Dict: + for param, value in non_default_params.items(): + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if ( + param == "stream" and value is True + ): # sending stream = False, can cause it to get passed unchecked and raise issues + optional_params["stream"] = value + if param == "n": + optional_params["candidate_count"] = value + if param == "stop": + if isinstance(value, str): + optional_params["stop_sequences"] = [value] + elif isinstance(value, list): + optional_params["stop_sequences"] = value + if param == "max_tokens" or param == "max_completion_tokens": + optional_params["max_output_tokens"] = value + if param == "response_format" and isinstance(value, dict): # type: ignore + # remove 'additionalProperties' from json schema + value = _remove_additional_properties(value) + # remove 'strict' from json schema + value = _remove_strict_from_schema(value) + if value["type"] == "json_object": + optional_params["response_mime_type"] = "application/json" + elif value["type"] == "text": + optional_params["response_mime_type"] = "text/plain" + if "response_schema" in value: + optional_params["response_mime_type"] = "application/json" + optional_params["response_schema"] = value["response_schema"] + elif value["type"] == "json_schema": # type: ignore + if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore + optional_params["response_mime_type"] = "application/json" + optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore + + if "response_schema" in optional_params and isinstance( + optional_params["response_schema"], dict + ): + optional_params["response_schema"] = self._map_response_schema( + value=optional_params["response_schema"] + ) + if param == "frequency_penalty": + optional_params["frequency_penalty"] = value + if param == "presence_penalty": + optional_params["presence_penalty"] = value + if param == "logprobs": + optional_params["responseLogprobs"] = value + if (param == "tools" or param == "functions") and isinstance(value, list): + optional_params["tools"] = self._map_function(value=value) + optional_params["litellm_param_is_function_call"] = ( + True if param == "functions" else False + ) + if param == "tool_choice" and ( + isinstance(value, str) or isinstance(value, dict) + ): + _tool_choice_value = self.map_tool_choice_values( + model=model, tool_choice=value # type: ignore + ) + if _tool_choice_value is not None: + optional_params["tool_choice"] = _tool_choice_value + if param == "seed": + optional_params["seed"] = value + + if litellm.vertex_ai_safety_settings is not None: + optional_params["safety_settings"] = litellm.vertex_ai_safety_settings + return optional_params + + def get_mapped_special_auth_params(self) -> dict: + """ + Common auth params across bedrock/vertex_ai/azure/watsonx + """ + return {"project": "vertex_project", "region_name": "vertex_location"} + + def map_special_auth_params(self, non_default_params: dict, optional_params: dict): + mapped_params = self.get_mapped_special_auth_params() + + for param, value in non_default_params.items(): + if param in mapped_params: + optional_params[mapped_params[param]] = value + return optional_params + + def get_eu_regions(self) -> List[str]: + """ + Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions + """ + return [ + "europe-central2", + "europe-north1", + "europe-southwest1", + "europe-west1", + "europe-west2", + "europe-west3", + "europe-west4", + "europe-west6", + "europe-west8", + "europe-west9", + ] + + def get_flagged_finish_reasons(self) -> Dict[str, str]: + """ + Return Dictionary of finish reasons which indicate response was flagged + + and what it means + """ + return { + "SAFETY": "The token generation was stopped as the response was flagged for safety reasons. NOTE: When streaming the Candidate.content will be empty if content filters blocked the output.", + "RECITATION": "The token generation was stopped as the response was flagged for unauthorized citations.", + "BLOCKLIST": "The token generation was stopped as the response was flagged for the terms which are included from the terminology blocklist.", + "PROHIBITED_CONTENT": "The token generation was stopped as the response was flagged for the prohibited contents.", + "SPII": "The token generation was stopped as the response was flagged for Sensitive Personally Identifiable Information (SPII) contents.", + } + + def translate_exception_str(self, exception_string: str): + if ( + "GenerateContentRequest.tools[0].function_declarations[0].parameters.properties: should be non-empty for OBJECT type" + in exception_string + ): + return "'properties' field in tools[0]['function']['parameters'] cannot be empty if 'type' == 'object'. Received error from provider - {}".format( + exception_string + ) + return exception_string + + def get_assistant_content_message( + self, parts: List[HttpxPartType] + ) -> Optional[str]: + _content_str = "" + for part in parts: + if "text" in part: + _content_str += part["text"] + if _content_str: + return _content_str + return None + + def _transform_parts( + self, + parts: List[HttpxPartType], + index: int, + is_function_call: Optional[bool], + ) -> Tuple[ + Optional[ChatCompletionToolCallFunctionChunk], + Optional[List[ChatCompletionToolCallChunk]], + ]: + function: Optional[ChatCompletionToolCallFunctionChunk] = None + _tools: List[ChatCompletionToolCallChunk] = [] + for part in parts: + if "functionCall" in part: + _function_chunk = ChatCompletionToolCallFunctionChunk( + name=part["functionCall"]["name"], + arguments=json.dumps(part["functionCall"]["args"]), + ) + if is_function_call is True: + function = _function_chunk + else: + _tool_response_chunk = ChatCompletionToolCallChunk( + id=f"call_{str(uuid.uuid4())}", + type="function", + function=_function_chunk, + index=index, + ) + _tools.append(_tool_response_chunk) + if len(_tools) == 0: + tools: Optional[List[ChatCompletionToolCallChunk]] = None + else: + tools = _tools + return function, tools + + def _transform_logprobs( + self, logprobs_result: Optional[LogprobsResult] + ) -> Optional[ChoiceLogprobs]: + if logprobs_result is None: + return None + if "chosenCandidates" not in logprobs_result: + return None + logprobs_list: List[ChatCompletionTokenLogprob] = [] + for index, candidate in enumerate(logprobs_result["chosenCandidates"]): + top_logprobs: List[TopLogprob] = [] + if "topCandidates" in logprobs_result and index < len( + logprobs_result["topCandidates"] + ): + top_candidates_for_index = logprobs_result["topCandidates"][index][ + "candidates" + ] + + for options in top_candidates_for_index: + top_logprobs.append( + TopLogprob( + token=options["token"], logprob=options["logProbability"] + ) + ) + logprobs_list.append( + ChatCompletionTokenLogprob( + token=candidate["token"], + logprob=candidate["logProbability"], + top_logprobs=top_logprobs, + ) + ) + return ChoiceLogprobs(content=logprobs_list) + + def _handle_blocked_response( + self, + model_response: ModelResponse, + completion_response: GenerateContentResponseBody, + ) -> ModelResponse: + # If set, the prompt was blocked and no candidates are returned. Rephrase your prompt + model_response.choices[0].finish_reason = "content_filter" + + chat_completion_message: ChatCompletionResponseMessage = { + "role": "assistant", + "content": None, + } + + choice = litellm.Choices( + finish_reason="content_filter", + index=0, + message=chat_completion_message, # type: ignore + logprobs=None, + enhancements=None, + ) + + model_response.choices = [choice] + + ## GET USAGE ## + usage = Usage( + prompt_tokens=completion_response["usageMetadata"].get( + "promptTokenCount", 0 + ), + completion_tokens=completion_response["usageMetadata"].get( + "candidatesTokenCount", 0 + ), + total_tokens=completion_response["usageMetadata"].get("totalTokenCount", 0), + ) + + setattr(model_response, "usage", usage) + + return model_response + + def _handle_content_policy_violation( + self, + model_response: ModelResponse, + completion_response: GenerateContentResponseBody, + ) -> ModelResponse: + ## CONTENT POLICY VIOLATION ERROR + model_response.choices[0].finish_reason = "content_filter" + + _chat_completion_message = { + "role": "assistant", + "content": None, + } + + choice = litellm.Choices( + finish_reason="content_filter", + index=0, + message=_chat_completion_message, + logprobs=None, + enhancements=None, + ) + + model_response.choices = [choice] + + ## GET USAGE ## + usage = Usage( + prompt_tokens=completion_response["usageMetadata"].get( + "promptTokenCount", 0 + ), + completion_tokens=completion_response["usageMetadata"].get( + "candidatesTokenCount", 0 + ), + total_tokens=completion_response["usageMetadata"].get("totalTokenCount", 0), + ) + + setattr(model_response, "usage", usage) + + return model_response + + def _calculate_usage( + self, + completion_response: GenerateContentResponseBody, + ) -> Usage: + cached_tokens: Optional[int] = None + prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None + if "cachedContentTokenCount" in completion_response["usageMetadata"]: + cached_tokens = completion_response["usageMetadata"][ + "cachedContentTokenCount" + ] + + if cached_tokens is not None: + prompt_tokens_details = PromptTokensDetailsWrapper( + cached_tokens=cached_tokens, + ) + ## GET USAGE ## + usage = Usage( + prompt_tokens=completion_response["usageMetadata"].get( + "promptTokenCount", 0 + ), + completion_tokens=completion_response["usageMetadata"].get( + "candidatesTokenCount", 0 + ), + total_tokens=completion_response["usageMetadata"].get("totalTokenCount", 0), + prompt_tokens_details=prompt_tokens_details, + ) + + return usage + + def transform_response( + self, + model: str, + raw_response: httpx.Response, + model_response: ModelResponse, + logging_obj: LoggingClass, + request_data: Dict, + messages: List[AllMessageValues], + optional_params: Dict, + litellm_params: Dict, + encoding: Any, + api_key: Optional[str] = None, + json_mode: Optional[bool] = None, + ) -> ModelResponse: + ## LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=raw_response.text, + additional_args={"complete_input_dict": request_data}, + ) + + ## RESPONSE OBJECT + try: + completion_response = GenerateContentResponseBody(**raw_response.json()) # type: ignore + except Exception as e: + raise VertexAIError( + message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format( + raw_response.text, str(e) + ), + status_code=422, + headers=raw_response.headers, + ) + + ## GET MODEL ## + model_response.model = model + + ## CHECK IF RESPONSE FLAGGED + if ( + "promptFeedback" in completion_response + and "blockReason" in completion_response["promptFeedback"] + ): + return self._handle_blocked_response( + model_response=model_response, + completion_response=completion_response, + ) + + _candidates = completion_response.get("candidates") + if _candidates and len(_candidates) > 0: + content_policy_violations = ( + VertexGeminiConfig().get_flagged_finish_reasons() + ) + if ( + "finishReason" in _candidates[0] + and _candidates[0]["finishReason"] in content_policy_violations.keys() + ): + return self._handle_content_policy_violation( + model_response=model_response, + completion_response=completion_response, + ) + + model_response.choices = [] # type: ignore + + try: + ## CHECK IF GROUNDING METADATA IN REQUEST + grounding_metadata: List[dict] = [] + safety_ratings: List = [] + citation_metadata: List = [] + ## GET TEXT ## + chat_completion_message: ChatCompletionResponseMessage = { + "role": "assistant" + } + chat_completion_logprobs: Optional[ChoiceLogprobs] = None + tools: Optional[List[ChatCompletionToolCallChunk]] = [] + functions: Optional[ChatCompletionToolCallFunctionChunk] = None + if _candidates: + for idx, candidate in enumerate(_candidates): + if "content" not in candidate: + continue + + if "groundingMetadata" in candidate: + grounding_metadata.append(candidate["groundingMetadata"]) # type: ignore + + if "safetyRatings" in candidate: + safety_ratings.append(candidate["safetyRatings"]) + + if "citationMetadata" in candidate: + citation_metadata.append(candidate["citationMetadata"]) + if "parts" in candidate["content"]: + chat_completion_message[ + "content" + ] = VertexGeminiConfig().get_assistant_content_message( + parts=candidate["content"]["parts"] + ) + + functions, tools = self._transform_parts( + parts=candidate["content"]["parts"], + index=candidate.get("index", idx), + is_function_call=litellm_params.get( + "litellm_param_is_function_call" + ), + ) + + if "logprobsResult" in candidate: + chat_completion_logprobs = self._transform_logprobs( + logprobs_result=candidate["logprobsResult"] + ) + + if tools: + chat_completion_message["tool_calls"] = tools + + if functions is not None: + chat_completion_message["function_call"] = functions + choice = litellm.Choices( + finish_reason=candidate.get("finishReason", "stop"), + index=candidate.get("index", idx), + message=chat_completion_message, # type: ignore + logprobs=chat_completion_logprobs, + enhancements=None, + ) + + model_response.choices.append(choice) + + usage = self._calculate_usage(completion_response=completion_response) + setattr(model_response, "usage", usage) + + ## ADD GROUNDING METADATA ## + setattr(model_response, "vertex_ai_grounding_metadata", grounding_metadata) + model_response._hidden_params[ + "vertex_ai_grounding_metadata" + ] = ( # older approach - maintaining to prevent regressions + grounding_metadata + ) + + ## ADD SAFETY RATINGS ## + setattr(model_response, "vertex_ai_safety_results", safety_ratings) + model_response._hidden_params["vertex_ai_safety_results"] = ( + safety_ratings # older approach - maintaining to prevent regressions + ) + + ## ADD CITATION METADATA ## + setattr(model_response, "vertex_ai_citation_metadata", citation_metadata) + model_response._hidden_params["vertex_ai_citation_metadata"] = ( + citation_metadata # older approach - maintaining to prevent regressions + ) + + except Exception as e: + raise VertexAIError( + message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format( + completion_response, str(e) + ), + status_code=422, + headers=raw_response.headers, + ) + + return model_response + + def _transform_messages( + self, messages: List[AllMessageValues] + ) -> List[ContentType]: + return _gemini_convert_messages_with_history(messages=messages) + + def get_error_class( + self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers] + ) -> BaseLLMException: + return VertexAIError( + message=error_message, status_code=status_code, headers=headers + ) + + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: Dict, + litellm_params: Dict, + headers: Dict, + ) -> Dict: + raise NotImplementedError( + "Vertex AI has a custom implementation of transform_request. Needs sync + async." + ) + + def validate_environment( + self, + headers: Optional[Dict], + model: str, + messages: List[AllMessageValues], + optional_params: Dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> Dict: + default_headers = { + "Content-Type": "application/json", + } + if api_key is not None: + default_headers["Authorization"] = f"Bearer {api_key}" + if headers is not None: + default_headers.update(headers) + + return default_headers + + +async def make_call( + client: Optional[AsyncHTTPHandler], + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, +): + if client is None: + client = get_async_httpx_client( + llm_provider=litellm.LlmProviders.VERTEX_AI, + ) + + try: + response = await client.post(api_base, headers=headers, data=data, stream=True) + response.raise_for_status() + except httpx.HTTPStatusError as e: + exception_string = str(await e.response.aread()) + raise VertexAIError( + status_code=e.response.status_code, + message=VertexGeminiConfig().translate_exception_str(exception_string), + headers=e.response.headers, + ) + if response.status_code != 200 and response.status_code != 201: + raise VertexAIError( + status_code=response.status_code, + message=response.text, + headers=response.headers, + ) + + completion_stream = ModelResponseIterator( + streaming_response=response.aiter_lines(), sync_stream=False + ) + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response="first stream response received", + additional_args={"complete_input_dict": data}, + ) + + return completion_stream + + +def make_sync_call( + client: Optional[HTTPHandler], # module-level client + gemini_client: Optional[HTTPHandler], # if passed by user + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, +): + if gemini_client is not None: + client = gemini_client + if client is None: + client = HTTPHandler() # Create a new client if none provided + + response = client.post(api_base, headers=headers, data=data, stream=True) + + if response.status_code != 200 and response.status_code != 201: + raise VertexAIError( + status_code=response.status_code, + message=str(response.read()), + headers=response.headers, + ) + + completion_stream = ModelResponseIterator( + streaming_response=response.iter_lines(), sync_stream=True + ) + + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response="first stream response received", + additional_args={"complete_input_dict": data}, + ) + + return completion_stream + + +class VertexLLM(VertexBase): + def __init__(self) -> None: + super().__init__() + + async def async_streaming( + self, + model: str, + custom_llm_provider: Literal[ + "vertex_ai", "vertex_ai_beta", "gemini" + ], # if it's vertex_ai or gemini (google ai studio) + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + data: dict, + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params=None, + logger_fn=None, + api_base: Optional[str] = None, + client: Optional[AsyncHTTPHandler] = None, + vertex_project: Optional[str] = None, + vertex_location: Optional[str] = None, + vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None, + gemini_api_key: Optional[str] = None, + extra_headers: Optional[dict] = None, + ) -> CustomStreamWrapper: + request_body = await async_transform_request_body(**data) # type: ignore + + should_use_v1beta1_features = self.is_using_v1beta1_features( + optional_params=optional_params + ) + + _auth_header, vertex_project = await self._ensure_access_token_async( + credentials=vertex_credentials, + project_id=vertex_project, + custom_llm_provider=custom_llm_provider, + ) + + auth_header, api_base = self._get_token_and_url( + model=model, + gemini_api_key=gemini_api_key, + auth_header=_auth_header, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + stream=stream, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + should_use_v1beta1_features=should_use_v1beta1_features, + ) + + headers = VertexGeminiConfig().validate_environment( + api_key=auth_header, + headers=extra_headers, + model=model, + messages=messages, + optional_params=optional_params, + ) + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": api_base, + "headers": headers, + }, + ) + + request_body_str = json.dumps(request_body) + streaming_response = CustomStreamWrapper( + completion_stream=None, + make_call=partial( + make_call, + client=client, + api_base=api_base, + headers=headers, + data=request_body_str, + model=model, + messages=messages, + logging_obj=logging_obj, + ), + model=model, + custom_llm_provider="vertex_ai_beta", + logging_obj=logging_obj, + ) + return streaming_response + + async def async_completion( + self, + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + data: dict, + custom_llm_provider: Literal[ + "vertex_ai", "vertex_ai_beta", "gemini" + ], # if it's vertex_ai or gemini (google ai studio) + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params: dict, + logger_fn=None, + api_base: Optional[str] = None, + client: Optional[AsyncHTTPHandler] = None, + vertex_project: Optional[str] = None, + vertex_location: Optional[str] = None, + vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None, + gemini_api_key: Optional[str] = None, + extra_headers: Optional[dict] = None, + ) -> Union[ModelResponse, CustomStreamWrapper]: + should_use_v1beta1_features = self.is_using_v1beta1_features( + optional_params=optional_params + ) + + _auth_header, vertex_project = await self._ensure_access_token_async( + credentials=vertex_credentials, + project_id=vertex_project, + custom_llm_provider=custom_llm_provider, + ) + + auth_header, api_base = self._get_token_and_url( + model=model, + gemini_api_key=gemini_api_key, + auth_header=_auth_header, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + stream=stream, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + should_use_v1beta1_features=should_use_v1beta1_features, + ) + + headers = VertexGeminiConfig().validate_environment( + api_key=auth_header, + headers=extra_headers, + model=model, + messages=messages, + optional_params=optional_params, + ) + + request_body = await async_transform_request_body(**data) # type: ignore + _async_client_params = {} + if timeout: + _async_client_params["timeout"] = timeout + if client is None or not isinstance(client, AsyncHTTPHandler): + client = get_async_httpx_client( + params=_async_client_params, llm_provider=litellm.LlmProviders.VERTEX_AI + ) + else: + client = client # type: ignore + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": request_body, + "api_base": api_base, + "headers": headers, + }, + ) + + try: + response = await client.post( + api_base, headers=headers, json=cast(dict, request_body) + ) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise VertexAIError( + status_code=error_code, + message=err.response.text, + headers=err.response.headers, + ) + except httpx.TimeoutException: + raise VertexAIError( + status_code=408, + message="Timeout error occurred.", + headers=None, + ) + + return VertexGeminiConfig().transform_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + api_key="", + request_data=cast(dict, request_body), + messages=messages, + optional_params=optional_params, + litellm_params=litellm_params, + encoding=encoding, + ) + + def completion( + self, + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + custom_llm_provider: Literal[ + "vertex_ai", "vertex_ai_beta", "gemini" + ], # if it's vertex_ai or gemini (google ai studio) + encoding, + logging_obj, + optional_params: dict, + acompletion: bool, + timeout: Optional[Union[float, httpx.Timeout]], + vertex_project: Optional[str], + vertex_location: Optional[str], + vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES], + gemini_api_key: Optional[str], + litellm_params: dict, + logger_fn=None, + extra_headers: Optional[dict] = None, + client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, + api_base: Optional[str] = None, + ) -> Union[ModelResponse, CustomStreamWrapper]: + stream: Optional[bool] = optional_params.pop("stream", None) # type: ignore + + transform_request_params = { + "gemini_api_key": gemini_api_key, + "messages": messages, + "api_base": api_base, + "model": model, + "client": client, + "timeout": timeout, + "extra_headers": extra_headers, + "optional_params": optional_params, + "logging_obj": logging_obj, + "custom_llm_provider": custom_llm_provider, + "litellm_params": litellm_params, + } + + ### ROUTING (ASYNC, STREAMING, SYNC) + if acompletion: + ### ASYNC STREAMING + if stream is True: + return self.async_streaming( + model=model, + messages=messages, + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, + client=client, # type: ignore + data=transform_request_params, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + gemini_api_key=gemini_api_key, + custom_llm_provider=custom_llm_provider, + extra_headers=extra_headers, + ) + ### ASYNC COMPLETION + return self.async_completion( + model=model, + messages=messages, + data=transform_request_params, # type: ignore + api_base=api_base, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, + litellm_params=litellm_params, + logger_fn=logger_fn, + timeout=timeout, + client=client, # type: ignore + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + gemini_api_key=gemini_api_key, + custom_llm_provider=custom_llm_provider, + extra_headers=extra_headers, + ) + + should_use_v1beta1_features = self.is_using_v1beta1_features( + optional_params=optional_params + ) + + _auth_header, vertex_project = self._ensure_access_token( + credentials=vertex_credentials, + project_id=vertex_project, + custom_llm_provider=custom_llm_provider, + ) + + auth_header, url = self._get_token_and_url( + model=model, + gemini_api_key=gemini_api_key, + auth_header=_auth_header, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_credentials=vertex_credentials, + stream=stream, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + should_use_v1beta1_features=should_use_v1beta1_features, + ) + headers = VertexGeminiConfig().validate_environment( + api_key=auth_header, + headers=extra_headers, + model=model, + messages=messages, + optional_params=optional_params, + ) + + ## TRANSFORMATION ## + data = sync_transform_request_body(**transform_request_params) + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": url, + "headers": headers, + }, + ) + + ## SYNC STREAMING CALL ## + if stream is True: + request_data_str = json.dumps(data) + streaming_response = CustomStreamWrapper( + completion_stream=None, + make_call=partial( + make_sync_call, + gemini_client=( + client + if client is not None and isinstance(client, HTTPHandler) + else None + ), + api_base=url, + data=request_data_str, + model=model, + messages=messages, + logging_obj=logging_obj, + headers=headers, + ), + model=model, + custom_llm_provider="vertex_ai_beta", + logging_obj=logging_obj, + ) + + return streaming_response + ## COMPLETION CALL ## + + if client is None or isinstance(client, AsyncHTTPHandler): + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = HTTPHandler(**_params) # type: ignore + else: + client = client + + try: + response = client.post(url=url, headers=headers, json=data) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise VertexAIError( + status_code=error_code, + message=err.response.text, + headers=err.response.headers, + ) + except httpx.TimeoutException: + raise VertexAIError( + status_code=408, + message="Timeout error occurred.", + headers=None, + ) + + return VertexGeminiConfig().transform_response( + model=model, + raw_response=response, + model_response=model_response, + logging_obj=logging_obj, + optional_params=optional_params, + litellm_params=litellm_params, + api_key="", + request_data=data, # type: ignore + messages=messages, + encoding=encoding, + ) + + +class ModelResponseIterator: + def __init__(self, streaming_response, sync_stream: bool): + self.streaming_response = streaming_response + self.chunk_type: Literal["valid_json", "accumulated_json"] = "valid_json" + self.accumulated_json = "" + self.sent_first_chunk = False + + def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: + try: + processed_chunk = GenerateContentResponseBody(**chunk) # type: ignore + + text = "" + tool_use: Optional[ChatCompletionToolCallChunk] = None + finish_reason = "" + usage: Optional[ChatCompletionUsageBlock] = None + _candidates: Optional[List[Candidates]] = processed_chunk.get("candidates") + gemini_chunk: Optional[Candidates] = None + if _candidates and len(_candidates) > 0: + gemini_chunk = _candidates[0] + + if ( + gemini_chunk + and "content" in gemini_chunk + and "parts" in gemini_chunk["content"] + ): + if "text" in gemini_chunk["content"]["parts"][0]: + text = gemini_chunk["content"]["parts"][0]["text"] + elif "functionCall" in gemini_chunk["content"]["parts"][0]: + function_call = ChatCompletionToolCallFunctionChunk( + name=gemini_chunk["content"]["parts"][0]["functionCall"][ + "name" + ], + arguments=json.dumps( + gemini_chunk["content"]["parts"][0]["functionCall"]["args"] + ), + ) + tool_use = ChatCompletionToolCallChunk( + id=str(uuid.uuid4()), + type="function", + function=function_call, + index=0, + ) + + if gemini_chunk and "finishReason" in gemini_chunk: + finish_reason = map_finish_reason( + finish_reason=gemini_chunk["finishReason"] + ) + ## DO NOT SET 'is_finished' = True + ## GEMINI SETS FINISHREASON ON EVERY CHUNK! + + if "usageMetadata" in processed_chunk: + usage = ChatCompletionUsageBlock( + prompt_tokens=processed_chunk["usageMetadata"].get( + "promptTokenCount", 0 + ), + completion_tokens=processed_chunk["usageMetadata"].get( + "candidatesTokenCount", 0 + ), + total_tokens=processed_chunk["usageMetadata"].get( + "totalTokenCount", 0 + ), + ) + + returned_chunk = GenericStreamingChunk( + text=text, + tool_use=tool_use, + is_finished=False, + finish_reason=finish_reason, + usage=usage, + index=0, + ) + return returned_chunk + except json.JSONDecodeError: + raise ValueError(f"Failed to decode JSON from chunk: {chunk}") + + # Sync iterator + def __iter__(self): + self.response_iterator = self.streaming_response + return self + + def handle_valid_json_chunk(self, chunk: str) -> GenericStreamingChunk: + chunk = chunk.strip() + try: + json_chunk = json.loads(chunk) + + except json.JSONDecodeError as e: + if ( + self.sent_first_chunk is False + ): # only check for accumulated json, on first chunk, else raise error. Prevent real errors from being masked. + self.chunk_type = "accumulated_json" + return self.handle_accumulated_json_chunk(chunk=chunk) + raise e + + if self.sent_first_chunk is False: + self.sent_first_chunk = True + + return self.chunk_parser(chunk=json_chunk) + + def handle_accumulated_json_chunk(self, chunk: str) -> GenericStreamingChunk: + chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or "" + message = chunk.replace("\n\n", "") + + # Accumulate JSON data + self.accumulated_json += message + + # Try to parse the accumulated JSON + try: + _data = json.loads(self.accumulated_json) + self.accumulated_json = "" # reset after successful parsing + return self.chunk_parser(chunk=_data) + except json.JSONDecodeError: + # If it's not valid JSON yet, continue to the next event + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) + + def _common_chunk_parsing_logic(self, chunk: str) -> GenericStreamingChunk: + try: + chunk = litellm.CustomStreamWrapper._strip_sse_data_from_chunk(chunk) or "" + if len(chunk) > 0: + """ + Check if initial chunk valid json + - if partial json -> enter accumulated json logic + - if valid - continue + """ + if self.chunk_type == "valid_json": + return self.handle_valid_json_chunk(chunk=chunk) + elif self.chunk_type == "accumulated_json": + return self.handle_accumulated_json_chunk(chunk=chunk) + + return GenericStreamingChunk( + text="", + is_finished=False, + finish_reason="", + usage=None, + index=0, + tool_use=None, + ) + except Exception: + raise + + def __next__(self): + try: + chunk = self.response_iterator.__next__() + except StopIteration: + if self.chunk_type == "accumulated_json" and self.accumulated_json: + return self.handle_accumulated_json_chunk(chunk="") + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + return self._common_chunk_parsing_logic(chunk=chunk) + except StopIteration: + raise StopIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") + + # Async iterator + def __aiter__(self): + self.async_response_iterator = self.streaming_response.__aiter__() + return self + + async def __anext__(self): + try: + chunk = await self.async_response_iterator.__anext__() + except StopAsyncIteration: + if self.chunk_type == "accumulated_json" and self.accumulated_json: + return self.handle_accumulated_json_chunk(chunk="") + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error receiving chunk from stream: {e}") + + try: + return self._common_chunk_parsing_logic(chunk=chunk) + except StopAsyncIteration: + raise StopAsyncIteration + except ValueError as e: + raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}") |