diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/integrations/gcs_pubsub')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/integrations/gcs_pubsub/pub_sub.py | 203 |
1 files changed, 203 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_pubsub/pub_sub.py b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_pubsub/pub_sub.py new file mode 100644 index 00000000..e94c853f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/integrations/gcs_pubsub/pub_sub.py @@ -0,0 +1,203 @@ +""" +BETA + +This is the PubSub logger for GCS PubSub, this sends LiteLLM SpendLogs Payloads to GCS PubSub. + +Users can use this instead of sending their SpendLogs to their Postgres database. +""" + +import asyncio +import json +import os +import traceback +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +if TYPE_CHECKING: + from litellm.proxy._types import SpendLogsPayload +else: + SpendLogsPayload = Any + +from litellm._logging import verbose_logger +from litellm.integrations.custom_batch_logger import CustomBatchLogger +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) + + +class GcsPubSubLogger(CustomBatchLogger): + def __init__( + self, + project_id: Optional[str] = None, + topic_id: Optional[str] = None, + credentials_path: Optional[str] = None, + **kwargs, + ): + """ + Initialize Google Cloud Pub/Sub publisher + + Args: + project_id (str): Google Cloud project ID + topic_id (str): Pub/Sub topic ID + credentials_path (str, optional): Path to Google Cloud credentials JSON file + """ + from litellm.proxy.utils import _premium_user_check + + _premium_user_check() + + self.async_httpx_client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.LoggingCallback + ) + + self.project_id = project_id or os.getenv("GCS_PUBSUB_PROJECT_ID") + self.topic_id = topic_id or os.getenv("GCS_PUBSUB_TOPIC_ID") + self.path_service_account_json = credentials_path or os.getenv( + "GCS_PATH_SERVICE_ACCOUNT" + ) + + if not self.project_id or not self.topic_id: + raise ValueError("Both project_id and topic_id must be provided") + + self.flush_lock = asyncio.Lock() + super().__init__(**kwargs, flush_lock=self.flush_lock) + asyncio.create_task(self.periodic_flush()) + self.log_queue: List[SpendLogsPayload] = [] + + async def construct_request_headers(self) -> Dict[str, str]: + """Construct authorization headers using Vertex AI auth""" + from litellm import vertex_chat_completion + + _auth_header, vertex_project = ( + await vertex_chat_completion._ensure_access_token_async( + credentials=self.path_service_account_json, + project_id=None, + custom_llm_provider="vertex_ai", + ) + ) + + auth_header, _ = vertex_chat_completion._get_token_and_url( + model="pub-sub", + auth_header=_auth_header, + vertex_credentials=self.path_service_account_json, + vertex_project=vertex_project, + vertex_location=None, + gemini_api_key=None, + stream=None, + custom_llm_provider="vertex_ai", + api_base=None, + ) + + headers = { + "Authorization": f"Bearer {auth_header}", + "Content-Type": "application/json", + } + return headers + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + """ + Async Log success events to GCS PubSub Topic + + - Creates a SpendLogsPayload + - Adds to batch queue + - Flushes based on CustomBatchLogger settings + + Raises: + Raises a NON Blocking verbose_logger.exception if an error occurs + """ + from litellm.proxy.spend_tracking.spend_tracking_utils import ( + get_logging_payload, + ) + from litellm.proxy.utils import _premium_user_check + + _premium_user_check() + + try: + verbose_logger.debug( + "PubSub: Logging - Enters logging function for model %s", kwargs + ) + spend_logs_payload = get_logging_payload( + kwargs=kwargs, + response_obj=response_obj, + start_time=start_time, + end_time=end_time, + ) + self.log_queue.append(spend_logs_payload) + + if len(self.log_queue) >= self.batch_size: + await self.async_send_batch() + + except Exception as e: + verbose_logger.exception( + f"PubSub Layer Error - {str(e)}\n{traceback.format_exc()}" + ) + pass + + async def async_send_batch(self): + """ + Sends the batch of messages to Pub/Sub + """ + try: + if not self.log_queue: + return + + verbose_logger.debug( + f"PubSub - about to flush {len(self.log_queue)} events" + ) + + for message in self.log_queue: + await self.publish_message(message) + + except Exception as e: + verbose_logger.exception( + f"PubSub Error sending batch - {str(e)}\n{traceback.format_exc()}" + ) + finally: + self.log_queue.clear() + + async def publish_message( + self, message: SpendLogsPayload + ) -> Optional[Dict[str, Any]]: + """ + Publish message to Google Cloud Pub/Sub using REST API + + Args: + message: Message to publish (dict or string) + + Returns: + dict: Published message response + """ + try: + headers = await self.construct_request_headers() + + # Prepare message data + if isinstance(message, str): + message_data = message + else: + message_data = json.dumps(message, default=str) + + # Base64 encode the message + import base64 + + encoded_message = base64.b64encode(message_data.encode("utf-8")).decode( + "utf-8" + ) + + # Construct request body + request_body = {"messages": [{"data": encoded_message}]} + + url = f"https://pubsub.googleapis.com/v1/projects/{self.project_id}/topics/{self.topic_id}:publish" + + response = await self.async_httpx_client.post( + url=url, headers=headers, json=request_body + ) + + if response.status_code not in [200, 202]: + verbose_logger.error("Pub/Sub publish error: %s", str(response.text)) + raise Exception(f"Failed to publish message: {response.text}") + + verbose_logger.debug("Pub/Sub response: %s", response.text) + return response.json() + + except Exception as e: + verbose_logger.error("Pub/Sub publish error: %s", str(e)) + return None |