aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_llm_base.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_llm_base.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_llm_base.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_llm_base.py319
1 files changed, 319 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_llm_base.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_llm_base.py
new file mode 100644
index 00000000..8286cb51
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_llm_base.py
@@ -0,0 +1,319 @@
+"""
+Base Vertex, Google AI Studio LLM Class
+
+Handles Authentication and generating request urls for Vertex AI and Google AI Studio
+"""
+
+import json
+import os
+from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple
+
+from litellm._logging import verbose_logger
+from litellm.litellm_core_utils.asyncify import asyncify
+from litellm.llms.base import BaseLLM
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
+from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
+
+from .common_utils import _get_gemini_url, _get_vertex_url, all_gemini_url_modes
+
+if TYPE_CHECKING:
+ from google.auth.credentials import Credentials as GoogleCredentialsObject
+else:
+ GoogleCredentialsObject = Any
+
+
+class VertexBase(BaseLLM):
+ def __init__(self) -> None:
+ super().__init__()
+ self.access_token: Optional[str] = None
+ self.refresh_token: Optional[str] = None
+ self._credentials: Optional[GoogleCredentialsObject] = None
+ self.project_id: Optional[str] = None
+ self.async_handler: Optional[AsyncHTTPHandler] = None
+
+ def get_vertex_region(self, vertex_region: Optional[str]) -> str:
+ return vertex_region or "us-central1"
+
+ def load_auth(
+ self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], project_id: Optional[str]
+ ) -> Tuple[Any, str]:
+ import google.auth as google_auth
+ from google.auth import identity_pool
+ from google.auth.transport.requests import (
+ Request, # type: ignore[import-untyped]
+ )
+
+ if credentials is not None:
+ import google.oauth2.service_account
+
+ if isinstance(credentials, str):
+ verbose_logger.debug(
+ "Vertex: Loading vertex credentials from %s", credentials
+ )
+ verbose_logger.debug(
+ "Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
+ credentials,
+ os.path.exists(credentials),
+ os.getcwd(),
+ )
+
+ try:
+ if os.path.exists(credentials):
+ json_obj = json.load(open(credentials))
+ else:
+ json_obj = json.loads(credentials)
+ except Exception:
+ raise Exception(
+ "Unable to load vertex credentials from environment. Got={}".format(
+ credentials
+ )
+ )
+ elif isinstance(credentials, dict):
+ json_obj = credentials
+ else:
+ raise ValueError(
+ "Invalid credentials type: {}".format(type(credentials))
+ )
+
+ # Check if the JSON object contains Workload Identity Federation configuration
+ if "type" in json_obj and json_obj["type"] == "external_account":
+ creds = identity_pool.Credentials.from_info(json_obj)
+ else:
+ creds = (
+ google.oauth2.service_account.Credentials.from_service_account_info(
+ json_obj,
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+ )
+
+ if project_id is None:
+ project_id = getattr(creds, "project_id", None)
+ else:
+ creds, creds_project_id = google_auth.default(
+ quota_project_id=project_id,
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+ if project_id is None:
+ project_id = creds_project_id
+
+ creds.refresh(Request()) # type: ignore
+
+ if not project_id:
+ raise ValueError("Could not resolve project_id")
+
+ if not isinstance(project_id, str):
+ raise TypeError(
+ f"Expected project_id to be a str but got {type(project_id)}"
+ )
+
+ return creds, project_id
+
+ def refresh_auth(self, credentials: Any) -> None:
+ from google.auth.transport.requests import (
+ Request, # type: ignore[import-untyped]
+ )
+
+ credentials.refresh(Request())
+
+ def _ensure_access_token(
+ self,
+ credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ project_id: Optional[str],
+ custom_llm_provider: Literal[
+ "vertex_ai", "vertex_ai_beta", "gemini"
+ ], # if it's vertex_ai or gemini (google ai studio)
+ ) -> Tuple[str, str]:
+ """
+ Returns auth token and project id
+ """
+ if custom_llm_provider == "gemini":
+ return "", ""
+ if self.access_token is not None:
+ if project_id is not None:
+ return self.access_token, project_id
+ elif self.project_id is not None:
+ return self.access_token, self.project_id
+
+ if not self._credentials:
+ self._credentials, cred_project_id = self.load_auth(
+ credentials=credentials, project_id=project_id
+ )
+ if not self.project_id:
+ self.project_id = project_id or cred_project_id
+ else:
+ if self._credentials.expired or not self._credentials.token:
+ self.refresh_auth(self._credentials)
+
+ if not self.project_id:
+ self.project_id = self._credentials.quota_project_id
+
+ if not self.project_id:
+ raise ValueError("Could not resolve project_id")
+
+ if not self._credentials or not self._credentials.token:
+ raise RuntimeError("Could not resolve API token from the environment")
+
+ return self._credentials.token, project_id or self.project_id
+
+ def is_using_v1beta1_features(self, optional_params: dict) -> bool:
+ """
+ VertexAI only supports ContextCaching on v1beta1
+
+ use this helper to decide if request should be sent to v1 or v1beta1
+
+ Returns v1beta1 if context caching is enabled
+ Returns v1 in all other cases
+ """
+ if "cached_content" in optional_params:
+ return True
+ if "CachedContent" in optional_params:
+ return True
+ return False
+
+ def _check_custom_proxy(
+ self,
+ api_base: Optional[str],
+ custom_llm_provider: str,
+ gemini_api_key: Optional[str],
+ endpoint: str,
+ stream: Optional[bool],
+ auth_header: Optional[str],
+ url: str,
+ ) -> Tuple[Optional[str], str]:
+ """
+ for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317
+
+ ## Returns
+ - (auth_header, url) - Tuple[Optional[str], str]
+ """
+ if api_base:
+ if custom_llm_provider == "gemini":
+ url = "{}:{}".format(api_base, endpoint)
+ if gemini_api_key is None:
+ raise ValueError(
+ "Missing gemini_api_key, please set `GEMINI_API_KEY`"
+ )
+ auth_header = (
+ gemini_api_key # cloudflare expects api key as bearer token
+ )
+ else:
+ url = "{}:{}".format(api_base, endpoint)
+
+ if stream is True:
+ url = url + "?alt=sse"
+ return auth_header, url
+
+ def _get_token_and_url(
+ self,
+ model: str,
+ auth_header: Optional[str],
+ gemini_api_key: Optional[str],
+ vertex_project: Optional[str],
+ vertex_location: Optional[str],
+ vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ stream: Optional[bool],
+ custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
+ api_base: Optional[str],
+ should_use_v1beta1_features: Optional[bool] = False,
+ mode: all_gemini_url_modes = "chat",
+ ) -> Tuple[Optional[str], str]:
+ """
+ Internal function. Returns the token and url for the call.
+
+ Handles logic if it's google ai studio vs. vertex ai.
+
+ Returns
+ token, url
+ """
+ if custom_llm_provider == "gemini":
+ url, endpoint = _get_gemini_url(
+ mode=mode,
+ model=model,
+ stream=stream,
+ gemini_api_key=gemini_api_key,
+ )
+ auth_header = None # this field is not used for gemin
+ else:
+ vertex_location = self.get_vertex_region(vertex_region=vertex_location)
+
+ ### SET RUNTIME ENDPOINT ###
+ version: Literal["v1beta1", "v1"] = (
+ "v1beta1" if should_use_v1beta1_features is True else "v1"
+ )
+ url, endpoint = _get_vertex_url(
+ mode=mode,
+ model=model,
+ stream=stream,
+ vertex_project=vertex_project,
+ vertex_location=vertex_location,
+ vertex_api_version=version,
+ )
+
+ return self._check_custom_proxy(
+ api_base=api_base,
+ auth_header=auth_header,
+ custom_llm_provider=custom_llm_provider,
+ gemini_api_key=gemini_api_key,
+ endpoint=endpoint,
+ stream=stream,
+ url=url,
+ )
+
+ async def _ensure_access_token_async(
+ self,
+ credentials: Optional[VERTEX_CREDENTIALS_TYPES],
+ project_id: Optional[str],
+ custom_llm_provider: Literal[
+ "vertex_ai", "vertex_ai_beta", "gemini"
+ ], # if it's vertex_ai or gemini (google ai studio)
+ ) -> Tuple[str, str]:
+ """
+ Async version of _ensure_access_token
+ """
+ if custom_llm_provider == "gemini":
+ return "", ""
+ if self.access_token is not None:
+ if project_id is not None:
+ return self.access_token, project_id
+ elif self.project_id is not None:
+ return self.access_token, self.project_id
+
+ if not self._credentials:
+ try:
+ self._credentials, cred_project_id = await asyncify(self.load_auth)(
+ credentials=credentials, project_id=project_id
+ )
+ except Exception:
+ verbose_logger.exception(
+ "Failed to load vertex credentials. Check to see if credentials containing partial/invalid information."
+ )
+ raise
+ if not self.project_id:
+ self.project_id = project_id or cred_project_id
+ else:
+ if self._credentials.expired or not self._credentials.token:
+ await asyncify(self.refresh_auth)(self._credentials)
+
+ if not self.project_id:
+ self.project_id = self._credentials.quota_project_id
+
+ if not self.project_id:
+ raise ValueError("Could not resolve project_id")
+
+ if not self._credentials or not self._credentials.token:
+ raise RuntimeError("Could not resolve API token from the environment")
+
+ return self._credentials.token, project_id or self.project_id
+
+ def set_headers(
+ self, auth_header: Optional[str], extra_headers: Optional[dict]
+ ) -> dict:
+ headers = {
+ "Content-Type": "application/json",
+ }
+ if auth_header is not None:
+ headers["Authorization"] = f"Bearer {auth_header}"
+ if extra_headers is not None:
+ headers.update(extra_headers)
+
+ return headers