diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py | 479 |
1 files changed, 479 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py new file mode 100644 index 00000000..d6bafc7c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/gemini/transformation.py @@ -0,0 +1,479 @@ +""" +Transformation logic from OpenAI format to Gemini format. + +Why separate file? Make it easy to see how transformation works +""" + +import os +from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union, cast + +import httpx +from pydantic import BaseModel + +import litellm +from litellm._logging import verbose_logger +from litellm.litellm_core_utils.prompt_templates.factory import ( + convert_to_anthropic_image_obj, + convert_to_gemini_tool_call_invoke, + convert_to_gemini_tool_call_result, + response_schema_prompt, +) +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.files import ( + get_file_mime_type_for_file_type, + get_file_type_from_extension, + is_gemini_1_5_accepted_file_type, +) +from litellm.types.llms.openai import ( + AllMessageValues, + ChatCompletionAssistantMessage, + ChatCompletionImageObject, + ChatCompletionTextObject, +) +from litellm.types.llms.vertex_ai import * +from litellm.types.llms.vertex_ai import ( + GenerationConfig, + PartType, + RequestBody, + SafetSettingsConfig, + SystemInstructions, + ToolConfig, + Tools, +) + +from ..common_utils import ( + _check_text_in_content, + get_supports_response_schema, + get_supports_system_message, +) + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +def _process_gemini_image(image_url: str, format: Optional[str] = None) -> PartType: + """ + Given an image URL, return the appropriate PartType for Gemini + """ + + try: + # GCS URIs + if "gs://" in image_url: + # Figure out file type + extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png" + extension = extension_with_dot[1:] # Ex: "png" + + if not format: + file_type = get_file_type_from_extension(extension) + + # Validate the file type is supported by Gemini + if not is_gemini_1_5_accepted_file_type(file_type): + raise Exception(f"File type not supported by gemini - {file_type}") + + mime_type = get_file_mime_type_for_file_type(file_type) + else: + mime_type = format + file_data = FileDataType(mime_type=mime_type, file_uri=image_url) + + return PartType(file_data=file_data) + elif ( + "https://" in image_url + and (image_type := format or _get_image_mime_type_from_url(image_url)) + is not None + ): + + file_data = FileDataType(file_uri=image_url, mime_type=image_type) + return PartType(file_data=file_data) + elif "http://" in image_url or "https://" in image_url or "base64" in image_url: + # https links for unsupported mime types and base64 images + image = convert_to_anthropic_image_obj(image_url, format=format) + _blob = BlobType(data=image["data"], mime_type=image["media_type"]) + return PartType(inline_data=_blob) + raise Exception("Invalid image received - {}".format(image_url)) + except Exception as e: + raise e + + +def _get_image_mime_type_from_url(url: str) -> Optional[str]: + """ + Get mime type for common image URLs + See gemini mime types: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/image-understanding#image-requirements + + Supported by Gemini: + - PNG (`image/png`) + - JPEG (`image/jpeg`) + - WebP (`image/webp`) + Example: + url = https://example.com/image.jpg + Returns: image/jpeg + """ + url = url.lower() + if url.endswith((".jpg", ".jpeg")): + return "image/jpeg" + elif url.endswith(".png"): + return "image/png" + elif url.endswith(".webp"): + return "image/webp" + elif url.endswith(".mp4"): + return "video/mp4" + elif url.endswith(".pdf"): + return "application/pdf" + return None + + +def _gemini_convert_messages_with_history( # noqa: PLR0915 + messages: List[AllMessageValues], +) -> List[ContentType]: + """ + Converts given messages from OpenAI format to Gemini format + + - Parts must be iterable + - Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles) + - Please ensure that function response turn comes immediately after a function call turn + """ + user_message_types = {"user", "system"} + contents: List[ContentType] = [] + + last_message_with_tool_calls = None + + msg_i = 0 + tool_call_responses = [] + try: + while msg_i < len(messages): + user_content: List[PartType] = [] + init_msg_i = msg_i + ## MERGE CONSECUTIVE USER CONTENT ## + while ( + msg_i < len(messages) and messages[msg_i]["role"] in user_message_types + ): + _message_content = messages[msg_i].get("content") + if _message_content is not None and isinstance(_message_content, list): + _parts: List[PartType] = [] + for element in _message_content: + if ( + element["type"] == "text" + and "text" in element + and len(element["text"]) > 0 + ): + element = cast(ChatCompletionTextObject, element) + _part = PartType(text=element["text"]) + _parts.append(_part) + elif element["type"] == "image_url": + element = cast(ChatCompletionImageObject, element) + img_element = element + format: Optional[str] = None + if isinstance(img_element["image_url"], dict): + image_url = img_element["image_url"]["url"] + format = img_element["image_url"].get("format") + else: + image_url = img_element["image_url"] + _part = _process_gemini_image( + image_url=image_url, format=format + ) + _parts.append(_part) + user_content.extend(_parts) + elif ( + _message_content is not None + and isinstance(_message_content, str) + and len(_message_content) > 0 + ): + _part = PartType(text=_message_content) + user_content.append(_part) + + msg_i += 1 + + if user_content: + """ + check that user_content has 'text' parameter. + - Known Vertex Error: Unable to submit request because it must have a text parameter. + - Relevant Issue: https://github.com/BerriAI/litellm/issues/5515 + """ + has_text_in_content = _check_text_in_content(user_content) + if has_text_in_content is False: + verbose_logger.warning( + "No text in user content. Adding a blank text to user content, to ensure Gemini doesn't fail the request. Relevant Issue - https://github.com/BerriAI/litellm/issues/5515" + ) + user_content.append( + PartType(text=" ") + ) # add a blank text, to ensure Gemini doesn't fail the request. + contents.append(ContentType(role="user", parts=user_content)) + assistant_content = [] + ## MERGE CONSECUTIVE ASSISTANT CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": + if isinstance(messages[msg_i], BaseModel): + msg_dict: Union[ChatCompletionAssistantMessage, dict] = messages[msg_i].model_dump() # type: ignore + else: + msg_dict = messages[msg_i] # type: ignore + assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore + _message_content = assistant_msg.get("content", None) + if _message_content is not None and isinstance(_message_content, list): + _parts = [] + for element in _message_content: + if isinstance(element, dict): + if element["type"] == "text": + _part = PartType(text=element["text"]) + _parts.append(_part) + assistant_content.extend(_parts) + elif ( + _message_content is not None + and isinstance(_message_content, str) + and _message_content + ): + assistant_text = _message_content # either string or none + assistant_content.append(PartType(text=assistant_text)) # type: ignore + + ## HANDLE ASSISTANT FUNCTION CALL + if ( + assistant_msg.get("tool_calls", []) is not None + or assistant_msg.get("function_call") is not None + ): # support assistant tool invoke conversion + assistant_content.extend( + convert_to_gemini_tool_call_invoke(assistant_msg) + ) + last_message_with_tool_calls = assistant_msg + + msg_i += 1 + + if assistant_content: + contents.append(ContentType(role="model", parts=assistant_content)) + + ## APPEND TOOL CALL MESSAGES ## + tool_call_message_roles = ["tool", "function"] + if ( + msg_i < len(messages) + and messages[msg_i]["role"] in tool_call_message_roles + ): + _part = convert_to_gemini_tool_call_result( + messages[msg_i], last_message_with_tool_calls # type: ignore + ) + msg_i += 1 + tool_call_responses.append(_part) + if msg_i < len(messages) and ( + messages[msg_i]["role"] not in tool_call_message_roles + ): + if len(tool_call_responses) > 0: + contents.append(ContentType(parts=tool_call_responses)) + tool_call_responses = [] + + if msg_i == init_msg_i: # prevent infinite loops + raise Exception( + "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format( + messages[msg_i] + ) + ) + if len(tool_call_responses) > 0: + contents.append(ContentType(parts=tool_call_responses)) + return contents + except Exception as e: + raise e + + +def _transform_request_body( + messages: List[AllMessageValues], + model: str, + optional_params: dict, + custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"], + litellm_params: dict, + cached_content: Optional[str], +) -> RequestBody: + """ + Common transformation logic across sync + async Gemini /generateContent calls. + """ + # Separate system prompt from rest of message + supports_system_message = get_supports_system_message( + model=model, custom_llm_provider=custom_llm_provider + ) + system_instructions, messages = _transform_system_message( + supports_system_message=supports_system_message, messages=messages + ) + # Checks for 'response_schema' support - if passed in + if "response_schema" in optional_params: + supports_response_schema = get_supports_response_schema( + model=model, custom_llm_provider=custom_llm_provider + ) + if supports_response_schema is False: + user_response_schema_message = response_schema_prompt( + model=model, response_schema=optional_params.get("response_schema") # type: ignore + ) + messages.append({"role": "user", "content": user_response_schema_message}) + optional_params.pop("response_schema") + + # Check for any 'litellm_param_*' set during optional param mapping + + remove_keys = [] + for k, v in optional_params.items(): + if k.startswith("litellm_param_"): + litellm_params.update({k: v}) + remove_keys.append(k) + + optional_params = {k: v for k, v in optional_params.items() if k not in remove_keys} + + try: + if custom_llm_provider == "gemini": + content = litellm.GoogleAIStudioGeminiConfig()._transform_messages( + messages=messages + ) + else: + content = litellm.VertexGeminiConfig()._transform_messages( + messages=messages + ) + tools: Optional[Tools] = optional_params.pop("tools", None) + tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) + safety_settings: Optional[List[SafetSettingsConfig]] = optional_params.pop( + "safety_settings", None + ) # type: ignore + config_fields = GenerationConfig.__annotations__.keys() + + filtered_params = { + k: v for k, v in optional_params.items() if k in config_fields + } + + generation_config: Optional[GenerationConfig] = GenerationConfig( + **filtered_params + ) + data = RequestBody(contents=content) + if system_instructions is not None: + data["system_instruction"] = system_instructions + if tools is not None: + data["tools"] = tools + if tool_choice is not None: + data["toolConfig"] = tool_choice + if safety_settings is not None: + data["safetySettings"] = safety_settings + if generation_config is not None: + data["generationConfig"] = generation_config + if cached_content is not None: + data["cachedContent"] = cached_content + except Exception as e: + raise e + + return data + + +def sync_transform_request_body( + gemini_api_key: Optional[str], + messages: List[AllMessageValues], + api_base: Optional[str], + model: str, + client: Optional[HTTPHandler], + timeout: Optional[Union[float, httpx.Timeout]], + extra_headers: Optional[dict], + optional_params: dict, + logging_obj: LiteLLMLoggingObj, + custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"], + litellm_params: dict, +) -> RequestBody: + from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints + + context_caching_endpoints = ContextCachingEndpoints() + + if gemini_api_key is not None: + messages, cached_content = context_caching_endpoints.check_and_create_cache( + messages=messages, + api_key=gemini_api_key, + api_base=api_base, + model=model, + client=client, + timeout=timeout, + extra_headers=extra_headers, + cached_content=optional_params.pop("cached_content", None), + logging_obj=logging_obj, + ) + else: # [TODO] implement context caching for gemini as well + cached_content = optional_params.pop("cached_content", None) + + return _transform_request_body( + messages=messages, + model=model, + custom_llm_provider=custom_llm_provider, + litellm_params=litellm_params, + cached_content=cached_content, + optional_params=optional_params, + ) + + +async def async_transform_request_body( + gemini_api_key: Optional[str], + messages: List[AllMessageValues], + api_base: Optional[str], + model: str, + client: Optional[AsyncHTTPHandler], + timeout: Optional[Union[float, httpx.Timeout]], + extra_headers: Optional[dict], + optional_params: dict, + logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, # type: ignore + custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"], + litellm_params: dict, +) -> RequestBody: + from ..context_caching.vertex_ai_context_caching import ContextCachingEndpoints + + context_caching_endpoints = ContextCachingEndpoints() + + if gemini_api_key is not None: + messages, cached_content = ( + await context_caching_endpoints.async_check_and_create_cache( + messages=messages, + api_key=gemini_api_key, + api_base=api_base, + model=model, + client=client, + timeout=timeout, + extra_headers=extra_headers, + cached_content=optional_params.pop("cached_content", None), + logging_obj=logging_obj, + ) + ) + else: # [TODO] implement context caching for gemini as well + cached_content = optional_params.pop("cached_content", None) + + return _transform_request_body( + messages=messages, + model=model, + custom_llm_provider=custom_llm_provider, + litellm_params=litellm_params, + cached_content=cached_content, + optional_params=optional_params, + ) + + +def _transform_system_message( + supports_system_message: bool, messages: List[AllMessageValues] +) -> Tuple[Optional[SystemInstructions], List[AllMessageValues]]: + """ + Extracts the system message from the openai message list. + + Converts the system message to Gemini format + + Returns + - system_content_blocks: Optional[SystemInstructions] - the system message list in Gemini format. + - messages: List[AllMessageValues] - filtered list of messages in OpenAI format (transformed separately) + """ + # Separate system prompt from rest of message + system_prompt_indices = [] + system_content_blocks: List[PartType] = [] + if supports_system_message is True: + for idx, message in enumerate(messages): + if message["role"] == "system": + _system_content_block: Optional[PartType] = None + if isinstance(message["content"], str): + _system_content_block = PartType(text=message["content"]) + elif isinstance(message["content"], list): + system_text = "" + for content in message["content"]: + system_text += content.get("text") or "" + _system_content_block = PartType(text=system_text) + if _system_content_block is not None: + system_content_blocks.append(_system_content_block) + system_prompt_indices.append(idx) + if len(system_prompt_indices) > 0: + for idx in reversed(system_prompt_indices): + messages.pop(idx) + + if len(system_content_blocks) > 0: + return SystemInstructions(parts=system_content_blocks), messages + + return None, messages |