aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py110
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py416
2 files changed, 526 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py
new file mode 100644
index 00000000..83c15029
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py
@@ -0,0 +1,110 @@
+"""
+Transformation logic for context caching.
+
+Why separate file? Make it easy to see how transformation works
+"""
+
+from typing import List, Tuple
+
+from litellm.types.llms.openai import AllMessageValues
+from litellm.types.llms.vertex_ai import CachedContentRequestBody
+from litellm.utils import is_cached_message
+
+from ..common_utils import get_supports_system_message
+from ..gemini.transformation import (
+ _gemini_convert_messages_with_history,
+ _transform_system_message,
+)
+
+
+def get_first_continuous_block_idx(
+ filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message)
+) -> int:
+ """
+ Find the array index that ends the first continuous sequence of message blocks.
+
+ Args:
+ filtered_messages: List of tuples containing (index, message) pairs
+
+ Returns:
+ int: The array index where the first continuous sequence ends
+ """
+ if not filtered_messages:
+ return -1
+
+ if len(filtered_messages) == 1:
+ return 0
+
+ current_value = filtered_messages[0][0]
+
+ # Search forward through the array indices
+ for i in range(1, len(filtered_messages)):
+ if filtered_messages[i][0] != current_value + 1:
+ return i - 1
+ current_value = filtered_messages[i][0]
+
+ # If we made it through the whole list, return the last index
+ return len(filtered_messages) - 1
+
+
+def separate_cached_messages(
+ messages: List[AllMessageValues],
+) -> Tuple[List[AllMessageValues], List[AllMessageValues]]:
+ """
+ Returns separated cached and non-cached messages.
+
+ Args:
+ messages: List of messages to be separated.
+
+ Returns:
+ Tuple containing:
+ - cached_messages: List of cached messages.
+ - non_cached_messages: List of non-cached messages.
+ """
+ cached_messages: List[AllMessageValues] = []
+ non_cached_messages: List[AllMessageValues] = []
+
+ # Extract cached messages and their indices
+ filtered_messages: List[Tuple[int, AllMessageValues]] = []
+ for idx, message in enumerate(messages):
+ if is_cached_message(message=message):
+ filtered_messages.append((idx, message))
+
+ # Validate only one block of continuous cached messages
+ last_continuous_block_idx = get_first_continuous_block_idx(filtered_messages)
+ # Separate messages based on the block of cached messages
+ if filtered_messages and last_continuous_block_idx is not None:
+ first_cached_idx = filtered_messages[0][0]
+ last_cached_idx = filtered_messages[last_continuous_block_idx][0]
+
+ cached_messages = messages[first_cached_idx : last_cached_idx + 1]
+ non_cached_messages = (
+ messages[:first_cached_idx] + messages[last_cached_idx + 1 :]
+ )
+ else:
+ non_cached_messages = messages
+
+ return cached_messages, non_cached_messages
+
+
+def transform_openai_messages_to_gemini_context_caching(
+ model: str, messages: List[AllMessageValues], cache_key: str
+) -> CachedContentRequestBody:
+ supports_system_message = get_supports_system_message(
+ model=model, custom_llm_provider="gemini"
+ )
+
+ transformed_system_messages, new_messages = _transform_system_message(
+ supports_system_message=supports_system_message, messages=messages
+ )
+
+ transformed_messages = _gemini_convert_messages_with_history(messages=new_messages)
+ data = CachedContentRequestBody(
+ contents=transformed_messages,
+ model="models/{}".format(model),
+ displayName=cache_key,
+ )
+ if transformed_system_messages is not None:
+ data["system_instruction"] = transformed_system_messages
+
+ return data
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py
new file mode 100644
index 00000000..5cfb9141
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py
@@ -0,0 +1,416 @@
+from typing import List, Literal, Optional, Tuple, Union
+
+import httpx
+
+import litellm
+from litellm.caching.caching import Cache, LiteLLMCacheType
+from litellm.litellm_core_utils.litellm_logging import Logging
+from litellm.llms.custom_httpx.http_handler import (
+ AsyncHTTPHandler,
+ HTTPHandler,
+ get_async_httpx_client,
+)
+from litellm.llms.openai.openai import AllMessageValues
+from litellm.types.llms.vertex_ai import (
+ CachedContentListAllResponseBody,
+ VertexAICachedContentResponseObject,
+)
+
+from ..common_utils import VertexAIError
+from ..vertex_llm_base import VertexBase
+from .transformation import (
+ separate_cached_messages,
+ transform_openai_messages_to_gemini_context_caching,
+)
+
+local_cache_obj = Cache(
+ type=LiteLLMCacheType.LOCAL
+) # only used for calling 'get_cache_key' function
+
+
+class ContextCachingEndpoints(VertexBase):
+ """
+ Covers context caching endpoints for Vertex AI + Google AI Studio
+
+ v0: covers Google AI Studio
+ """
+
+ def __init__(self) -> None:
+ pass
+
+ def _get_token_and_url_context_caching(
+ self,
+ gemini_api_key: Optional[str],
+ custom_llm_provider: Literal["gemini"],
+ api_base: Optional[str],
+ ) -> Tuple[Optional[str], str]:
+ """
+ Internal function. Returns the token and url for the call.
+
+ Handles logic if it's google ai studio vs. vertex ai.
+
+ Returns
+ token, url
+ """
+ if custom_llm_provider == "gemini":
+ auth_header = None
+ endpoint = "cachedContents"
+ url = "https://generativelanguage.googleapis.com/v1beta/{}?key={}".format(
+ endpoint, gemini_api_key
+ )
+
+ else:
+ raise NotImplementedError
+
+ return self._check_custom_proxy(
+ api_base=api_base,
+ custom_llm_provider=custom_llm_provider,
+ gemini_api_key=gemini_api_key,
+ endpoint=endpoint,
+ stream=None,
+ auth_header=auth_header,
+ url=url,
+ )
+
+ def check_cache(
+ self,
+ cache_key: str,
+ client: HTTPHandler,
+ headers: dict,
+ api_key: str,
+ api_base: Optional[str],
+ logging_obj: Logging,
+ ) -> Optional[str]:
+ """
+ Checks if content already cached.
+
+ Currently, checks cache list, for cache key == displayName, since Google doesn't let us set the name of the cache (their API docs are out of sync with actual implementation).
+
+ Returns
+ - cached_content_name - str - cached content name stored on google. (if found.)
+ OR
+ - None
+ """
+
+ _, url = self._get_token_and_url_context_caching(
+ gemini_api_key=api_key,
+ custom_llm_provider="gemini",
+ api_base=api_base,
+ )
+ try:
+ ## LOGGING
+ logging_obj.pre_call(
+ input="",
+ api_key="",
+ additional_args={
+ "complete_input_dict": {},
+ "api_base": url,
+ "headers": headers,
+ },
+ )
+
+ resp = client.get(url=url, headers=headers)
+ resp.raise_for_status()
+ except httpx.HTTPStatusError as e:
+ if e.response.status_code == 403:
+ return None
+ raise VertexAIError(
+ status_code=e.response.status_code, message=e.response.text
+ )
+ except Exception as e:
+ raise VertexAIError(status_code=500, message=str(e))
+ raw_response = resp.json()
+ logging_obj.post_call(original_response=raw_response)
+
+ if "cachedContents" not in raw_response:
+ return None
+
+ all_cached_items = CachedContentListAllResponseBody(**raw_response)
+
+ if "cachedContents" not in all_cached_items:
+ return None
+
+ for cached_item in all_cached_items["cachedContents"]:
+ display_name = cached_item.get("displayName")
+ if display_name is not None and display_name == cache_key:
+ return cached_item.get("name")
+
+ return None
+
+ async def async_check_cache(
+ self,
+ cache_key: str,
+ client: AsyncHTTPHandler,
+ headers: dict,
+ api_key: str,
+ api_base: Optional[str],
+ logging_obj: Logging,
+ ) -> Optional[str]:
+ """
+ Checks if content already cached.
+
+ Currently, checks cache list, for cache key == displayName, since Google doesn't let us set the name of the cache (their API docs are out of sync with actual implementation).
+
+ Returns
+ - cached_content_name - str - cached content name stored on google. (if found.)
+ OR
+ - None
+ """
+
+ _, url = self._get_token_and_url_context_caching(
+ gemini_api_key=api_key,
+ custom_llm_provider="gemini",
+ api_base=api_base,
+ )
+ try:
+ ## LOGGING
+ logging_obj.pre_call(
+ input="",
+ api_key="",
+ additional_args={
+ "complete_input_dict": {},
+ "api_base": url,
+ "headers": headers,
+ },
+ )
+
+ resp = await client.get(url=url, headers=headers)
+ resp.raise_for_status()
+ except httpx.HTTPStatusError as e:
+ if e.response.status_code == 403:
+ return None
+ raise VertexAIError(
+ status_code=e.response.status_code, message=e.response.text
+ )
+ except Exception as e:
+ raise VertexAIError(status_code=500, message=str(e))
+ raw_response = resp.json()
+ logging_obj.post_call(original_response=raw_response)
+
+ if "cachedContents" not in raw_response:
+ return None
+
+ all_cached_items = CachedContentListAllResponseBody(**raw_response)
+
+ if "cachedContents" not in all_cached_items:
+ return None
+
+ for cached_item in all_cached_items["cachedContents"]:
+ display_name = cached_item.get("displayName")
+ if display_name is not None and display_name == cache_key:
+ return cached_item.get("name")
+
+ return None
+
+ def check_and_create_cache(
+ self,
+ messages: List[AllMessageValues], # receives openai format messages
+ api_key: str,
+ api_base: Optional[str],
+ model: str,
+ client: Optional[HTTPHandler],
+ timeout: Optional[Union[float, httpx.Timeout]],
+ logging_obj: Logging,
+ extra_headers: Optional[dict] = None,
+ cached_content: Optional[str] = None,
+ ) -> Tuple[List[AllMessageValues], Optional[str]]:
+ """
+ Receives
+ - messages: List of dict - messages in the openai format
+
+ Returns
+ - messages - List[dict] - filtered list of messages in the openai format.
+ - cached_content - str - the cache content id, to be passed in the gemini request body
+
+ Follows - https://ai.google.dev/api/caching#request-body
+ """
+ if cached_content is not None:
+ return messages, cached_content
+
+ ## AUTHORIZATION ##
+ token, url = self._get_token_and_url_context_caching(
+ gemini_api_key=api_key,
+ custom_llm_provider="gemini",
+ api_base=api_base,
+ )
+
+ headers = {
+ "Content-Type": "application/json",
+ }
+ if token is not None:
+ headers["Authorization"] = f"Bearer {token}"
+ if extra_headers is not None:
+ headers.update(extra_headers)
+
+ if client is None or not isinstance(client, HTTPHandler):
+ _params = {}
+ if timeout is not None:
+ if isinstance(timeout, float) or isinstance(timeout, int):
+ timeout = httpx.Timeout(timeout)
+ _params["timeout"] = timeout
+ client = HTTPHandler(**_params) # type: ignore
+ else:
+ client = client
+
+ cached_messages, non_cached_messages = separate_cached_messages(
+ messages=messages
+ )
+
+ if len(cached_messages) == 0:
+ return messages, None
+
+ ## CHECK IF CACHED ALREADY
+ generated_cache_key = local_cache_obj.get_cache_key(messages=cached_messages)
+ google_cache_name = self.check_cache(
+ cache_key=generated_cache_key,
+ client=client,
+ headers=headers,
+ api_key=api_key,
+ api_base=api_base,
+ logging_obj=logging_obj,
+ )
+ if google_cache_name:
+ return non_cached_messages, google_cache_name
+
+ ## TRANSFORM REQUEST
+ cached_content_request_body = (
+ transform_openai_messages_to_gemini_context_caching(
+ model=model, messages=cached_messages, cache_key=generated_cache_key
+ )
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=messages,
+ api_key="",
+ additional_args={
+ "complete_input_dict": cached_content_request_body,
+ "api_base": url,
+ "headers": headers,
+ },
+ )
+
+ try:
+ response = client.post(
+ url=url, headers=headers, json=cached_content_request_body # type: ignore
+ )
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise VertexAIError(status_code=error_code, message=err.response.text)
+ except httpx.TimeoutException:
+ raise VertexAIError(status_code=408, message="Timeout error occurred.")
+
+ raw_response_cached = response.json()
+ cached_content_response_obj = VertexAICachedContentResponseObject(
+ name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
+ )
+ return (non_cached_messages, cached_content_response_obj["name"])
+
+ async def async_check_and_create_cache(
+ self,
+ messages: List[AllMessageValues], # receives openai format messages
+ api_key: str,
+ api_base: Optional[str],
+ model: str,
+ client: Optional[AsyncHTTPHandler],
+ timeout: Optional[Union[float, httpx.Timeout]],
+ logging_obj: Logging,
+ extra_headers: Optional[dict] = None,
+ cached_content: Optional[str] = None,
+ ) -> Tuple[List[AllMessageValues], Optional[str]]:
+ """
+ Receives
+ - messages: List of dict - messages in the openai format
+
+ Returns
+ - messages - List[dict] - filtered list of messages in the openai format.
+ - cached_content - str - the cache content id, to be passed in the gemini request body
+
+ Follows - https://ai.google.dev/api/caching#request-body
+ """
+ if cached_content is not None:
+ return messages, cached_content
+
+ cached_messages, non_cached_messages = separate_cached_messages(
+ messages=messages
+ )
+
+ if len(cached_messages) == 0:
+ return messages, None
+
+ ## AUTHORIZATION ##
+ token, url = self._get_token_and_url_context_caching(
+ gemini_api_key=api_key,
+ custom_llm_provider="gemini",
+ api_base=api_base,
+ )
+
+ headers = {
+ "Content-Type": "application/json",
+ }
+ if token is not None:
+ headers["Authorization"] = f"Bearer {token}"
+ if extra_headers is not None:
+ headers.update(extra_headers)
+
+ if client is None or not isinstance(client, AsyncHTTPHandler):
+ client = get_async_httpx_client(
+ params={"timeout": timeout}, llm_provider=litellm.LlmProviders.VERTEX_AI
+ )
+ else:
+ client = client
+
+ ## CHECK IF CACHED ALREADY
+ generated_cache_key = local_cache_obj.get_cache_key(messages=cached_messages)
+ google_cache_name = await self.async_check_cache(
+ cache_key=generated_cache_key,
+ client=client,
+ headers=headers,
+ api_key=api_key,
+ api_base=api_base,
+ logging_obj=logging_obj,
+ )
+ if google_cache_name:
+ return non_cached_messages, google_cache_name
+
+ ## TRANSFORM REQUEST
+ cached_content_request_body = (
+ transform_openai_messages_to_gemini_context_caching(
+ model=model, messages=cached_messages, cache_key=generated_cache_key
+ )
+ )
+
+ ## LOGGING
+ logging_obj.pre_call(
+ input=messages,
+ api_key="",
+ additional_args={
+ "complete_input_dict": cached_content_request_body,
+ "api_base": url,
+ "headers": headers,
+ },
+ )
+
+ try:
+ response = await client.post(
+ url=url, headers=headers, json=cached_content_request_body # type: ignore
+ )
+ response.raise_for_status()
+ except httpx.HTTPStatusError as err:
+ error_code = err.response.status_code
+ raise VertexAIError(status_code=error_code, message=err.response.text)
+ except httpx.TimeoutException:
+ raise VertexAIError(status_code=408, message="Timeout error occurred.")
+
+ raw_response_cached = response.json()
+ cached_content_response_obj = VertexAICachedContentResponseObject(
+ name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
+ )
+ return (non_cached_messages, cached_content_response_obj["name"])
+
+ def get_cache(self):
+ pass
+
+ async def async_get_cache(self):
+ pass