aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py110
1 files changed, 110 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py
new file mode 100644
index 00000000..83c15029
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py
@@ -0,0 +1,110 @@
+"""
+Transformation logic for context caching.
+
+Why separate file? Make it easy to see how transformation works
+"""
+
+from typing import List, Tuple
+
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.llms.vertex_ai import CachedContentRequestBody
+from litellm.utils import is_cached_message
+
+from ..common_utils import get_supports_system_message
+from ..gemini.transformation import (
+ _gemini_convert_messages_with_history,
+ _transform_system_message,
+)
+
+
+def get_first_continuous_block_idx(
+ filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message)
+) -> int:
+ """
+ Find the array index that ends the first continuous sequence of message blocks.
+
+ Args:
+ filtered_messages: List of tuples containing (index, message) pairs
+
+ Returns:
+ int: The array index where the first continuous sequence ends
+ """
+ if not filtered_messages:
+ return -1
+
+ if len(filtered_messages) == 1:
+ return 0
+
+ current_value = filtered_messages[0][0]
+
+ # Search forward through the array indices
+ for i in range(1, len(filtered_messages)):
+ if filtered_messages[i][0] != current_value + 1:
+ return i - 1
+ current_value = filtered_messages[i][0]
+
+ # If we made it through the whole list, return the last index
+ return len(filtered_messages) - 1
+
+
+def separate_cached_messages(
+ messages: List[AllMessageValues],
+) -> Tuple[List[AllMessageValues], List[AllMessageValues]]:
+ """
+ Returns separated cached and non-cached messages.
+
+ Args:
+ messages: List of messages to be separated.
+
+ Returns:
+ Tuple containing:
+ - cached_messages: List of cached messages.
+ - non_cached_messages: List of non-cached messages.
+ """
+ cached_messages: List[AllMessageValues] = []
+ non_cached_messages: List[AllMessageValues] = []
+
+ # Extract cached messages and their indices
+ filtered_messages: List[Tuple[int, AllMessageValues]] = []
+ for idx, message in enumerate(messages):
+ if is_cached_message(message=message):
+ filtered_messages.append((idx, message))
+
+ # Validate only one block of continuous cached messages
+ last_continuous_block_idx = get_first_continuous_block_idx(filtered_messages)
+ # Separate messages based on the block of cached messages
+ if filtered_messages and last_continuous_block_idx is not None:
+ first_cached_idx = filtered_messages[0][0]
+ last_cached_idx = filtered_messages[last_continuous_block_idx][0]
+
+ cached_messages = messages[first_cached_idx : last_cached_idx + 1]
+ non_cached_messages = (
+ messages[:first_cached_idx] + messages[last_cached_idx + 1 :]
+ )
+ else:
+ non_cached_messages = messages
+
+ return cached_messages, non_cached_messages
+
+
+def transform_openai_messages_to_gemini_context_caching(
+ model: str, messages: List[AllMessageValues], cache_key: str
+) -> CachedContentRequestBody:
+ supports_system_message = get_supports_system_message(
+ model=model, custom_llm_provider="gemini"
+ )
+
+ transformed_system_messages, new_messages = _transform_system_message(
+ supports_system_message=supports_system_message, messages=messages
+ )
+
+ transformed_messages = _gemini_convert_messages_with_history(messages=new_messages)
+ data = CachedContentRequestBody(
+ contents=transformed_messages,
+ model="models/{}".format(model),
+ displayName=cache_key,
+ )
+ if transformed_system_messages is not None:
+ data["system_instruction"] = transformed_system_messages
+
+ return data