aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py479
1 files changed, 479 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py
new file mode 100644
index 00000000..d6bafc7c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py
@@ -0,0 +1,479 @@
+"""
+Transformation logic from OpenAI format to Gemini format.
+
+Why separate file? Make it easy to see how transformation works
+"""
+
+import os
+from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union, cast
+
+import httpx
+from pydantic import BaseModel
+
+import litellm
+from litellm._logging import verbose_logger
+from litellm.litellm_core_utils.prompt_templates.factory import (
+ convert_to_anthropic_image_obj,
+ convert_to_gemini_tool_call_invoke,
+ convert_to_gemini_tool_call_result,
+ response_schema_prompt,
+)
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
+from litellm.types.files import (
+ get_file_mime_type_for_file_type,
+ get_file_type_from_extension,
+ is_gemini_1_5_accepted_file_type,
+)
+from litellm.types.llms.openai import (
+ AllMessageValues,
+ ChatCompletionAssistantMessage,
+ ChatCompletionImageObject,
+ ChatCompletionTextObject,
+)
+from litellm.types.llms.vertex_ai import *
+from litellm.types.llms.vertex_ai import (
+ GenerationConfig,
+ PartType,
+ RequestBody,
+ SafetSettingsConfig,
+ SystemInstructions,
+ ToolConfig,
+ Tools,
+)
+
+from ..common_utils import (
+ _check_text_in_content,
+ get_supports_response_schema,
+ get_supports_system_message,
+)
+
+if TYPE_CHECKING:
+ from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
+
+ LiteLLMLoggingObj = _LiteLLMLoggingObj
+else:
+ LiteLLMLoggingObj = Any
+
+
+def _process_gemini_image(image_url: str, format: Optional[str] = None) -> PartType:
+ """
+ Given an image URL, return the appropriate PartType for Gemini
+ """
+
+ try:
+ # GCS URIs
+ if "gs://" in image_url:
+ # Figure out file type
+ extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png"
+ extension = extension_with_dot[1:] # Ex: "png"
+
+ if not format:
+ file_type = get_file_type_from_extension(extension)
+
+ # Validate the file type is supported by Gemini
+ if not is_gemini_1_5_accepted_file_type(file_type):
+ raise Exception(f"File type not supported by gemini - {file_type}")
+
+ mime_type = get_file_mime_type_for_file_type(file_type)
+ else:
+ mime_type = format
+ file_data = FileDataType(mime_type=mime_type, file_uri=image_url)
+
+ return PartType(file_data=file_data)
+ elif (
+ "https://" in image_url
+ and (image_type := format or _get_image_mime_type_from_url(image_url))
+ is not None
+ ):
+
+ file_data = FileDataType(file_uri=image_url, mime_type=image_type)
+ return PartType(file_data=file_data)
+ elif "http://" in image_url or "https://" in image_url or "base64" in image_url:
+ # https links for unsupported mime types and base64 images
+ image = convert_to_anthropic_image_obj(image_url, format=format)
+ _blob = BlobType(data=image["data"], mime_type=image["media_type"])
+ return PartType(inline_data=_blob)
+ raise Exception("Invalid image received - {}".format(image_url))
+ except Exception as e:
+ raise e
+
+
+def _get_image_mime_type_from_url(url: str) -> Optional[str]:
+ """
+ Get mime type for common image URLs
+ See gemini mime types: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/image-understanding#image-requirements
+
+ Supported by Gemini:
+ - PNG (`image/png`)
+ - JPEG (`image/jpeg`)
+ - WebP (`image/webp`)
+ Example:
+ url = https://example.com/image.jpg
+ Returns: image/jpeg
+ """
+ url = url.lower()
+ if url.endswith((".jpg", ".jpeg")):
+ return "image/jpeg"
+ elif url.endswith(".png"):
+ return "image/png"
+ elif url.endswith(".webp"):
+ return "image/webp"
+ elif url.endswith(".mp4"):
+ return "video/mp4"
+ elif url.endswith(".pdf"):
+ return "application/pdf"
+ return None
+
+
+def _gemini_convert_messages_with_history( # noqa: PLR0915
+ messages: List[AllMessageValues],
+) -> List[ContentType]:
+ """
+ Converts given messages from OpenAI format to Gemini format
+
+ - Parts must be iterable
+ - Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles)
+ - Please ensure that function response turn comes immediately after a function call turn
+ """
+ user_message_types = {"user", "system"}
+ contents: List[ContentType] = []
+
+ last_message_with_tool_calls = None
+
+ msg_i = 0
+ tool_call_responses = []
+ try:
+ while msg_i < len(messages):
+ user_content: List[PartType] = []
+ init_msg_i = msg_i
+ ## MERGE CONSECUTIVE USER CONTENT ##
+ while (
+ msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
+ ):
+ _message_content = messages[msg_i].get("content")
+ if _message_content is not None and isinstance(_message_content, list):
+ _parts: List[PartType] = []
+ for element in _message_content:
+ if (
+ element["type"] == "text"
+ and "text" in element
+ and len(element["text"]) > 0
+ ):
+ element = cast(ChatCompletionTextObject, element)
+ _part = PartType(text=element["text"])
+ _parts.append(_part)
+ elif element["type"] == "image_url":
+ element = cast(ChatCompletionImageObject, element)
+ img_element = element
+ format: Optional[str] = None
+ if isinstance(img_element["image_url"], dict):
+ image_url = img_element["image_url"]["url"]
+ format = img_element["image_url"].get("format")
+ else:
+ image_url = img_element["image_url"]
+ _part = _process_gemini_image(
+ image_url=image_url, format=format
+ )
+ _parts.append(_part)
+ user_content.extend(_parts)
+ elif (
+ _message_content is not None
+ and isinstance(_message_content, str)
+ and len(_message_content) > 0
+ ):
+ _part = PartType(text=_message_content)
+ user_content.append(_part)
+
+ msg_i += 1
+
+ if user_content:
+ """
+ check that user_content has 'text' parameter.
+ - Known Vertex Error: Unable to submit request because it must have a text parameter.
+ - Relevant Issue: https://github.com/BerriAI/litellm/issues/5515
+ """
+ has_text_in_content = _check_text_in_content(user_content)
+ if has_text_in_content is False:
+ verbose_logger.warning(
+ "No text in user content. Adding a blank text to user content, to ensure Gemini doesn't fail the request. Relevant Issue - https://github.com/BerriAI/litellm/issues/5515"
+ )
+ user_content.append(
+ PartType(text=" ")
+ ) # add a blank text, to ensure Gemini doesn't fail the request.
+ contents.append(ContentType(role="user", parts=user_content))
+ assistant_content = []
+ ## MERGE CONSECUTIVE ASSISTANT CONTENT ##
+ while msg_i < len(messages) and messages[msg_i]["role"] == "assistant":
+ if isinstance(messages[msg_i], BaseModel):
+ msg_dict: Union[ChatCompletionAssistantMessage, dict] = messages[msg_i].model_dump() # type: ignore
+ else:
+ msg_dict = messages[msg_i] # type: ignore
+ assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore
+ _message_content = assistant_msg.get("content", None)
+ if _message_content is not None and isinstance(_message_content, list):
+ _parts = []
+ for element in _message_content:
+ if isinstance(element, dict):
+ if element["type"] == "text":
+ _part = PartType(text=element["text"])
+ _parts.append(_part)
+ assistant_content.extend(_parts)
+ elif (
+ _message_content is not None
+ and isinstance(_message_content, str)
+ and _message_content
+ ):
+ assistant_text = _message_content # either string or none
+ assistant_content.append(PartType(text=assistant_text)) # type: ignore
+
+ ## HANDLE ASSISTANT FUNCTION CALL
+ if (
+ assistant_msg.get("tool_calls", []) is not None
+ or assistant_msg.get("function_call") is not None
+ ): # support assistant tool invoke conversion
+ assistant_content.extend(
+ convert_to_gemini_tool_call_invoke(assistant_msg)
+ )
+ last_message_with_tool_calls = assistant_msg
+
+ msg_i += 1
+
+ if assistant_content:
+ contents.append(ContentType(role="model", parts=assistant_content))
+
+ ## APPEND TOOL CALL MESSAGES ##
+ tool_call_message_roles = ["tool", "function"]
+ if (
+ msg_i < len(messages)
+ and messages[msg_i]["role"] in tool_call_message_roles
+ ):
+ _part = convert_to_gemini_tool_call_result(
+ messages[msg_i], last_message_with_tool_calls # type: ignore
+ )
+ msg_i += 1
+ tool_call_responses.append(_part)
+ if msg_i < len(messages) and (
+ messages[msg_i]["role"] not in tool_call_message_roles
+ ):
+ if len(tool_call_responses) > 0:
+ contents.append(ContentType(parts=tool_call_responses))
+ tool_call_responses = []
+
+ if msg_i == init_msg_i: # prevent infinite loops
+ raise Exception(
+ "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format(
+ messages[msg_i]
+ )
+ )
+ if len(tool_call_responses) > 0:
+ contents.append(ContentType(parts=tool_call_responses))
+ return contents
+ except Exception as e:
+ raise e
+
+
+def _transform_request_body(
+ messages: List[AllMessageValues],
+ model: str,
+ optional_params: dict,
+ custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
+ litellm_params: dict,
+ cached_content: Optional[str],
+) -> RequestBody:
+ """
+ Common transformation logic across sync + async Gemini /generateContent calls.
+ """
+ # Separate system prompt from rest of message
+ supports_system_message = get_supports_system_message(
+ model=model, custom_llm_provider=custom_llm_provider
+ )
+ system_instructions, messages = _transform_system_message(
+ supports_system_message=supports_system_message, messages=messages
+ )
+ # Checks for 'response_schema' support - if passed in
+ if "response_schema" in optional_params:
+ supports_response_schema = get_supports_response_schema(
+ model=model, custom_llm_provider=custom_llm_provider
+ )
+ if supports_response_schema is False:
+ user_response_schema_message = response_schema_prompt(
+ model=model, response_schema=optional_params.get("response_schema") # type: ignore
+ )
+ messages.append({"role": "user", "content": user_response_schema_message})
+ optional_params.pop("response_schema")
+
+ # Check for any 'litellm_param_*' set during optional param mapping
+
+ remove_keys = []
+ for k, v in optional_params.items():
+ if k.startswith("litellm_param_"):
+ litellm_params.update({k: v})
+ remove_keys.append(k)
+
+ optional_params = {k: v for k, v in optional_params.items() if k not in remove_keys}
+
+ try:
+ if custom_llm_provider == "gemini":
+ content = litellm.GoogleAIStudioGeminiConfig()._transform_messages(
+ messages=messages
+ )
+ else:
+ content = litellm.VertexGeminiConfig()._transform_messages(
+ messages=messages
+ )
+ tools: Optional[Tools] = optional_params.pop("tools", None)
+ tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None)
+ safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop(
+ "safety_settings", None
+ ) # type: ignore
+ config_fields = GenerationConfig.__annotations__.keys()
+
+ filtered_params = {
+ k: v for k, v in optional_params.items() if k in config_fields
+ }
+
+ generation_config: Optional[GenerationConfig] = GenerationConfig(
+ **filtered_params
+ )
+ data = RequestBody(contents=content)
+ if system_instructions is not None:
+ data["system_instruction"] = system_instructions
+ if tools is not None:
+ data["tools"] = tools
+ if tool_choice is not None:
+ data["toolConfig"] = tool_choice
+ if safety_settings is not None:
+ data["safetySettings"] = safety_settings
+ if generation_config is not None:
+ data["generationConfig"] = generation_config
+ if cached_content is not None:
+ data["cachedContent"] = cached_content
+ except Exception as e:
+ raise e
+
+ return data
+
+
+def sync_transform_request_body(
+ gemini_api_key: Optional[str],
+ messages: List[AllMessageValues],
+ api_base: Optional[str],
+ model: str,
+ client: Optional[HTTPHandler],
+ timeout: Optional[Union[float, httpx.Timeout]],
+ extra_headers: Optional[dict],
+ optional_params: dict,
+ logging_obj: LiteLLMLoggingObj,
+ custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
+ litellm_params: dict,
+) -> RequestBody:
+ from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints
+
+ context_caching_endpoints = ContextCachingEndpoints()
+
+ if gemini_api_key is not None:
+ messages, cached_content = context_caching_endpoints.check_and_create_cache(
+ messages=messages,
+ api_key=gemini_api_key,
+ api_base=api_base,
+ model=model,
+ client=client,
+ timeout=timeout,
+ extra_headers=extra_headers,
+ cached_content=optional_params.pop("cached_content", None),
+ logging_obj=logging_obj,
+ )
+ else: # [TODO] implement context caching for gemini as well
+ cached_content = optional_params.pop("cached_content", None)
+
+ return _transform_request_body(
+ messages=messages,
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ litellm_params=litellm_params,
+ cached_content=cached_content,
+ optional_params=optional_params,
+ )
+
+
+async def async_transform_request_body(
+ gemini_api_key: Optional[str],
+ messages: List[AllMessageValues],
+ api_base: Optional[str],
+ model: str,
+ client: Optional[AsyncHTTPHandler],
+ timeout: Optional[Union[float, httpx.Timeout]],
+ extra_headers: Optional[dict],
+ optional_params: dict,
+ logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, # type: ignore
+ custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
+ litellm_params: dict,
+) -> RequestBody:
+ from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints
+
+ context_caching_endpoints = ContextCachingEndpoints()
+
+ if gemini_api_key is not None:
+ messages, cached_content = (
+ await context_caching_endpoints.async_check_and_create_cache(
+ messages=messages,
+ api_key=gemini_api_key,
+ api_base=api_base,
+ model=model,
+ client=client,
+ timeout=timeout,
+ extra_headers=extra_headers,
+ cached_content=optional_params.pop("cached_content", None),
+ logging_obj=logging_obj,
+ )
+ )
+ else: # [TODO] implement context caching for gemini as well
+ cached_content = optional_params.pop("cached_content", None)
+
+ return _transform_request_body(
+ messages=messages,
+ model=model,
+ custom_llm_provider=custom_llm_provider,
+ litellm_params=litellm_params,
+ cached_content=cached_content,
+ optional_params=optional_params,
+ )
+
+
+def _transform_system_message(
+ supports_system_message: bool, messages: List[AllMessageValues]
+) -> Tuple[Optional[SystemInstructions], List[AllMessageValues]]:
+ """
+ Extracts the system message from the openai message list.
+
+ Converts the system message to Gemini format
+
+ Returns
+ - system_content_blocks: Optional[SystemInstructions] - the system message list in Gemini format.
+ - messages: List[AllMessageValues] - filtered list of messages in OpenAI format (transformed separately)
+ """
+ # Separate system prompt from rest of message
+ system_prompt_indices = []
+ system_content_blocks: List[PartType] = []
+ if supports_system_message is True:
+ for idx, message in enumerate(messages):
+ if message["role"] == "system":
+ _system_content_block: Optional[PartType] = None
+ if isinstance(message["content"], str):
+ _system_content_block = PartType(text=message["content"])
+ elif isinstance(message["content"], list):
+ system_text = ""
+ for content in message["content"]:
+ system_text += content.get("text") or ""
+ _system_content_block = PartType(text=system_text)
+ if _system_content_block is not None:
+ system_content_blocks.append(_system_content_block)
+ system_prompt_indices.append(idx)
+ if len(system_prompt_indices) > 0:
+ for idx in reversed(system_prompt_indices):
+ messages.pop(idx)
+
+ if len(system_content_blocks) > 0:
+ return SystemInstructions(parts=system_content_blocks), messages
+
+ return None, messages