diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py | 110 |
1 files changed, 110 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py new file mode 100644 index 00000000..83c15029 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py @@ -0,0 +1,110 @@ +""" +Transformation logic for context caching. + +Why separate file? Make it easy to see how transformation works +""" + +from typing import List, Tuple + +from litellm.types.llms.openai import AllMessageValues +from litellm.types.llms.vertex_ai import CachedContentRequestBody +from litellm.utils import is_cached_message + +from ..common_utils import get_supports_system_message +from ..gemini.transformation import ( + _gemini_convert_messages_with_history, + _transform_system_message, +) + + +def get_first_continuous_block_idx( + filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message) +) -> int: + """ + Find the array index that ends the first continuous sequence of message blocks. + + Args: + filtered_messages: List of tuples containing (index, message) pairs + + Returns: + int: The array index where the first continuous sequence ends + """ + if not filtered_messages: + return -1 + + if len(filtered_messages) == 1: + return 0 + + current_value = filtered_messages[0][0] + + # Search forward through the array indices + for i in range(1, len(filtered_messages)): + if filtered_messages[i][0] != current_value + 1: + return i - 1 + current_value = filtered_messages[i][0] + + # If we made it through the whole list, return the last index + return len(filtered_messages) - 1 + + +def separate_cached_messages( + messages: List[AllMessageValues], +) -> Tuple[List[AllMessageValues], List[AllMessageValues]]: + """ + Returns separated cached and non-cached messages. + + Args: + messages: List of messages to be separated. + + Returns: + Tuple containing: + - cached_messages: List of cached messages. + - non_cached_messages: List of non-cached messages. + """ + cached_messages: List[AllMessageValues] = [] + non_cached_messages: List[AllMessageValues] = [] + + # Extract cached messages and their indices + filtered_messages: List[Tuple[int, AllMessageValues]] = [] + for idx, message in enumerate(messages): + if is_cached_message(message=message): + filtered_messages.append((idx, message)) + + # Validate only one block of continuous cached messages + last_continuous_block_idx = get_first_continuous_block_idx(filtered_messages) + # Separate messages based on the block of cached messages + if filtered_messages and last_continuous_block_idx is not None: + first_cached_idx = filtered_messages[0][0] + last_cached_idx = filtered_messages[last_continuous_block_idx][0] + + cached_messages = messages[first_cached_idx : last_cached_idx + 1] + non_cached_messages = ( + messages[:first_cached_idx] + messages[last_cached_idx + 1 :] + ) + else: + non_cached_messages = messages + + return cached_messages, non_cached_messages + + +def transform_openai_messages_to_gemini_context_caching( + model: str, messages: List[AllMessageValues], cache_key: str +) -> CachedContentRequestBody: + supports_system_message = get_supports_system_message( + model=model, custom_llm_provider="gemini" + ) + + transformed_system_messages, new_messages = _transform_system_message( + supports_system_message=supports_system_message, messages=messages + ) + + transformed_messages = _gemini_convert_messages_with_history(messages=new_messages) + data = CachedContentRequestBody( + contents=transformed_messages, + model="models/{}".format(model), + displayName=cache_key, + ) + if transformed_system_messages is not None: + data["system_instruction"] = transformed_system_messages + + return data |