about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py110
1 files changed, 110 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py
new file mode 100644
index 00000000..83c15029
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py
@@ -0,0 +1,110 @@
+"""
+Transformation logic for context caching. 
+
+Why separate file? Make it easy to see how transformation works
+"""
+
+from typing import List, Tuple
+
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.llms.vertex_ai import CachedContentRequestBody
+from litellm.utils import is_cached_message
+
+from ..common_utils import get_supports_system_message
+from ..gemini.transformation import (
+    _gemini_convert_messages_with_history,
+    _transform_system_message,
+)
+
+
+def get_first_continuous_block_idx(
+    filtered_messages: List[Tuple[int, AllMessageValues]]  # (idx, message)
+) -> int:
+    """
+    Find the array index that ends the first continuous sequence of message blocks.
+
+    Args:
+        filtered_messages: List of tuples containing (index, message) pairs
+
+    Returns:
+        int: The array index where the first continuous sequence ends
+    """
+    if not filtered_messages:
+        return -1
+
+    if len(filtered_messages) == 1:
+        return 0
+
+    current_value = filtered_messages[0][0]
+
+    # Search forward through the array indices
+    for i in range(1, len(filtered_messages)):
+        if filtered_messages[i][0] != current_value + 1:
+            return i - 1
+        current_value = filtered_messages[i][0]
+
+    # If we made it through the whole list, return the last index
+    return len(filtered_messages) - 1
+
+
+def separate_cached_messages(
+    messages: List[AllMessageValues],
+) -> Tuple[List[AllMessageValues], List[AllMessageValues]]:
+    """
+    Returns separated cached and non-cached messages.
+
+    Args:
+        messages: List of messages to be separated.
+
+    Returns:
+        Tuple containing:
+        - cached_messages: List of cached messages.
+        - non_cached_messages: List of non-cached messages.
+    """
+    cached_messages: List[AllMessageValues] = []
+    non_cached_messages: List[AllMessageValues] = []
+
+    # Extract cached messages and their indices
+    filtered_messages: List[Tuple[int, AllMessageValues]] = []
+    for idx, message in enumerate(messages):
+        if is_cached_message(message=message):
+            filtered_messages.append((idx, message))
+
+    # Validate only one block of continuous cached messages
+    last_continuous_block_idx = get_first_continuous_block_idx(filtered_messages)
+    # Separate messages based on the block of cached messages
+    if filtered_messages and last_continuous_block_idx is not None:
+        first_cached_idx = filtered_messages[0][0]
+        last_cached_idx = filtered_messages[last_continuous_block_idx][0]
+
+        cached_messages = messages[first_cached_idx : last_cached_idx + 1]
+        non_cached_messages = (
+            messages[:first_cached_idx] + messages[last_cached_idx + 1 :]
+        )
+    else:
+        non_cached_messages = messages
+
+    return cached_messages, non_cached_messages
+
+
+def transform_openai_messages_to_gemini_context_caching(
+    model: str, messages: List[AllMessageValues], cache_key: str
+) -> CachedContentRequestBody:
+    supports_system_message = get_supports_system_message(
+        model=model, custom_llm_provider="gemini"
+    )
+
+    transformed_system_messages, new_messages = _transform_system_message(
+        supports_system_message=supports_system_message, messages=messages
+    )
+
+    transformed_messages = _gemini_convert_messages_with_history(messages=new_messages)
+    data = CachedContentRequestBody(
+        contents=transformed_messages,
+        model="models/{}".format(model),
+        displayName=cache_key,
+    )
+    if transformed_system_messages is not None:
+        data["system_instruction"] = transformed_system_messages
+
+    return data