diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_llm_base.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_llm_base.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_llm_base.py | 319 |
1 files changed, 319 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_llm_base.py b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_llm_base.py new file mode 100644 index 00000000..8286cb51 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/llms/vertex_ai/vertex_llm_base.py @@ -0,0 +1,319 @@ +""" +Base Vertex, Google AI Studio LLM Class + +Handles Authentication and generating request urls for Vertex AI and Google AI Studio +""" + +import json +import os +from typing import TYPE_CHECKING, Any, Literal, Optional, Tuple + +from litellm._logging import verbose_logger +from litellm.litellm_core_utils.asyncify import asyncify +from litellm.llms.base import BaseLLM +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES + +from .common_utils import _get_gemini_url, _get_vertex_url, all_gemini_url_modes + +if TYPE_CHECKING: + from google.auth.credentials import Credentials as GoogleCredentialsObject +else: + GoogleCredentialsObject = Any + + +class VertexBase(BaseLLM): + def __init__(self) -> None: + super().__init__() + self.access_token: Optional[str] = None + self.refresh_token: Optional[str] = None + self._credentials: Optional[GoogleCredentialsObject] = None + self.project_id: Optional[str] = None + self.async_handler: Optional[AsyncHTTPHandler] = None + + def get_vertex_region(self, vertex_region: Optional[str]) -> str: + return vertex_region or "us-central1" + + def load_auth( + self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], project_id: Optional[str] + ) -> Tuple[Any, str]: + import google.auth as google_auth + from google.auth import identity_pool + from google.auth.transport.requests import ( + Request, # type: ignore[import-untyped] + ) + + if credentials is not None: + import google.oauth2.service_account + + if isinstance(credentials, str): + verbose_logger.debug( + "Vertex: Loading vertex credentials from %s", credentials + ) + verbose_logger.debug( + "Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s", + credentials, + os.path.exists(credentials), + os.getcwd(), + ) + + try: + if os.path.exists(credentials): + json_obj = json.load(open(credentials)) + else: + json_obj = json.loads(credentials) + except Exception: + raise Exception( + "Unable to load vertex credentials from environment. Got={}".format( + credentials + ) + ) + elif isinstance(credentials, dict): + json_obj = credentials + else: + raise ValueError( + "Invalid credentials type: {}".format(type(credentials)) + ) + + # Check if the JSON object contains Workload Identity Federation configuration + if "type" in json_obj and json_obj["type"] == "external_account": + creds = identity_pool.Credentials.from_info(json_obj) + else: + creds = ( + google.oauth2.service_account.Credentials.from_service_account_info( + json_obj, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + ) + + if project_id is None: + project_id = getattr(creds, "project_id", None) + else: + creds, creds_project_id = google_auth.default( + quota_project_id=project_id, + scopes=["https://www.googleapis.com/auth/cloud-platform"], + ) + if project_id is None: + project_id = creds_project_id + + creds.refresh(Request()) # type: ignore + + if not project_id: + raise ValueError("Could not resolve project_id") + + if not isinstance(project_id, str): + raise TypeError( + f"Expected project_id to be a str but got {type(project_id)}" + ) + + return creds, project_id + + def refresh_auth(self, credentials: Any) -> None: + from google.auth.transport.requests import ( + Request, # type: ignore[import-untyped] + ) + + credentials.refresh(Request()) + + def _ensure_access_token( + self, + credentials: Optional[VERTEX_CREDENTIALS_TYPES], + project_id: Optional[str], + custom_llm_provider: Literal[ + "vertex_ai", "vertex_ai_beta", "gemini" + ], # if it's vertex_ai or gemini (google ai studio) + ) -> Tuple[str, str]: + """ + Returns auth token and project id + """ + if custom_llm_provider == "gemini": + return "", "" + if self.access_token is not None: + if project_id is not None: + return self.access_token, project_id + elif self.project_id is not None: + return self.access_token, self.project_id + + if not self._credentials: + self._credentials, cred_project_id = self.load_auth( + credentials=credentials, project_id=project_id + ) + if not self.project_id: + self.project_id = project_id or cred_project_id + else: + if self._credentials.expired or not self._credentials.token: + self.refresh_auth(self._credentials) + + if not self.project_id: + self.project_id = self._credentials.quota_project_id + + if not self.project_id: + raise ValueError("Could not resolve project_id") + + if not self._credentials or not self._credentials.token: + raise RuntimeError("Could not resolve API token from the environment") + + return self._credentials.token, project_id or self.project_id + + def is_using_v1beta1_features(self, optional_params: dict) -> bool: + """ + VertexAI only supports ContextCaching on v1beta1 + + use this helper to decide if request should be sent to v1 or v1beta1 + + Returns v1beta1 if context caching is enabled + Returns v1 in all other cases + """ + if "cached_content" in optional_params: + return True + if "CachedContent" in optional_params: + return True + return False + + def _check_custom_proxy( + self, + api_base: Optional[str], + custom_llm_provider: str, + gemini_api_key: Optional[str], + endpoint: str, + stream: Optional[bool], + auth_header: Optional[str], + url: str, + ) -> Tuple[Optional[str], str]: + """ + for cloudflare ai gateway - https://github.com/BerriAI/litellm/issues/4317 + + ## Returns + - (auth_header, url) - Tuple[Optional[str], str] + """ + if api_base: + if custom_llm_provider == "gemini": + url = "{}:{}".format(api_base, endpoint) + if gemini_api_key is None: + raise ValueError( + "Missing gemini_api_key, please set `GEMINI_API_KEY`" + ) + auth_header = ( + gemini_api_key # cloudflare expects api key as bearer token + ) + else: + url = "{}:{}".format(api_base, endpoint) + + if stream is True: + url = url + "?alt=sse" + return auth_header, url + + def _get_token_and_url( + self, + model: str, + auth_header: Optional[str], + gemini_api_key: Optional[str], + vertex_project: Optional[str], + vertex_location: Optional[str], + vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES], + stream: Optional[bool], + custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"], + api_base: Optional[str], + should_use_v1beta1_features: Optional[bool] = False, + mode: all_gemini_url_modes = "chat", + ) -> Tuple[Optional[str], str]: + """ + Internal function. Returns the token and url for the call. + + Handles logic if it's google ai studio vs. vertex ai. + + Returns + token, url + """ + if custom_llm_provider == "gemini": + url, endpoint = _get_gemini_url( + mode=mode, + model=model, + stream=stream, + gemini_api_key=gemini_api_key, + ) + auth_header = None # this field is not used for gemin + else: + vertex_location = self.get_vertex_region(vertex_region=vertex_location) + + ### SET RUNTIME ENDPOINT ### + version: Literal["v1beta1", "v1"] = ( + "v1beta1" if should_use_v1beta1_features is True else "v1" + ) + url, endpoint = _get_vertex_url( + mode=mode, + model=model, + stream=stream, + vertex_project=vertex_project, + vertex_location=vertex_location, + vertex_api_version=version, + ) + + return self._check_custom_proxy( + api_base=api_base, + auth_header=auth_header, + custom_llm_provider=custom_llm_provider, + gemini_api_key=gemini_api_key, + endpoint=endpoint, + stream=stream, + url=url, + ) + + async def _ensure_access_token_async( + self, + credentials: Optional[VERTEX_CREDENTIALS_TYPES], + project_id: Optional[str], + custom_llm_provider: Literal[ + "vertex_ai", "vertex_ai_beta", "gemini" + ], # if it's vertex_ai or gemini (google ai studio) + ) -> Tuple[str, str]: + """ + Async version of _ensure_access_token + """ + if custom_llm_provider == "gemini": + return "", "" + if self.access_token is not None: + if project_id is not None: + return self.access_token, project_id + elif self.project_id is not None: + return self.access_token, self.project_id + + if not self._credentials: + try: + self._credentials, cred_project_id = await asyncify(self.load_auth)( + credentials=credentials, project_id=project_id + ) + except Exception: + verbose_logger.exception( + "Failed to load vertex credentials. Check to see if credentials containing partial/invalid information." + ) + raise + if not self.project_id: + self.project_id = project_id or cred_project_id + else: + if self._credentials.expired or not self._credentials.token: + await asyncify(self.refresh_auth)(self._credentials) + + if not self.project_id: + self.project_id = self._credentials.quota_project_id + + if not self.project_id: + raise ValueError("Could not resolve project_id") + + if not self._credentials or not self._credentials.token: + raise RuntimeError("Could not resolve API token from the environment") + + return self._credentials.token, project_id or self.project_id + + def set_headers( + self, auth_header: Optional[str], extra_headers: Optional[dict] + ) -> dict: + headers = { + "Content-Type": "application/json", + } + if auth_header is not None: + headers["Authorization"] = f"Bearer {auth_header}" + if extra_headers is not None: + headers.update(extra_headers) + + return headers |