aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket_base.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket_base.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket_base.py326
1 files changed, 326 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket_base.py b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket_base.py
new file mode 100644
index 00000000..66995d84
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket_base.py
@@ -0,0 +1,326 @@
+import json
+import os
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
+
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_batch_logger import CustomBatchLogger
+from litellm.llms.custom_httpx.http_handler import (
+ get_async_httpx_client,
+ httpxSpecialProvider,
+)
+from litellm.types.integrations.gcs_bucket import *
+from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
+
+if TYPE_CHECKING:
+ from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
+else:
+ VertexBase = Any
+IAM_AUTH_KEY = "IAM_AUTH"
+
+
+class GCSBucketBase(CustomBatchLogger):
+ def __init__(self, bucket_name: Optional[str] = None, **kwargs) -> None:
+ self.async_httpx_client = get_async_httpx_client(
+ llm_provider=httpxSpecialProvider.LoggingCallback
+ )
+ _path_service_account = os.getenv("GCS_PATH_SERVICE_ACCOUNT")
+ _bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
+ self.path_service_account_json: Optional[str] = _path_service_account
+ self.BUCKET_NAME: Optional[str] = _bucket_name
+ self.vertex_instances: Dict[str, VertexBase] = {}
+ super().__init__(**kwargs)
+
+ async def construct_request_headers(
+ self,
+ service_account_json: Optional[str],
+ vertex_instance: Optional[VertexBase] = None,
+ ) -> Dict[str, str]:
+ from litellm import vertex_chat_completion
+
+ if vertex_instance is None:
+ vertex_instance = vertex_chat_completion
+
+ _auth_header, vertex_project = await vertex_instance._ensure_access_token_async(
+ credentials=service_account_json,
+ project_id=None,
+ custom_llm_provider="vertex_ai",
+ )
+
+ auth_header, _ = vertex_instance._get_token_and_url(
+ model="gcs-bucket",
+ auth_header=_auth_header,
+ vertex_credentials=service_account_json,
+ vertex_project=vertex_project,
+ vertex_location=None,
+ gemini_api_key=None,
+ stream=None,
+ custom_llm_provider="vertex_ai",
+ api_base=None,
+ )
+ verbose_logger.debug("constructed auth_header %s", auth_header)
+ headers = {
+ "Authorization": f"Bearer {auth_header}", # auth_header
+ "Content-Type": "application/json",
+ }
+
+ return headers
+
+ def sync_construct_request_headers(self) -> Dict[str, str]:
+ from litellm import vertex_chat_completion
+
+ _auth_header, vertex_project = vertex_chat_completion._ensure_access_token(
+ credentials=self.path_service_account_json,
+ project_id=None,
+ custom_llm_provider="vertex_ai",
+ )
+
+ auth_header, _ = vertex_chat_completion._get_token_and_url(
+ model="gcs-bucket",
+ auth_header=_auth_header,
+ vertex_credentials=self.path_service_account_json,
+ vertex_project=vertex_project,
+ vertex_location=None,
+ gemini_api_key=None,
+ stream=None,
+ custom_llm_provider="vertex_ai",
+ api_base=None,
+ )
+ verbose_logger.debug("constructed auth_header %s", auth_header)
+ headers = {
+ "Authorization": f"Bearer {auth_header}", # auth_header
+ "Content-Type": "application/json",
+ }
+
+ return headers
+
+ def _handle_folders_in_bucket_name(
+ self,
+ bucket_name: str,
+ object_name: str,
+ ) -> Tuple[str, str]:
+ """
+ Handles when the user passes a bucket name with a folder postfix
+
+
+ Example:
+ - Bucket name: "my-bucket/my-folder/dev"
+ - Object name: "my-object"
+ - Returns: bucket_name="my-bucket", object_name="my-folder/dev/my-object"
+
+ """
+ if "/" in bucket_name:
+ bucket_name, prefix = bucket_name.split("/", 1)
+ object_name = f"{prefix}/{object_name}"
+ return bucket_name, object_name
+ return bucket_name, object_name
+
+ async def get_gcs_logging_config(
+ self, kwargs: Optional[Dict[str, Any]] = {}
+ ) -> GCSLoggingConfig:
+ """
+ This function is used to get the GCS logging config for the GCS Bucket Logger.
+ It checks if the dynamic parameters are provided in the kwargs and uses them to get the GCS logging config.
+ If no dynamic parameters are provided, it uses the default values.
+ """
+ if kwargs is None:
+ kwargs = {}
+
+ standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
+ kwargs.get("standard_callback_dynamic_params", None)
+ )
+
+ bucket_name: str
+ path_service_account: Optional[str]
+ if standard_callback_dynamic_params is not None:
+ verbose_logger.debug("Using dynamic GCS logging")
+ verbose_logger.debug(
+ "standard_callback_dynamic_params: %s", standard_callback_dynamic_params
+ )
+
+ _bucket_name: Optional[str] = (
+ standard_callback_dynamic_params.get("gcs_bucket_name", None)
+ or self.BUCKET_NAME
+ )
+ _path_service_account: Optional[str] = (
+ standard_callback_dynamic_params.get("gcs_path_service_account", None)
+ or self.path_service_account_json
+ )
+
+ if _bucket_name is None:
+ raise ValueError(
+ "GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
+ )
+ bucket_name = _bucket_name
+ path_service_account = _path_service_account
+ vertex_instance = await self.get_or_create_vertex_instance(
+ credentials=path_service_account
+ )
+ else:
+ # If no dynamic parameters, use the default instance
+ if self.BUCKET_NAME is None:
+ raise ValueError(
+ "GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
+ )
+ bucket_name = self.BUCKET_NAME
+ path_service_account = self.path_service_account_json
+ vertex_instance = await self.get_or_create_vertex_instance(
+ credentials=path_service_account
+ )
+
+ return GCSLoggingConfig(
+ bucket_name=bucket_name,
+ vertex_instance=vertex_instance,
+ path_service_account=path_service_account,
+ )
+
+ async def get_or_create_vertex_instance(
+ self, credentials: Optional[str]
+ ) -> VertexBase:
+ """
+ This function is used to get the Vertex instance for the GCS Bucket Logger.
+ It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
+ """
+ from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
+
+ _in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials)
+ if _in_memory_key not in self.vertex_instances:
+ vertex_instance = VertexBase()
+ await vertex_instance._ensure_access_token_async(
+ credentials=credentials,
+ project_id=None,
+ custom_llm_provider="vertex_ai",
+ )
+ self.vertex_instances[_in_memory_key] = vertex_instance
+ return self.vertex_instances[_in_memory_key]
+
+ def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str:
+ """
+ Returns key to use for caching the Vertex instance in-memory.
+
+ When using Vertex with Key based logging, we need to cache the Vertex instance in-memory.
+
+ - If a credentials string is provided, it is used as the key.
+ - If no credentials string is provided, "IAM_AUTH" is used as the key.
+ """
+ return credentials or IAM_AUTH_KEY
+
+ async def download_gcs_object(self, object_name: str, **kwargs):
+ """
+ Download an object from GCS.
+
+ https://cloud.google.com/storage/docs/downloading-objects#download-object-json
+ """
+ try:
+ gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
+ kwargs=kwargs
+ )
+ headers = await self.construct_request_headers(
+ vertex_instance=gcs_logging_config["vertex_instance"],
+ service_account_json=gcs_logging_config["path_service_account"],
+ )
+ bucket_name = gcs_logging_config["bucket_name"]
+ bucket_name, object_name = self._handle_folders_in_bucket_name(
+ bucket_name=bucket_name,
+ object_name=object_name,
+ )
+
+ url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
+
+ # Send the GET request to download the object
+ response = await self.async_httpx_client.get(url=url, headers=headers)
+
+ if response.status_code != 200:
+ verbose_logger.error(
+ "GCS object download error: %s", str(response.text)
+ )
+ return None
+
+ verbose_logger.debug(
+ "GCS object download response status code: %s", response.status_code
+ )
+
+ # Return the content of the downloaded object
+ return response.content
+
+ except Exception as e:
+ verbose_logger.error("GCS object download error: %s", str(e))
+ return None
+
+ async def delete_gcs_object(self, object_name: str, **kwargs):
+ """
+ Delete an object from GCS.
+ """
+ try:
+ gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
+ kwargs=kwargs
+ )
+ headers = await self.construct_request_headers(
+ vertex_instance=gcs_logging_config["vertex_instance"],
+ service_account_json=gcs_logging_config["path_service_account"],
+ )
+ bucket_name = gcs_logging_config["bucket_name"]
+ bucket_name, object_name = self._handle_folders_in_bucket_name(
+ bucket_name=bucket_name,
+ object_name=object_name,
+ )
+
+ url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}"
+
+ # Send the DELETE request to delete the object
+ response = await self.async_httpx_client.delete(url=url, headers=headers)
+
+ if (response.status_code != 200) or (response.status_code != 204):
+ verbose_logger.error(
+ "GCS object delete error: %s, status code: %s",
+ str(response.text),
+ response.status_code,
+ )
+ return None
+
+ verbose_logger.debug(
+ "GCS object delete response status code: %s, response: %s",
+ response.status_code,
+ response.text,
+ )
+
+ # Return the content of the downloaded object
+ return response.text
+
+ except Exception as e:
+ verbose_logger.error("GCS object download error: %s", str(e))
+ return None
+
+ async def _log_json_data_on_gcs(
+ self,
+ headers: Dict[str, str],
+ bucket_name: str,
+ object_name: str,
+ logging_payload: Union[StandardLoggingPayload, str],
+ ):
+ """
+ Helper function to make POST request to GCS Bucket in the specified bucket.
+ """
+ if isinstance(logging_payload, str):
+ json_logged_payload = logging_payload
+ else:
+ json_logged_payload = json.dumps(logging_payload, default=str)
+
+ bucket_name, object_name = self._handle_folders_in_bucket_name(
+ bucket_name=bucket_name,
+ object_name=object_name,
+ )
+
+ response = await self.async_httpx_client.post(
+ headers=headers,
+ url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
+ data=json_logged_payload,
+ )
+
+ if response.status_code != 200:
+ verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
+
+ verbose_logger.debug("GCS Bucket response %s", response)
+ verbose_logger.debug("GCS Bucket status code %s", response.status_code)
+ verbose_logger.debug("GCS Bucket response.text %s", response.text)
+
+ return response.json()