aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket.py237
1 files changed, 237 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket.py b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket.py
new file mode 100644
index 00000000..187ab779
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_bucket/gcs_bucket.py
@@ -0,0 +1,237 @@
+import asyncio
+import json
+import os
+import uuid
+from datetime import datetime, timedelta, timezone
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from urllib.parse import quote
+
+from litellm._logging import verbose_logger
+from litellm.integrations.additional_logging_utils import AdditionalLoggingUtils
+from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
+from litellm.proxy._types import CommonProxyErrors
+from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus
+from litellm.types.integrations.gcs_bucket import *
+from litellm.types.utils import StandardLoggingPayload
+
+if TYPE_CHECKING:
+ from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
+else:
+ VertexBase = Any
+
+
+GCS_DEFAULT_BATCH_SIZE = 2048
+GCS_DEFAULT_FLUSH_INTERVAL_SECONDS = 20
+
+
+class GCSBucketLogger(GCSBucketBase, AdditionalLoggingUtils):
+ def __init__(self, bucket_name: Optional[str] = None) -> None:
+ from litellm.proxy.proxy_server import premium_user
+
+ super().__init__(bucket_name=bucket_name)
+
+ # Init Batch logging settings
+ self.log_queue: List[GCSLogQueueItem] = []
+ self.batch_size = int(os.getenv("GCS_BATCH_SIZE", GCS_DEFAULT_BATCH_SIZE))
+ self.flush_interval = int(
+ os.getenv("GCS_FLUSH_INTERVAL", GCS_DEFAULT_FLUSH_INTERVAL_SECONDS)
+ )
+ asyncio.create_task(self.periodic_flush())
+ self.flush_lock = asyncio.Lock()
+ super().__init__(
+ flush_lock=self.flush_lock,
+ batch_size=self.batch_size,
+ flush_interval=self.flush_interval,
+ )
+ AdditionalLoggingUtils.__init__(self)
+
+ if premium_user is not True:
+ raise ValueError(
+ f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
+ )
+
+ #### ASYNC ####
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ from litellm.proxy.proxy_server import premium_user
+
+ if premium_user is not True:
+ raise ValueError(
+ f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}"
+ )
+ try:
+ verbose_logger.debug(
+ "GCS Logger: async_log_success_event logging kwargs: %s, response_obj: %s",
+ kwargs,
+ response_obj,
+ )
+ logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object", None
+ )
+ if logging_payload is None:
+ raise ValueError("standard_logging_object not found in kwargs")
+ # Add to logging queue - this will be flushed periodically
+ self.log_queue.append(
+ GCSLogQueueItem(
+ payload=logging_payload, kwargs=kwargs, response_obj=response_obj
+ )
+ )
+
+ except Exception as e:
+ verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
+
+ async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
+ try:
+ verbose_logger.debug(
+ "GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s",
+ kwargs,
+ response_obj,
+ )
+
+ logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object", None
+ )
+ if logging_payload is None:
+ raise ValueError("standard_logging_object not found in kwargs")
+ # Add to logging queue - this will be flushed periodically
+ self.log_queue.append(
+ GCSLogQueueItem(
+ payload=logging_payload, kwargs=kwargs, response_obj=response_obj
+ )
+ )
+
+ except Exception as e:
+ verbose_logger.exception(f"GCS Bucket logging error: {str(e)}")
+
+ async def async_send_batch(self):
+ """
+ Process queued logs in batch - sends logs to GCS Bucket
+
+
+ GCS Bucket does not have a Batch endpoint to batch upload logs
+
+ Instead, we
+ - collect the logs to flush every `GCS_FLUSH_INTERVAL` seconds
+ - during async_send_batch, we make 1 POST request per log to GCS Bucket
+
+ """
+ if not self.log_queue:
+ return
+
+ for log_item in self.log_queue:
+ logging_payload = log_item["payload"]
+ kwargs = log_item["kwargs"]
+ response_obj = log_item.get("response_obj", None) or {}
+
+ gcs_logging_config: GCSLoggingConfig = await self.get_gcs_logging_config(
+ kwargs
+ )
+ headers = await self.construct_request_headers(
+ vertex_instance=gcs_logging_config["vertex_instance"],
+ service_account_json=gcs_logging_config["path_service_account"],
+ )
+ bucket_name = gcs_logging_config["bucket_name"]
+ object_name = self._get_object_name(kwargs, logging_payload, response_obj)
+
+ try:
+ await self._log_json_data_on_gcs(
+ headers=headers,
+ bucket_name=bucket_name,
+ object_name=object_name,
+ logging_payload=logging_payload,
+ )
+ except Exception as e:
+ # don't let one log item fail the entire batch
+ verbose_logger.exception(
+ f"GCS Bucket error logging payload to GCS bucket: {str(e)}"
+ )
+ pass
+
+ # Clear the queue after processing
+ self.log_queue.clear()
+
+ def _get_object_name(
+ self, kwargs: Dict, logging_payload: StandardLoggingPayload, response_obj: Any
+ ) -> str:
+ """
+ Get the object name to use for the current payload
+ """
+ current_date = self._get_object_date_from_datetime(datetime.now(timezone.utc))
+ if logging_payload.get("error_str", None) is not None:
+ object_name = self._generate_failure_object_name(
+ request_date_str=current_date,
+ )
+ else:
+ object_name = self._generate_success_object_name(
+ request_date_str=current_date,
+ response_id=response_obj.get("id", ""),
+ )
+
+ # used for testing
+ _litellm_params = kwargs.get("litellm_params", None) or {}
+ _metadata = _litellm_params.get("metadata", None) or {}
+ if "gcs_log_id" in _metadata:
+ object_name = _metadata["gcs_log_id"]
+
+ return object_name
+
+ async def get_request_response_payload(
+ self,
+ request_id: str,
+ start_time_utc: Optional[datetime],
+ end_time_utc: Optional[datetime],
+ ) -> Optional[dict]:
+ """
+ Get the request and response payload for a given `request_id`
+ Tries current day, next day, and previous day until it finds the payload
+ """
+ if start_time_utc is None:
+ raise ValueError(
+ "start_time_utc is required for getting a payload from GCS Bucket"
+ )
+
+ # Try current day, next day, and previous day
+ dates_to_try = [
+ start_time_utc,
+ start_time_utc + timedelta(days=1),
+ start_time_utc - timedelta(days=1),
+ ]
+ date_str = None
+ for date in dates_to_try:
+ try:
+ date_str = self._get_object_date_from_datetime(datetime_obj=date)
+ object_name = self._generate_success_object_name(
+ request_date_str=date_str,
+ response_id=request_id,
+ )
+ encoded_object_name = quote(object_name, safe="")
+ response = await self.download_gcs_object(encoded_object_name)
+
+ if response is not None:
+ loaded_response = json.loads(response)
+ return loaded_response
+ except Exception as e:
+ verbose_logger.debug(
+ f"Failed to fetch payload for date {date_str}: {str(e)}"
+ )
+ continue
+
+ return None
+
+ def _generate_success_object_name(
+ self,
+ request_date_str: str,
+ response_id: str,
+ ) -> str:
+ return f"{request_date_str}/{response_id}"
+
+ def _generate_failure_object_name(
+ self,
+ request_date_str: str,
+ ) -> str:
+ return f"{request_date_str}/failure-{uuid.uuid4().hex}"
+
+ def _get_object_date_from_datetime(self, datetime_obj: datetime) -> str:
+ return datetime_obj.strftime("%Y-%m-%d")
+
+ async def async_health_check(self) -> IntegrationHealthCheckStatus:
+ raise NotImplementedError("GCS Bucket does not support health check")