diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching')
2 files changed, 526 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py new file mode 100644 index 00000000..83c15029 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/transformation.py @@ -0,0 +1,110 @@ +""" +Transformation logic for context caching. + +Why separate file? Make it easy to see how transformation works +""" + +from typing import List, Tuple + +from litellm.types.llms.openai import AllMessageValues +from litellm.types.llms.vertex_ai import CachedContentRequestBody +from litellm.utils import is_cached_message + +from ..common_utils import get_supports_system_message +from ..gemini.transformation import ( + _gemini_convert_messages_with_history, + _transform_system_message, +) + + +def get_first_continuous_block_idx( + filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message) +) -> int: + """ + Find the array index that ends the first continuous sequence of message blocks. + + Args: + filtered_messages: List of tuples containing (index, message) pairs + + Returns: + int: The array index where the first continuous sequence ends + """ + if not filtered_messages: + return -1 + + if len(filtered_messages) == 1: + return 0 + + current_value = filtered_messages[0][0] + + # Search forward through the array indices + for i in range(1, len(filtered_messages)): + if filtered_messages[i][0] != current_value + 1: + return i - 1 + current_value = filtered_messages[i][0] + + # If we made it through the whole list, return the last index + return len(filtered_messages) - 1 + + +def separate_cached_messages( + messages: List[AllMessageValues], +) -> Tuple[List[AllMessageValues], List[AllMessageValues]]: + """ + Returns separated cached and non-cached messages. + + Args: + messages: List of messages to be separated. + + Returns: + Tuple containing: + - cached_messages: List of cached messages. + - non_cached_messages: List of non-cached messages. + """ + cached_messages: List[AllMessageValues] = [] + non_cached_messages: List[AllMessageValues] = [] + + # Extract cached messages and their indices + filtered_messages: List[Tuple[int, AllMessageValues]] = [] + for idx, message in enumerate(messages): + if is_cached_message(message=message): + filtered_messages.append((idx, message)) + + # Validate only one block of continuous cached messages + last_continuous_block_idx = get_first_continuous_block_idx(filtered_messages) + # Separate messages based on the block of cached messages + if filtered_messages and last_continuous_block_idx is not None: + first_cached_idx = filtered_messages[0][0] + last_cached_idx = filtered_messages[last_continuous_block_idx][0] + + cached_messages = messages[first_cached_idx : last_cached_idx + 1] + non_cached_messages = ( + messages[:first_cached_idx] + messages[last_cached_idx + 1 :] + ) + else: + non_cached_messages = messages + + return cached_messages, non_cached_messages + + +def transform_openai_messages_to_gemini_context_caching( + model: str, messages: List[AllMessageValues], cache_key: str +) -> CachedContentRequestBody: + supports_system_message = get_supports_system_message( + model=model, custom_llm_provider="gemini" + ) + + transformed_system_messages, new_messages = _transform_system_message( + supports_system_message=supports_system_message, messages=messages + ) + + transformed_messages = _gemini_convert_messages_with_history(messages=new_messages) + data = CachedContentRequestBody( + contents=transformed_messages, + model="models/{}".format(model), + displayName=cache_key, + ) + if transformed_system_messages is not None: + data["system_instruction"] = transformed_system_messages + + return data diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py new file mode 100644 index 00000000..5cfb9141 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/context_caching/vertex_ai_context_caching.py @@ -0,0 +1,416 @@ +from typing import List, Literal, Optional, Tuple, Union + +import httpx + +import litellm +from litellm.caching.caching import Cache, LiteLLMCacheType +from litellm.litellm_core_utils.litellm_logging import Logging +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + get_async_httpx_client, +) +from litellm.llms.openai.openai import AllMessageValues +from litellm.types.llms.vertex_ai import ( + CachedContentListAllResponseBody, + VertexAICachedContentResponseObject, +) + +from ..common_utils import VertexAIError +from ..vertex_llm_base import VertexBase +from .transformation import ( + separate_cached_messages, + transform_openai_messages_to_gemini_context_caching, +) + +local_cache_obj = Cache( + type=LiteLLMCacheType.LOCAL +) # only used for calling 'get_cache_key' function + + +class ContextCachingEndpoints(VertexBase): + """ + Covers context caching endpoints for Vertex AI + Google AI Studio + + v0: covers Google AI Studio + """ + + def __init__(self) -> None: + pass + + def _get_token_and_url_context_caching( + self, + gemini_api_key: Optional[str], + custom_llm_provider: Literal["gemini"], + api_base: Optional[str], + ) -> Tuple[Optional[str], str]: + """ + Internal function. Returns the token and url for the call. + + Handles logic if it's google ai studio vs. vertex ai. + + Returns + token, url + """ + if custom_llm_provider == "gemini": + auth_header = None + endpoint = "cachedContents" + url = "https://generativelanguage.googleapis.com/v1beta/{}?key={}".format( + endpoint, gemini_api_key + ) + + else: + raise NotImplementedError + + return self._check_custom_proxy( + api_base=api_base, + custom_llm_provider=custom_llm_provider, + gemini_api_key=gemini_api_key, + endpoint=endpoint, + stream=None, + auth_header=auth_header, + url=url, + ) + + def check_cache( + self, + cache_key: str, + client: HTTPHandler, + headers: dict, + api_key: str, + api_base: Optional[str], + logging_obj: Logging, + ) -> Optional[str]: + """ + Checks if content already cached. + + Currently, checks cache list, for cache key == displayName, since Google doesn't let us set the name of the cache (their API docs are out of sync with actual implementation). + + Returns + - cached_content_name - str - cached content name stored on google. (if found.) + OR + - None + """ + + _, url = self._get_token_and_url_context_caching( + gemini_api_key=api_key, + custom_llm_provider="gemini", + api_base=api_base, + ) + try: + ## LOGGING + logging_obj.pre_call( + input="", + api_key="", + additional_args={ + "complete_input_dict": {}, + "api_base": url, + "headers": headers, + }, + ) + + resp = client.get(url=url, headers=headers) + resp.raise_for_status() + except httpx.HTTPStatusError as e: + if e.response.status_code == 403: + return None + raise VertexAIError( + status_code=e.response.status_code, message=e.response.text + ) + except Exception as e: + raise VertexAIError(status_code=500, message=str(e)) + raw_response = resp.json() + logging_obj.post_call(original_response=raw_response) + + if "cachedContents" not in raw_response: + return None + + all_cached_items = CachedContentListAllResponseBody(**raw_response) + + if "cachedContents" not in all_cached_items: + return None + + for cached_item in all_cached_items["cachedContents"]: + display_name = cached_item.get("displayName") + if display_name is not None and display_name == cache_key: + return cached_item.get("name") + + return None + + async def async_check_cache( + self, + cache_key: str, + client: AsyncHTTPHandler, + headers: dict, + api_key: str, + api_base: Optional[str], + logging_obj: Logging, + ) -> Optional[str]: + """ + Checks if content already cached. + + Currently, checks cache list, for cache key == displayName, since Google doesn't let us set the name of the cache (their API docs are out of sync with actual implementation). + + Returns + - cached_content_name - str - cached content name stored on google. (if found.) + OR + - None + """ + + _, url = self._get_token_and_url_context_caching( + gemini_api_key=api_key, + custom_llm_provider="gemini", + api_base=api_base, + ) + try: + ## LOGGING + logging_obj.pre_call( + input="", + api_key="", + additional_args={ + "complete_input_dict": {}, + "api_base": url, + "headers": headers, + }, + ) + + resp = await client.get(url=url, headers=headers) + resp.raise_for_status() + except httpx.HTTPStatusError as e: + if e.response.status_code == 403: + return None + raise VertexAIError( + status_code=e.response.status_code, message=e.response.text + ) + except Exception as e: + raise VertexAIError(status_code=500, message=str(e)) + raw_response = resp.json() + logging_obj.post_call(original_response=raw_response) + + if "cachedContents" not in raw_response: + return None + + all_cached_items = CachedContentListAllResponseBody(**raw_response) + + if "cachedContents" not in all_cached_items: + return None + + for cached_item in all_cached_items["cachedContents"]: + display_name = cached_item.get("displayName") + if display_name is not None and display_name == cache_key: + return cached_item.get("name") + + return None + + def check_and_create_cache( + self, + messages: List[AllMessageValues], # receives openai format messages + api_key: str, + api_base: Optional[str], + model: str, + client: Optional[HTTPHandler], + timeout: Optional[Union[float, httpx.Timeout]], + logging_obj: Logging, + extra_headers: Optional[dict] = None, + cached_content: Optional[str] = None, + ) -> Tuple[List[AllMessageValues], Optional[str]]: + """ + Receives + - messages: List of dict - messages in the openai format + + Returns + - messages - List[dict] - filtered list of messages in the openai format. + - cached_content - str - the cache content id, to be passed in the gemini request body + + Follows - https://ai.google.dev/api/caching#request-body + """ + if cached_content is not None: + return messages, cached_content + + ## AUTHORIZATION ## + token, url = self._get_token_and_url_context_caching( + gemini_api_key=api_key, + custom_llm_provider="gemini", + api_base=api_base, + ) + + headers = { + "Content-Type": "application/json", + } + if token is not None: + headers["Authorization"] = f"Bearer {token}" + if extra_headers is not None: + headers.update(extra_headers) + + if client is None or not isinstance(client, HTTPHandler): + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = HTTPHandler(**_params) # type: ignore + else: + client = client + + cached_messages, non_cached_messages = separate_cached_messages( + messages=messages + ) + + if len(cached_messages) == 0: + return messages, None + + ## CHECK IF CACHED ALREADY + generated_cache_key = local_cache_obj.get_cache_key(messages=cached_messages) + google_cache_name = self.check_cache( + cache_key=generated_cache_key, + client=client, + headers=headers, + api_key=api_key, + api_base=api_base, + logging_obj=logging_obj, + ) + if google_cache_name: + return non_cached_messages, google_cache_name + + ## TRANSFORM REQUEST + cached_content_request_body = ( + transform_openai_messages_to_gemini_context_caching( + model=model, messages=cached_messages, cache_key=generated_cache_key + ) + ) + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": cached_content_request_body, + "api_base": url, + "headers": headers, + }, + ) + + try: + response = client.post( + url=url, headers=headers, json=cached_content_request_body # type: ignore + ) + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise VertexAIError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException: + raise VertexAIError(status_code=408, message="Timeout error occurred.") + + raw_response_cached = response.json() + cached_content_response_obj = VertexAICachedContentResponseObject( + name=raw_response_cached.get("name"), model=raw_response_cached.get("model") + ) + return (non_cached_messages, cached_content_response_obj["name"]) + + async def async_check_and_create_cache( + self, + messages: List[AllMessageValues], # receives openai format messages + api_key: str, + api_base: Optional[str], + model: str, + client: Optional[AsyncHTTPHandler], + timeout: Optional[Union[float, httpx.Timeout]], + logging_obj: Logging, + extra_headers: Optional[dict] = None, + cached_content: Optional[str] = None, + ) -> Tuple[List[AllMessageValues], Optional[str]]: + """ + Receives + - messages: List of dict - messages in the openai format + + Returns + - messages - List[dict] - filtered list of messages in the openai format. + - cached_content - str - the cache content id, to be passed in the gemini request body + + Follows - https://ai.google.dev/api/caching#request-body + """ + if cached_content is not None: + return messages, cached_content + + cached_messages, non_cached_messages = separate_cached_messages( + messages=messages + ) + + if len(cached_messages) == 0: + return messages, None + + ## AUTHORIZATION ## + token, url = self._get_token_and_url_context_caching( + gemini_api_key=api_key, + custom_llm_provider="gemini", + api_base=api_base, + ) + + headers = { + "Content-Type": "application/json", + } + if token is not None: + headers["Authorization"] = f"Bearer {token}" + if extra_headers is not None: + headers.update(extra_headers) + + if client is None or not isinstance(client, AsyncHTTPHandler): + client = get_async_httpx_client( + params={"timeout": timeout}, llm_provider=litellm.LlmProviders.VERTEX_AI + ) + else: + client = client + + ## CHECK IF CACHED ALREADY + generated_cache_key = local_cache_obj.get_cache_key(messages=cached_messages) + google_cache_name = await self.async_check_cache( + cache_key=generated_cache_key, + client=client, + headers=headers, + api_key=api_key, + api_base=api_base, + logging_obj=logging_obj, + ) + if google_cache_name: + return non_cached_messages, google_cache_name + + ## TRANSFORM REQUEST + cached_content_request_body = ( + transform_openai_messages_to_gemini_context_caching( + model=model, messages=cached_messages, cache_key=generated_cache_key + ) + ) + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": cached_content_request_body, + "api_base": url, + "headers": headers, + }, + ) + + try: + response = await client.post( + url=url, headers=headers, json=cached_content_request_body # type: ignore + ) + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise VertexAIError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException: + raise VertexAIError(status_code=408, message="Timeout error occurred.") + + raw_response_cached = response.json() + cached_content_response_obj = VertexAICachedContentResponseObject( + name=raw_response_cached.get("name"), model=raw_response_cached.get("model") + ) + return (non_cached_messages, cached_content_response_obj["name"]) + + def get_cache(self): + pass + + async def async_get_cache(self): + pass |