about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/Readme.md12
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket.py237
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket_base.py326
3 files changed, 575 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/Readme.md b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/Readme.md
new file mode 100644
index 00000000..2ab0b233
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/Readme.md
@@ -0,0 +1,12 @@
+# GCS (Google Cloud Storage) Bucket Logging on LiteLLM Gateway 
+
+This folder contains the GCS Bucket Logging integration for LiteLLM Gateway. 
+
+## Folder Structure 
+
+- `gcs_bucket.py`: This is the main file that handles failure/success logging to GCS Bucket
+- `gcs_bucket_base.py`: This file contains the GCSBucketBase class which handles Authentication for GCS Buckets
+
+## Further Reading
+- [Doc setting up GCS Bucket Logging on LiteLLM Proxy (Gateway)](https://docs.litellm.ai/docs/proxy/bucket)
+- [Doc on Key / Team Based logging with GCS](https://docs.litellm.ai/docs/proxy/team_logging)
\ No newline at end of file
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket.py b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket.py
new file mode 100644
index 00000000..187ab779
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket.py
@@ -0,0 +1,237 @@
+import asyncio
+import json
+import os
+import uuid
+from datetime import datetime, timedelta, timezone
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from urllib.parse import quote
+
+from litellm._logging import verbose_logger
+from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils
+from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
+from litellm.proxy._types import CommonProxyErrors
+from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
+from litellm.types.integrations.gcs_bucket import *
+from litellm.types.utils import StandardLoggingPayload
+
+if TYPE_CHECKING:
+    from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
+else:
+    VertexBase = Any
+
+
+GCS_DEFAULT_BATCH_SIZE = 2048
+GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20
+
+
+class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
+    def __init__(self, bucket_name: Optional[str] = None) -> None:
+        from litellm.proxy.proxy_server import premium_user
+
+        super().__init__(bucket_name=bucket_name)
+
+        # Init Batch logging settings
+        self.log_queue: List[GCSLogQueueItem] = []
+        self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE))
+        self.flush_interval = int(
+            os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS)
+        )
+        asyncio.create_task(self.periodic_flush())
+        self.flush_lock = asyncio.Lock()
+        super().__init__(
+            flush_lock=self.flush_lock,
+            batch_size=self.batch_size,
+            flush_interval=self.flush_interval,
+        )
+        AdditionalLoggingUtils.__init__(self)
+
+        if premium_user is not True:
+            raise ValueError(
+                f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
+            )
+
+    #### ASYNC ####
+    async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+        from litellm.proxy.proxy_server import premium_user
+
+        if premium_user is not True:
+            raise ValueError(
+                f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
+            )
+        try:
+            verbose_logger.debug(
+                "GCS Logger: async_log_success_event logging kwargs: %s, response_obj: %s",
+                kwargs,
+                response_obj,
+            )
+            logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
+                "standard_logging_object", None
+            )
+            if logging_payload is None:
+                raise ValueError("standard_logging_object not found in kwargs")
+            # Add to logging queue - this will be flushed periodically
+            self.log_queue.append(
+                GCSLogQueueItem(
+                    payload=logging_payload, kwargs=kwargs, response_obj=response_obj
+                )
+            )
+
+        except Exception as e:
+            verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
+
+    async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+        try:
+            verbose_logger.debug(
+                "GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
+                kwargs,
+                response_obj,
+            )
+
+            logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
+                "standard_logging_object", None
+            )
+            if logging_payload is None:
+                raise ValueError("standard_logging_object not found in kwargs")
+            # Add to logging queue - this will be flushed periodically
+            self.log_queue.append(
+                GCSLogQueueItem(
+                    payload=logging_payload, kwargs=kwargs, response_obj=response_obj
+                )
+            )
+
+        except Exception as e:
+            verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
+
+    async def async_send_batch(self):
+        """
+        Process queued logs in batch - sends logs to GCS Bucket
+
+
+        GCS Bucket does not have a Batch endpoint to batch upload logs
+
+        Instead, we
+            - collect the logs to flush every `GCS_FLUSH_INTERVAL` seconds
+            - during async_send_batch, we make 1 POST request per log to GCS Bucket
+
+        """
+        if not self.log_queue:
+            return
+
+        for log_item in self.log_queue:
+            logging_payload = log_item["payload"]
+            kwargs = log_item["kwargs"]
+            response_obj = log_item.get("response_obj", None) or {}
+
+            gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
+                kwargs
+            )
+            headers = await self.construct_request_headers(
+                vertex_instance=gcs_logging_config["vertex_instance"],
+                service_account_json=gcs_logging_config["path_service_account"],
+            )
+            bucket_name = gcs_logging_config["bucket_name"]
+            object_name = self._get_object_name(kwargs, logging_payload, response_obj)
+
+            try:
+                await self._log_json_data_on_gcs(
+                    headers=headers,
+                    bucket_name=bucket_name,
+                    object_name=object_name,
+                    logging_payload=logging_payload,
+                )
+            except Exception as e:
+                # don't let one log item fail the entire batch
+                verbose_logger.exception(
+                    f"GCS Bucket error logging payload to GCS bucket: {str(e)}"
+                )
+                pass
+
+        # Clear the queue after processing
+        self.log_queue.clear()
+
+    def _get_object_name(
+        self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any
+    ) -> str:
+        """
+        Get the object name to use for the current payload
+        """
+        current_date = self._get_object_date_from_datetime(datetime.now(timezone.utc))
+        if logging_payload.get("error_str", None) is not None:
+            object_name = self._generate_failure_object_name(
+                request_date_str=current_date,
+            )
+        else:
+            object_name = self._generate_success_object_name(
+                request_date_str=current_date,
+                response_id=response_obj.get("id", ""),
+            )
+
+        # used for testing
+        _litellm_params = kwargs.get("litellm_params", None) or {}
+        _metadata = _litellm_params.get("metadata", None) or {}
+        if "gcs_log_id" in _metadata:
+            object_name = _metadata["gcs_log_id"]
+
+        return object_name
+
+    async def get_request_response_payload(
+        self,
+        request_id: str,
+        start_time_utc: Optional[datetime],
+        end_time_utc: Optional[datetime],
+    ) -> Optional[dict]:
+        """
+        Get the request and response payload for a given `request_id`
+        Tries current day, next day, and previous day until it finds the payload
+        """
+        if start_time_utc is None:
+            raise ValueError(
+                "start_time_utc is required for getting a payload from GCS Bucket"
+            )
+
+        # Try current day, next day, and previous day
+        dates_to_try = [
+            start_time_utc,
+            start_time_utc + timedelta(days=1),
+            start_time_utc - timedelta(days=1),
+        ]
+        date_str = None
+        for date in dates_to_try:
+            try:
+                date_str = self._get_object_date_from_datetime(datetime_obj=date)
+                object_name = self._generate_success_object_name(
+                    request_date_str=date_str,
+                    response_id=request_id,
+                )
+                encoded_object_name = quote(object_name, safe="")
+                response = await self.download_gcs_object(encoded_object_name)
+
+                if response is not None:
+                    loaded_response = json.loads(response)
+                    return loaded_response
+            except Exception as e:
+                verbose_logger.debug(
+                    f"Failed to fetch payload for date {date_str}: {str(e)}"
+                )
+                continue
+
+        return None
+
+    def _generate_success_object_name(
+        self,
+        request_date_str: str,
+        response_id: str,
+    ) -> str:
+        return f"{request_date_str}/{response_id}"
+
+    def _generate_failure_object_name(
+        self,
+        request_date_str: str,
+    ) -> str:
+        return f"{request_date_str}/failure-{uuid.uuid4().hex}"
+
+    def _get_object_date_from_datetime(self, datetime_obj: datetime) -> str:
+        return datetime_obj.strftime("%Y-%m-%d")
+
+    async def async_health_check(self) -> IntegrationHealthCheckStatus:
+        raise NotImplementedError("GCS Bucket does not support health check")
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket_base.py b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket_base.py
new file mode 100644
index 00000000..66995d84
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket_base.py
@@ -0,0 +1,326 @@
+import json
+import os
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
+
+from litellm._logging import verbose_logger
+from litellm.integrations.custom_batch_logger import CustomBatchLogger
+from litellm.llms.custom_httpx.http_handler import (
+    get_async_httpx_client,
+    httpxSpecialProvider,
+)
+from litellm.types.integrations.gcs_bucket import *
+from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
+
+if TYPE_CHECKING:
+    from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
+else:
+    VertexBase = Any
+IAM_AUTH_KEY = "IAM_AUTH"
+
+
+class GCSBucketBase(CustomBatchLogger):
+    def __init__(self, bucket_name: Optional[str] = None, **kwargs) -> None:
+        self.async_httpx_client = get_async_httpx_client(
+            llm_provider=httpxSpecialProvider.LoggingCallback
+        )
+        _path_service_account = os.getenv("GCS_PATH_SERVICE_ACCOUNT")
+        _bucket_name = bucket_name or os.getenv("GCS_BUCKET_NAME")
+        self.path_service_account_json: Optional[str] = _path_service_account
+        self.BUCKET_NAME: Optional[str] = _bucket_name
+        self.vertex_instances: Dict[str, VertexBase] = {}
+        super().__init__(**kwargs)
+
+    async def construct_request_headers(
+        self,
+        service_account_json: Optional[str],
+        vertex_instance: Optional[VertexBase] = None,
+    ) -> Dict[str, str]:
+        from litellm import vertex_chat_completion
+
+        if vertex_instance is None:
+            vertex_instance = vertex_chat_completion
+
+        _auth_header, vertex_project = await vertex_instance._ensure_access_token_async(
+            credentials=service_account_json,
+            project_id=None,
+            custom_llm_provider="vertex_ai",
+        )
+
+        auth_header, _ = vertex_instance._get_token_and_url(
+            model="gcs-bucket",
+            auth_header=_auth_header,
+            vertex_credentials=service_account_json,
+            vertex_project=vertex_project,
+            vertex_location=None,
+            gemini_api_key=None,
+            stream=None,
+            custom_llm_provider="vertex_ai",
+            api_base=None,
+        )
+        verbose_logger.debug("constructed auth_header %s", auth_header)
+        headers = {
+            "Authorization": f"Bearer {auth_header}",  # auth_header
+            "Content-Type": "application/json",
+        }
+
+        return headers
+
+    def sync_construct_request_headers(self) -> Dict[str, str]:
+        from litellm import vertex_chat_completion
+
+        _auth_header, vertex_project = vertex_chat_completion._ensure_access_token(
+            credentials=self.path_service_account_json,
+            project_id=None,
+            custom_llm_provider="vertex_ai",
+        )
+
+        auth_header, _ = vertex_chat_completion._get_token_and_url(
+            model="gcs-bucket",
+            auth_header=_auth_header,
+            vertex_credentials=self.path_service_account_json,
+            vertex_project=vertex_project,
+            vertex_location=None,
+            gemini_api_key=None,
+            stream=None,
+            custom_llm_provider="vertex_ai",
+            api_base=None,
+        )
+        verbose_logger.debug("constructed auth_header %s", auth_header)
+        headers = {
+            "Authorization": f"Bearer {auth_header}",  # auth_header
+            "Content-Type": "application/json",
+        }
+
+        return headers
+
+    def _handle_folders_in_bucket_name(
+        self,
+        bucket_name: str,
+        object_name: str,
+    ) -> Tuple[str, str]:
+        """
+        Handles when the user passes a bucket name with a folder postfix
+
+
+        Example:
+            - Bucket name: "my-bucket/my-folder/dev"
+            - Object name: "my-object"
+            - Returns: bucket_name="my-bucket", object_name="my-folder/dev/my-object"
+
+        """
+        if "/" in bucket_name:
+            bucket_name, prefix = bucket_name.split("/", 1)
+            object_name = f"{prefix}/{object_name}"
+            return bucket_name, object_name
+        return bucket_name, object_name
+
+    async def get_gcs_logging_config(
+        self, kwargs: Optional[Dict[str, Any]] = {}
+    ) -> GCSLoggingConfig:
+        """
+        This function is used to get the GCS logging config for the GCS Bucket Logger.
+        It checks if the dynamic parameters are provided in the kwargs and uses them to get the GCS logging config.
+        If no dynamic parameters are provided, it uses the default values.
+        """
+        if kwargs is None:
+            kwargs = {}
+
+        standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
+            kwargs.get("standard_callback_dynamic_params", None)
+        )
+
+        bucket_name: str
+        path_service_account: Optional[str]
+        if standard_callback_dynamic_params is not None:
+            verbose_logger.debug("Using dynamic GCS logging")
+            verbose_logger.debug(
+                "standard_callback_dynamic_params: %s", standard_callback_dynamic_params
+            )
+
+            _bucket_name: Optional[str] = (
+                standard_callback_dynamic_params.get("gcs_bucket_name", None)
+                or self.BUCKET_NAME
+            )
+            _path_service_account: Optional[str] = (
+                standard_callback_dynamic_params.get("gcs_path_service_account", None)
+                or self.path_service_account_json
+            )
+
+            if _bucket_name is None:
+                raise ValueError(
+                    "GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
+                )
+            bucket_name = _bucket_name
+            path_service_account = _path_service_account
+            vertex_instance = await self.get_or_create_vertex_instance(
+                credentials=path_service_account
+            )
+        else:
+            # If no dynamic parameters, use the default instance
+            if self.BUCKET_NAME is None:
+                raise ValueError(
+                    "GCS_BUCKET_NAME is not set in the environment, but GCS Bucket is being used as a logging callback. Please set 'GCS_BUCKET_NAME' in the environment."
+                )
+            bucket_name = self.BUCKET_NAME
+            path_service_account = self.path_service_account_json
+            vertex_instance = await self.get_or_create_vertex_instance(
+                credentials=path_service_account
+            )
+
+        return GCSLoggingConfig(
+            bucket_name=bucket_name,
+            vertex_instance=vertex_instance,
+            path_service_account=path_service_account,
+        )
+
+    async def get_or_create_vertex_instance(
+        self, credentials: Optional[str]
+    ) -> VertexBase:
+        """
+        This function is used to get the Vertex instance for the GCS Bucket Logger.
+        It checks if the Vertex instance is already created and cached, if not it creates a new instance and caches it.
+        """
+        from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
+
+        _in_memory_key = self._get_in_memory_key_for_vertex_instance(credentials)
+        if _in_memory_key not in self.vertex_instances:
+            vertex_instance = VertexBase()
+            await vertex_instance._ensure_access_token_async(
+                credentials=credentials,
+                project_id=None,
+                custom_llm_provider="vertex_ai",
+            )
+            self.vertex_instances[_in_memory_key] = vertex_instance
+        return self.vertex_instances[_in_memory_key]
+
+    def _get_in_memory_key_for_vertex_instance(self, credentials: Optional[str]) -> str:
+        """
+        Returns key to use for caching the Vertex instance in-memory.
+
+        When using Vertex with Key based logging, we need to cache the Vertex instance in-memory.
+
+        - If a credentials string is provided, it is used as the key.
+        - If no credentials string is provided, "IAM_AUTH" is used as the key.
+        """
+        return credentials or IAM_AUTH_KEY
+
+    async def download_gcs_object(self, object_name: str, **kwargs):
+        """
+        Download an object from GCS.
+
+        https://cloud.google.com/storage/docs/downloading-objects#download-object-json
+        """
+        try:
+            gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
+                kwargs=kwargs
+            )
+            headers = await self.construct_request_headers(
+                vertex_instance=gcs_logging_config["vertex_instance"],
+                service_account_json=gcs_logging_config["path_service_account"],
+            )
+            bucket_name = gcs_logging_config["bucket_name"]
+            bucket_name, object_name = self._handle_folders_in_bucket_name(
+                bucket_name=bucket_name,
+                object_name=object_name,
+            )
+
+            url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
+
+            # Send the GET request to download the object
+            response = await self.async_httpx_client.get(url=url, headers=headers)
+
+            if response.status_code != 200:
+                verbose_logger.error(
+                    "GCS object download error: %s", str(response.text)
+                )
+                return None
+
+            verbose_logger.debug(
+                "GCS object download response status code: %s", response.status_code
+            )
+
+            # Return the content of the downloaded object
+            return response.content
+
+        except Exception as e:
+            verbose_logger.error("GCS object download error: %s", str(e))
+            return None
+
+    async def delete_gcs_object(self, object_name: str, **kwargs):
+        """
+        Delete an object from GCS.
+        """
+        try:
+            gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
+                kwargs=kwargs
+            )
+            headers = await self.construct_request_headers(
+                vertex_instance=gcs_logging_config["vertex_instance"],
+                service_account_json=gcs_logging_config["path_service_account"],
+            )
+            bucket_name = gcs_logging_config["bucket_name"]
+            bucket_name, object_name = self._handle_folders_in_bucket_name(
+                bucket_name=bucket_name,
+                object_name=object_name,
+            )
+
+            url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}"
+
+            # Send the DELETE request to delete the object
+            response = await self.async_httpx_client.delete(url=url, headers=headers)
+
+            if (response.status_code != 200) or (response.status_code != 204):
+                verbose_logger.error(
+                    "GCS object delete error: %s, status code: %s",
+                    str(response.text),
+                    response.status_code,
+                )
+                return None
+
+            verbose_logger.debug(
+                "GCS object delete response status code: %s, response: %s",
+                response.status_code,
+                response.text,
+            )
+
+            # Return the content of the downloaded object
+            return response.text
+
+        except Exception as e:
+            verbose_logger.error("GCS object download error: %s", str(e))
+            return None
+
+    async def _log_json_data_on_gcs(
+        self,
+        headers: Dict[str, str],
+        bucket_name: str,
+        object_name: str,
+        logging_payload: Union[StandardLoggingPayload, str],
+    ):
+        """
+        Helper function to make POST request to GCS Bucket in the specified bucket.
+        """
+        if isinstance(logging_payload, str):
+            json_logged_payload = logging_payload
+        else:
+            json_logged_payload = json.dumps(logging_payload, default=str)
+
+        bucket_name, object_name = self._handle_folders_in_bucket_name(
+            bucket_name=bucket_name,
+            object_name=object_name,
+        )
+
+        response = await self.async_httpx_client.post(
+            headers=headers,
+            url=f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}",
+            data=json_logged_payload,
+        )
+
+        if response.status_code != 200:
+            verbose_logger.error("GCS Bucket logging error: %s", str(response.text))
+
+        verbose_logger.debug("GCS Bucket response %s", response)
+        verbose_logger.debug("GCS Bucket status code %s", response.status_code)
+        verbose_logger.debug("GCS Bucket response.text %s", response.text)
+
+        return response.json()