diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/proxy/utils.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/utils.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/proxy/utils.py | 2913 |
1 files changed, 2913 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/utils.py new file mode 100644 index 00000000..0e7ae455 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/utils.py @@ -0,0 +1,2913 @@ +import asyncio +import copy +import hashlib +import json +import os +import smtplib +import threading +import time +import traceback +from datetime import datetime, timedelta +from email.mime.multipart import MIMEMultipart +from email.mime.text import MIMEText +from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union, overload + +from litellm.proxy._types import ( + DB_CONNECTION_ERROR_TYPES, + CommonProxyErrors, + ProxyErrorTypes, + ProxyException, +) +from litellm.types.guardrails import GuardrailEventHooks + +try: + import backoff +except ImportError: + raise ImportError( + "backoff is not installed. Please install it via 'pip install backoff'" + ) + +from fastapi import HTTPException, status + +import litellm +import litellm.litellm_core_utils +import litellm.litellm_core_utils.litellm_logging +from litellm import ( + EmbeddingResponse, + ImageResponse, + ModelResponse, + ModelResponseStream, + Router, +) +from litellm._logging import verbose_proxy_logger +from litellm._service_logger import ServiceLogging, ServiceTypes +from litellm.caching.caching import DualCache, RedisCache +from litellm.exceptions import RejectedRequestError +from litellm.integrations.custom_guardrail import CustomGuardrail +from litellm.integrations.custom_logger import CustomLogger +from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting +from litellm.integrations.SlackAlerting.utils import _add_langfuse_trace_id_to_alert +from litellm.litellm_core_utils.litellm_logging import Logging +from litellm.llms.custom_httpx.httpx_handler import HTTPHandler +from litellm.proxy._types import ( + AlertType, + CallInfo, + LiteLLM_VerificationTokenView, + Member, + UserAPIKeyAuth, +) +from litellm.proxy.db.create_views import ( + create_missing_views, + should_create_missing_views, +) +from litellm.proxy.db.log_db_metrics import log_db_metrics +from litellm.proxy.db.prisma_client import PrismaWrapper +from litellm.proxy.hooks.cache_control_check import _PROXY_CacheControlCheck +from litellm.proxy.hooks.max_budget_limiter import _PROXY_MaxBudgetLimiter +from litellm.proxy.hooks.parallel_request_limiter import ( + _PROXY_MaxParallelRequestsHandler, +) +from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup +from litellm.secret_managers.main import str_to_bool +from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES +from litellm.types.utils import CallTypes, LoggedLiteLLMParams + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + + +def print_verbose(print_statement): + """ + Prints the given `print_statement` to the console if `litellm.set_verbose` is True. + Also logs the `print_statement` at the debug level using `verbose_proxy_logger`. + + :param print_statement: The statement to be printed and logged. + :type print_statement: Any + """ + import traceback + + verbose_proxy_logger.debug("{}\n{}".format(print_statement, traceback.format_exc())) + if litellm.set_verbose: + print(f"LiteLLM Proxy: {print_statement}") # noqa + + +def safe_deep_copy(data): + """ + Safe Deep Copy + + The LiteLLM Request has some object that can-not be pickled / deep copied + + Use this function to safely deep copy the LiteLLM Request + """ + if litellm.safe_memory_mode is True: + return data + + litellm_parent_otel_span: Optional[Any] = None + # Step 1: Remove the litellm_parent_otel_span + litellm_parent_otel_span = None + if isinstance(data, dict): + # remove litellm_parent_otel_span since this is not picklable + if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]: + litellm_parent_otel_span = data["metadata"].pop("litellm_parent_otel_span") + new_data = copy.deepcopy(data) + + # Step 2: re-add the litellm_parent_otel_span after doing a deep copy + if isinstance(data, dict) and litellm_parent_otel_span is not None: + if "metadata" in data: + data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span + return new_data + + +class InternalUsageCache: + def __init__(self, dual_cache: DualCache): + self.dual_cache: DualCache = dual_cache + + async def async_get_cache( + self, + key, + litellm_parent_otel_span: Union[Span, None], + local_only: bool = False, + **kwargs, + ) -> Any: + return await self.dual_cache.async_get_cache( + key=key, + local_only=local_only, + parent_otel_span=litellm_parent_otel_span, + **kwargs, + ) + + async def async_set_cache( + self, + key, + value, + litellm_parent_otel_span: Union[Span, None], + local_only: bool = False, + **kwargs, + ) -> None: + return await self.dual_cache.async_set_cache( + key=key, + value=value, + local_only=local_only, + litellm_parent_otel_span=litellm_parent_otel_span, + **kwargs, + ) + + async def async_batch_set_cache( + self, + cache_list: List, + litellm_parent_otel_span: Union[Span, None], + local_only: bool = False, + **kwargs, + ) -> None: + return await self.dual_cache.async_set_cache_pipeline( + cache_list=cache_list, + local_only=local_only, + litellm_parent_otel_span=litellm_parent_otel_span, + **kwargs, + ) + + async def async_batch_get_cache( + self, + keys: list, + parent_otel_span: Optional[Span] = None, + local_only: bool = False, + ): + return await self.dual_cache.async_batch_get_cache( + keys=keys, + parent_otel_span=parent_otel_span, + local_only=local_only, + ) + + async def async_increment_cache( + self, + key, + value: float, + litellm_parent_otel_span: Union[Span, None], + local_only: bool = False, + **kwargs, + ): + return await self.dual_cache.async_increment_cache( + key=key, + value=value, + local_only=local_only, + parent_otel_span=litellm_parent_otel_span, + **kwargs, + ) + + def set_cache( + self, + key, + value, + local_only: bool = False, + **kwargs, + ) -> None: + return self.dual_cache.set_cache( + key=key, + value=value, + local_only=local_only, + **kwargs, + ) + + def get_cache( + self, + key, + local_only: bool = False, + **kwargs, + ) -> Any: + return self.dual_cache.get_cache( + key=key, + local_only=local_only, + **kwargs, + ) + + +### LOGGING ### +class ProxyLogging: + """ + Logging/Custom Handlers for proxy. + + Implemented mainly to: + - log successful/failed db read/writes + - support the max parallel request integration + """ + + def __init__( + self, + user_api_key_cache: DualCache, + premium_user: bool = False, + ): + ## INITIALIZE LITELLM CALLBACKS ## + self.call_details: dict = {} + self.call_details["user_api_key_cache"] = user_api_key_cache + self.internal_usage_cache: InternalUsageCache = InternalUsageCache( + dual_cache=DualCache(default_in_memory_ttl=1) # ping redis cache every 1s + ) + self.max_parallel_request_limiter = _PROXY_MaxParallelRequestsHandler( + self.internal_usage_cache + ) + self.max_budget_limiter = _PROXY_MaxBudgetLimiter() + self.cache_control_check = _PROXY_CacheControlCheck() + self.alerting: Optional[List] = None + self.alerting_threshold: float = 300 # default to 5 min. threshold + self.alert_types: List[AlertType] = DEFAULT_ALERT_TYPES + self.alert_to_webhook_url: Optional[dict] = None + self.slack_alerting_instance: SlackAlerting = SlackAlerting( + alerting_threshold=self.alerting_threshold, + alerting=self.alerting, + internal_usage_cache=self.internal_usage_cache.dual_cache, + ) + self.premium_user = premium_user + self.service_logging_obj = ServiceLogging() + + def startup_event( + self, + llm_router: Optional[Router], + redis_usage_cache: Optional[RedisCache], + ): + """Initialize logging and alerting on proxy startup""" + ## UPDATE SLACK ALERTING ## + self.slack_alerting_instance.update_values(llm_router=llm_router) + + ## UPDATE INTERNAL USAGE CACHE ## + self.update_values( + redis_cache=redis_usage_cache + ) # used by parallel request limiter for rate limiting keys across instances + + self._init_litellm_callbacks( + llm_router=llm_router + ) # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made + + if ( + self.slack_alerting_instance is not None + and "daily_reports" in self.slack_alerting_instance.alert_types + ): + asyncio.create_task( + self.slack_alerting_instance._run_scheduled_daily_report( + llm_router=llm_router + ) + ) # RUN DAILY REPORT (if scheduled) + + def update_values( + self, + alerting: Optional[List] = None, + alerting_threshold: Optional[float] = None, + redis_cache: Optional[RedisCache] = None, + alert_types: Optional[List[AlertType]] = None, + alerting_args: Optional[dict] = None, + alert_to_webhook_url: Optional[dict] = None, + ): + updated_slack_alerting: bool = False + if alerting is not None: + self.alerting = alerting + updated_slack_alerting = True + if alerting_threshold is not None: + self.alerting_threshold = alerting_threshold + updated_slack_alerting = True + if alert_types is not None: + self.alert_types = alert_types + updated_slack_alerting = True + if alert_to_webhook_url is not None: + self.alert_to_webhook_url = alert_to_webhook_url + updated_slack_alerting = True + + if updated_slack_alerting is True: + self.slack_alerting_instance.update_values( + alerting=self.alerting, + alerting_threshold=self.alerting_threshold, + alert_types=self.alert_types, + alerting_args=alerting_args, + alert_to_webhook_url=self.alert_to_webhook_url, + ) + + if self.alerting is not None and "slack" in self.alerting: + # NOTE: ENSURE we only add callbacks when alerting is on + # We should NOT add callbacks when alerting is off + if "daily_reports" in self.alert_types: + litellm.logging_callback_manager.add_litellm_callback(self.slack_alerting_instance) # type: ignore + litellm.logging_callback_manager.add_litellm_success_callback( + self.slack_alerting_instance.response_taking_too_long_callback + ) + + if redis_cache is not None: + self.internal_usage_cache.dual_cache.redis_cache = redis_cache + + def _init_litellm_callbacks(self, llm_router: Optional[Router] = None): + litellm.logging_callback_manager.add_litellm_callback(self.max_parallel_request_limiter) # type: ignore + litellm.logging_callback_manager.add_litellm_callback(self.max_budget_limiter) # type: ignore + litellm.logging_callback_manager.add_litellm_callback(self.cache_control_check) # type: ignore + litellm.logging_callback_manager.add_litellm_callback(self.service_logging_obj) # type: ignore + for callback in litellm.callbacks: + if isinstance(callback, str): + callback = litellm.litellm_core_utils.litellm_logging._init_custom_logger_compatible_class( # type: ignore + callback, + internal_usage_cache=self.internal_usage_cache.dual_cache, + llm_router=llm_router, + ) + if callback is None: + continue + if callback not in litellm.input_callback: + litellm.input_callback.append(callback) # type: ignore + if callback not in litellm.success_callback: + litellm.logging_callback_manager.add_litellm_success_callback(callback) # type: ignore + if callback not in litellm.failure_callback: + litellm.logging_callback_manager.add_litellm_failure_callback(callback) # type: ignore + if callback not in litellm._async_success_callback: + litellm.logging_callback_manager.add_litellm_async_success_callback(callback) # type: ignore + if callback not in litellm._async_failure_callback: + litellm.logging_callback_manager.add_litellm_async_failure_callback(callback) # type: ignore + if callback not in litellm.service_callback: + litellm.service_callback.append(callback) # type: ignore + + if ( + len(litellm.input_callback) > 0 + or len(litellm.success_callback) > 0 + or len(litellm.failure_callback) > 0 + ): + callback_list = list( + set( + litellm.input_callback + + litellm.success_callback + + litellm.failure_callback + ) + ) + litellm.litellm_core_utils.litellm_logging.set_callbacks( + callback_list=callback_list + ) + + async def update_request_status( + self, litellm_call_id: str, status: Literal["success", "fail"] + ): + # only use this if slack alerting is being used + if self.alerting is None: + return + + # current alerting threshold + alerting_threshold: float = self.alerting_threshold + + # add a 100 second buffer to the alerting threshold + # ensures we don't send errant hanging request slack alerts + alerting_threshold += 100 + + await self.internal_usage_cache.async_set_cache( + key="request_status:{}".format(litellm_call_id), + value=status, + local_only=True, + ttl=alerting_threshold, + litellm_parent_otel_span=None, + ) + + async def process_pre_call_hook_response(self, response, data, call_type): + if isinstance(response, Exception): + raise response + if isinstance(response, dict): + return response + if isinstance(response, str): + if call_type in ["completion", "text_completion"]: + raise RejectedRequestError( + message=response, + model=data.get("model", ""), + llm_provider="", + request_data=data, + ) + else: + raise HTTPException(status_code=400, detail={"error": response}) + return data + + # The actual implementation of the function + @overload + async def pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + data: None, + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "pass_through_endpoint", + "rerank", + ], + ) -> None: + pass + + @overload + async def pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + data: dict, + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "pass_through_endpoint", + "rerank", + ], + ) -> dict: + pass + + async def pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + data: Optional[dict], + call_type: Literal[ + "completion", + "text_completion", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + "pass_through_endpoint", + "rerank", + ], + ) -> Optional[dict]: + """ + Allows users to modify/reject the incoming request to the proxy, without having to deal with parsing Request body. + + Covers: + 1. /chat/completions + 2. /embeddings + 3. /image/generation + """ + verbose_proxy_logger.debug("Inside Proxy Logging Pre-call hook!") + + self._init_response_taking_too_long_task(data=data) + + if data is None: + return None + + try: + for callback in litellm.callbacks: + + _callback = None + if isinstance(callback, str): + _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( + callback + ) + else: + _callback = callback # type: ignore + if _callback is not None and isinstance(_callback, CustomGuardrail): + from litellm.types.guardrails import GuardrailEventHooks + + if ( + _callback.should_run_guardrail( + data=data, event_type=GuardrailEventHooks.pre_call + ) + is not True + ): + continue + + response = await _callback.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=self.call_details["user_api_key_cache"], + data=data, # type: ignore + call_type=call_type, + ) + if response is not None: + data = await self.process_pre_call_hook_response( + response=response, data=data, call_type=call_type + ) + + elif ( + _callback is not None + and isinstance(_callback, CustomLogger) + and "async_pre_call_hook" in vars(_callback.__class__) + and _callback.__class__.async_pre_call_hook + != CustomLogger.async_pre_call_hook + ): + response = await _callback.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=self.call_details["user_api_key_cache"], + data=data, # type: ignore + call_type=call_type, + ) + if response is not None: + data = await self.process_pre_call_hook_response( + response=response, data=data, call_type=call_type + ) + + return data + except Exception as e: + raise e + + async def during_call_hook( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: Literal[ + "completion", + "responses", + "embeddings", + "image_generation", + "moderation", + "audio_transcription", + ], + ): + """ + Runs the CustomGuardrail's async_moderation_hook() + """ + for callback in litellm.callbacks: + try: + if isinstance(callback, CustomGuardrail): + ################################################################ + # Check if guardrail should be run for GuardrailEventHooks.during_call hook + ################################################################ + + # V1 implementation - backwards compatibility + if callback.event_hook is None and hasattr( + callback, "moderation_check" + ): + if callback.moderation_check == "pre_call": # type: ignore + return + else: + # Main - V2 Guardrails implementation + from litellm.types.guardrails import GuardrailEventHooks + + if ( + callback.should_run_guardrail( + data=data, event_type=GuardrailEventHooks.during_call + ) + is not True + ): + continue + await callback.async_moderation_hook( + data=data, + user_api_key_dict=user_api_key_dict, + call_type=call_type, + ) + except Exception as e: + raise e + return data + + async def failed_tracking_alert( + self, + error_message: str, + failing_model: str, + ): + if self.alerting is None: + return + + if self.slack_alerting_instance: + await self.slack_alerting_instance.failed_tracking_alert( + error_message=error_message, + failing_model=failing_model, + ) + + async def budget_alerts( + self, + type: Literal[ + "token_budget", + "user_budget", + "soft_budget", + "team_budget", + "proxy_budget", + "projected_limit_exceeded", + ], + user_info: CallInfo, + ): + if self.alerting is None: + # do nothing if alerting is not switched on + return + await self.slack_alerting_instance.budget_alerts( + type=type, + user_info=user_info, + ) + + async def alerting_handler( + self, + message: str, + level: Literal["Low", "Medium", "High"], + alert_type: AlertType, + request_data: Optional[dict] = None, + ): + """ + Alerting based on thresholds: - https://github.com/BerriAI/litellm/issues/1298 + + - Responses taking too long + - Requests are hanging + - Calls are failing + - DB Read/Writes are failing + - Proxy Close to max budget + - Key Close to max budget + + Parameters: + level: str - Low|Medium|High - if calls might fail (Medium) or are failing (High); Currently, no alerts would be 'Low'. + message: str - what is the alert about + """ + if self.alerting is None: + return + + from datetime import datetime + + # Get the current timestamp + current_time = datetime.now().strftime("%H:%M:%S") + _proxy_base_url = os.getenv("PROXY_BASE_URL", None) + formatted_message = ( + f"Level: `{level}`\nTimestamp: `{current_time}`\n\nMessage: {message}" + ) + if _proxy_base_url is not None: + formatted_message += f"\n\nProxy URL: `{_proxy_base_url}`" + + extra_kwargs = {} + alerting_metadata = {} + if request_data is not None: + _url = await _add_langfuse_trace_id_to_alert(request_data=request_data) + + if _url is not None: + extra_kwargs["🪢 Langfuse Trace"] = _url + formatted_message += "\n\n🪢 Langfuse Trace: {}".format(_url) + if ( + "metadata" in request_data + and request_data["metadata"].get("alerting_metadata", None) is not None + and isinstance(request_data["metadata"]["alerting_metadata"], dict) + ): + alerting_metadata = request_data["metadata"]["alerting_metadata"] + for client in self.alerting: + if client == "slack": + await self.slack_alerting_instance.send_alert( + message=message, + level=level, + alert_type=alert_type, + user_info=None, + alerting_metadata=alerting_metadata, + **extra_kwargs, + ) + elif client == "sentry": + if litellm.utils.sentry_sdk_instance is not None: + litellm.utils.sentry_sdk_instance.capture_message(formatted_message) + else: + raise Exception("Missing SENTRY_DSN from environment") + + async def failure_handler( + self, original_exception, duration: float, call_type: str, traceback_str="" + ): + """ + Log failed db read/writes + + Currently only logs exceptions to sentry + """ + ### ALERTING ### + if AlertType.db_exceptions not in self.alert_types: + return + if isinstance(original_exception, HTTPException): + if isinstance(original_exception.detail, str): + error_message = original_exception.detail + elif isinstance(original_exception.detail, dict): + error_message = json.dumps(original_exception.detail) + else: + error_message = str(original_exception) + else: + error_message = str(original_exception) + if isinstance(traceback_str, str): + error_message += traceback_str[:1000] + asyncio.create_task( + self.alerting_handler( + message=f"DB read/write call failed: {error_message}", + level="High", + alert_type=AlertType.db_exceptions, + request_data={}, + ) + ) + + if hasattr(self, "service_logging_obj"): + await self.service_logging_obj.async_service_failure_hook( + service=ServiceTypes.DB, + duration=duration, + error=error_message, + call_type=call_type, + ) + + if litellm.utils.capture_exception: + litellm.utils.capture_exception(error=original_exception) + + async def post_call_failure_hook( + self, + request_data: dict, + original_exception: Exception, + user_api_key_dict: UserAPIKeyAuth, + error_type: Optional[ProxyErrorTypes] = None, + route: Optional[str] = None, + ): + """ + Allows users to raise custom exceptions/log when a call fails, without having to deal with parsing Request body. + + Covers: + 1. /chat/completions + 2. /embeddings + 3. /image/generation + """ + + ### ALERTING ### + await self.update_request_status( + litellm_call_id=request_data.get("litellm_call_id", ""), status="fail" + ) + if AlertType.llm_exceptions in self.alert_types and not isinstance( + original_exception, HTTPException + ): + """ + Just alert on LLM API exceptions. Do not alert on user errors + + Related issue - https://github.com/BerriAI/litellm/issues/3395 + """ + litellm_debug_info = getattr(original_exception, "litellm_debug_info", None) + exception_str = str(original_exception) + if litellm_debug_info is not None: + exception_str += litellm_debug_info + + asyncio.create_task( + self.alerting_handler( + message=f"LLM API call failed: `{exception_str}`", + level="High", + alert_type=AlertType.llm_exceptions, + request_data=request_data, + ) + ) + + ### LOGGING ### + if self._is_proxy_only_error( + original_exception=original_exception, error_type=error_type + ): + await self._handle_logging_proxy_only_error( + request_data=request_data, + user_api_key_dict=user_api_key_dict, + route=route, + original_exception=original_exception, + ) + + for callback in litellm.callbacks: + try: + _callback: Optional[CustomLogger] = None + if isinstance(callback, str): + _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( + callback + ) + else: + _callback = callback # type: ignore + if _callback is not None and isinstance(_callback, CustomLogger): + asyncio.create_task( + _callback.async_post_call_failure_hook( + request_data=request_data, + user_api_key_dict=user_api_key_dict, + original_exception=original_exception, + ) + ) + except Exception as e: + verbose_proxy_logger.exception( + f"[Non-Blocking] Error in post_call_failure_hook: {e}" + ) + return + + def _is_proxy_only_error( + self, + original_exception: Exception, + error_type: Optional[ProxyErrorTypes] = None, + ) -> bool: + """ + Return True if the error is a Proxy Only Error + + Prevents double logging of LLM API exceptions + + e.g should only return True for: + - Authentication Errors from user_api_key_auth + - HTTP HTTPException (rate limit errors) + """ + return isinstance(original_exception, HTTPException) or ( + error_type == ProxyErrorTypes.auth_error + ) + + async def _handle_logging_proxy_only_error( + self, + request_data: dict, + user_api_key_dict: UserAPIKeyAuth, + route: Optional[str] = None, + original_exception: Optional[Exception] = None, + ): + """ + Handle logging for proxy only errors by calling `litellm_logging_obj.async_failure_handler` + + Is triggered when self._is_proxy_only_error() returns True + """ + litellm_logging_obj: Optional[Logging] = request_data.get( + "litellm_logging_obj", None + ) + if litellm_logging_obj is None: + import uuid + + request_data["litellm_call_id"] = str(uuid.uuid4()) + user_api_key_logged_metadata = ( + LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key( + user_api_key_dict=user_api_key_dict + ) + ) + + litellm_logging_obj, data = litellm.utils.function_setup( + original_function=route or "IGNORE_THIS", + rules_obj=litellm.utils.Rules(), + start_time=datetime.now(), + **request_data, + ) + if "metadata" not in request_data: + request_data["metadata"] = {} + request_data["metadata"].update(user_api_key_logged_metadata) + + if litellm_logging_obj is not None: + ## UPDATE LOGGING INPUT + _optional_params = {} + _litellm_params = {} + + litellm_param_keys = LoggedLiteLLMParams.__annotations__.keys() + for k, v in request_data.items(): + if k in litellm_param_keys: + _litellm_params[k] = v + elif k != "model" and k != "user": + _optional_params[k] = v + + litellm_logging_obj.update_environment_variables( + model=request_data.get("model", ""), + user=request_data.get("user", ""), + optional_params=_optional_params, + litellm_params=_litellm_params, + ) + + input: Union[list, str, dict] = "" + if "messages" in request_data and isinstance( + request_data["messages"], list + ): + input = request_data["messages"] + litellm_logging_obj.model_call_details["messages"] = input + litellm_logging_obj.call_type = CallTypes.acompletion.value + elif "prompt" in request_data and isinstance(request_data["prompt"], str): + input = request_data["prompt"] + litellm_logging_obj.model_call_details["prompt"] = input + litellm_logging_obj.call_type = CallTypes.atext_completion.value + elif "input" in request_data and isinstance(request_data["input"], list): + input = request_data["input"] + litellm_logging_obj.model_call_details["input"] = input + litellm_logging_obj.call_type = CallTypes.aembedding.value + litellm_logging_obj.pre_call( + input=input, + api_key="", + ) + + # log the custom exception + await litellm_logging_obj.async_failure_handler( + exception=original_exception, + traceback_exception=traceback.format_exc(), + ) + + threading.Thread( + target=litellm_logging_obj.failure_handler, + args=( + original_exception, + traceback.format_exc(), + ), + ).start() + + async def post_call_success_hook( + self, + data: dict, + response: Union[ModelResponse, EmbeddingResponse, ImageResponse], + user_api_key_dict: UserAPIKeyAuth, + ): + """ + Allow user to modify outgoing data + + Covers: + 1. /chat/completions + """ + + for callback in litellm.callbacks: + try: + _callback: Optional[CustomLogger] = None + if isinstance(callback, str): + _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( + callback + ) + else: + _callback = callback # type: ignore + + if _callback is not None: + ############## Handle Guardrails ######################################## + ############################################################################# + if isinstance(callback, CustomGuardrail): + # Main - V2 Guardrails implementation + from litellm.types.guardrails import GuardrailEventHooks + + if ( + callback.should_run_guardrail( + data=data, event_type=GuardrailEventHooks.post_call + ) + is not True + ): + continue + + await callback.async_post_call_success_hook( + user_api_key_dict=user_api_key_dict, + data=data, + response=response, + ) + + ############ Handle CustomLogger ############################### + ################################################################# + elif isinstance(_callback, CustomLogger): + await _callback.async_post_call_success_hook( + user_api_key_dict=user_api_key_dict, + data=data, + response=response, + ) + except Exception as e: + raise e + return response + + async def async_post_call_streaming_hook( + self, + response: Union[ + ModelResponse, EmbeddingResponse, ImageResponse, ModelResponseStream + ], + user_api_key_dict: UserAPIKeyAuth, + ): + """ + Allow user to modify outgoing streaming data -> per chunk + + Covers: + 1. /chat/completions + """ + response_str: Optional[str] = None + if isinstance(response, (ModelResponse, ModelResponseStream)): + response_str = litellm.get_response_string(response_obj=response) + if response_str is not None: + for callback in litellm.callbacks: + try: + _callback: Optional[CustomLogger] = None + if isinstance(callback, str): + _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( + callback + ) + else: + _callback = callback # type: ignore + if _callback is not None and isinstance(_callback, CustomLogger): + await _callback.async_post_call_streaming_hook( + user_api_key_dict=user_api_key_dict, response=response_str + ) + except Exception as e: + raise e + return response + + def async_post_call_streaming_iterator_hook( + self, + response, + user_api_key_dict: UserAPIKeyAuth, + request_data: dict, + ): + """ + Allow user to modify outgoing streaming data -> Given a whole response iterator. + This hook is best used when you need to modify multiple chunks of the response at once. + + Covers: + 1. /chat/completions + """ + for callback in litellm.callbacks: + _callback: Optional[CustomLogger] = None + if isinstance(callback, str): + _callback = litellm.litellm_core_utils.litellm_logging.get_custom_logger_compatible_class( + callback + ) + else: + _callback = callback # type: ignore + if _callback is not None and isinstance(_callback, CustomLogger): + if not isinstance( + _callback, CustomGuardrail + ) or _callback.should_run_guardrail( + data=request_data, event_type=GuardrailEventHooks.post_call + ): + response = _callback.async_post_call_streaming_iterator_hook( + user_api_key_dict=user_api_key_dict, + response=response, + request_data=request_data, + ) + return response + + async def post_call_streaming_hook( + self, + response: str, + user_api_key_dict: UserAPIKeyAuth, + ): + """ + - Check outgoing streaming response uptil that point + - Run through moderation check + - Reject request if it fails moderation check + """ + new_response = copy.deepcopy(response) + for callback in litellm.callbacks: + try: + if isinstance(callback, CustomLogger): + await callback.async_post_call_streaming_hook( + user_api_key_dict=user_api_key_dict, response=new_response + ) + except Exception as e: + raise e + return new_response + + def _init_response_taking_too_long_task(self, data: Optional[dict] = None): + """ + Initialize the response taking too long task if user is using slack alerting + + Only run task if user is using slack alerting + + This handles checking for if a request is hanging for too long + """ + ## ALERTING ### + if ( + self.slack_alerting_instance + and self.slack_alerting_instance.alerting is not None + ): + asyncio.create_task( + self.slack_alerting_instance.response_taking_too_long(request_data=data) + ) + + +### DB CONNECTOR ### +# Define the retry decorator with backoff strategy +# Function to be called whenever a retry is about to happen +def on_backoff(details): + # The 'tries' key in the details dictionary contains the number of completed tries + print_verbose(f"Backing off... this was attempt #{details['tries']}") + + +def jsonify_object(data: dict) -> dict: + db_data = copy.deepcopy(data) + + for k, v in db_data.items(): + if isinstance(v, dict): + try: + db_data[k] = json.dumps(v) + except Exception: + # This avoids Prisma retrying this 5 times, and making 5 clients + db_data[k] = "failed-to-serialize-json" + return db_data + + +class PrismaClient: + user_list_transactons: dict = {} + end_user_list_transactons: dict = {} + key_list_transactons: dict = {} + team_list_transactons: dict = {} + team_member_list_transactons: dict = {} # key is ["team_id" + "user_id"] + org_list_transactons: dict = {} + spend_log_transactions: List = [] + + def __init__( + self, + database_url: str, + proxy_logging_obj: ProxyLogging, + http_client: Optional[Any] = None, + ): + ## init logging object + self.proxy_logging_obj = proxy_logging_obj + self.iam_token_db_auth: Optional[bool] = str_to_bool( + os.getenv("IAM_TOKEN_DB_AUTH") + ) + verbose_proxy_logger.debug("Creating Prisma Client..") + try: + from prisma import Prisma # type: ignore + except Exception: + raise Exception("Unable to find Prisma binaries.") + if http_client is not None: + self.db = PrismaWrapper( + original_prisma=Prisma(http=http_client), + iam_token_db_auth=( + self.iam_token_db_auth + if self.iam_token_db_auth is not None + else False + ), + ) + else: + self.db = PrismaWrapper( + original_prisma=Prisma(), + iam_token_db_auth=( + self.iam_token_db_auth + if self.iam_token_db_auth is not None + else False + ), + ) # Client to connect to Prisma db + verbose_proxy_logger.debug("Success - Created Prisma Client") + + def hash_token(self, token: str): + # Hash the string using SHA-256 + hashed_token = hashlib.sha256(token.encode()).hexdigest() + + return hashed_token + + def jsonify_object(self, data: dict) -> dict: + db_data = copy.deepcopy(data) + + for k, v in db_data.items(): + if isinstance(v, dict): + try: + db_data[k] = json.dumps(v) + except Exception: + # This avoids Prisma retrying this 5 times, and making 5 clients + db_data[k] = "failed-to-serialize-json" + return db_data + + @backoff.on_exception( + backoff.expo, + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for + on_backoff=on_backoff, # specifying the function to call on backoff + ) + async def check_view_exists(self): + """ + Checks if the LiteLLM_VerificationTokenView and MonthlyGlobalSpend exists in the user's db. + + LiteLLM_VerificationTokenView: This view is used for getting the token + team data in user_api_key_auth + + MonthlyGlobalSpend: This view is used for the admin view to see global spend for this month + + If the view doesn't exist, one will be created. + """ + + # Check to see if all of the necessary views exist and if they do, simply return + # This is more efficient because it lets us check for all views in one + # query instead of multiple queries. + try: + expected_views = [ + "LiteLLM_VerificationTokenView", + "MonthlyGlobalSpend", + "Last30dKeysBySpend", + "Last30dModelsBySpend", + "MonthlyGlobalSpendPerKey", + "MonthlyGlobalSpendPerUserPerKey", + "Last30dTopEndUsersSpend", + "DailyTagSpend", + ] + required_view = "LiteLLM_VerificationTokenView" + expected_views_str = ", ".join(f"'{view}'" for view in expected_views) + pg_schema = os.getenv("DATABASE_SCHEMA", "public") + ret = await self.db.query_raw( + f""" + WITH existing_views AS ( + SELECT viewname + FROM pg_views + WHERE schemaname = '{pg_schema}' AND viewname IN ( + {expected_views_str} + ) + ) + SELECT + (SELECT COUNT(*) FROM existing_views) AS view_count, + ARRAY_AGG(viewname) AS view_names + FROM existing_views + """ + ) + expected_total_views = len(expected_views) + if ret[0]["view_count"] == expected_total_views: + verbose_proxy_logger.info("All necessary views exist!") + return + else: + ## check if required view exists ## + if ret[0]["view_names"] and required_view not in ret[0]["view_names"]: + await self.health_check() # make sure we can connect to db + await self.db.execute_raw( + """ + CREATE VIEW "LiteLLM_VerificationTokenView" AS + SELECT + v.*, + t.spend AS team_spend, + t.max_budget AS team_max_budget, + t.tpm_limit AS team_tpm_limit, + t.rpm_limit AS team_rpm_limit + FROM "LiteLLM_VerificationToken" v + LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id; + """ + ) + + verbose_proxy_logger.info( + "LiteLLM_VerificationTokenView Created in DB!" + ) + else: + should_create_views = await should_create_missing_views(db=self.db) + if should_create_views: + await create_missing_views(db=self.db) + else: + # don't block execution if these views are missing + # Convert lists to sets for efficient difference calculation + ret_view_names_set = ( + set(ret[0]["view_names"]) if ret[0]["view_names"] else set() + ) + expected_views_set = set(expected_views) + # Find missing views + missing_views = expected_views_set - ret_view_names_set + + verbose_proxy_logger.warning( + "\n\n\033[93mNot all views exist in db, needed for UI 'Usage' tab. Missing={}.\nRun 'create_views.py' from https://github.com/BerriAI/litellm/tree/main/db_scripts to create missing views.\033[0m\n".format( + missing_views + ) + ) + + except Exception: + raise + return + + @log_db_metrics + @backoff.on_exception( + backoff.expo, + Exception, # base exception to catch for the backoff + max_tries=1, # maximum number of retries + max_time=2, # maximum total time to retry for + on_backoff=on_backoff, # specifying the function to call on backoff + ) + async def get_generic_data( + self, + key: str, + value: Any, + table_name: Literal["users", "keys", "config", "spend"], + ): + """ + Generic implementation of get data + """ + start_time = time.time() + try: + if table_name == "users": + response = await self.db.litellm_usertable.find_first( + where={key: value} # type: ignore + ) + elif table_name == "keys": + response = await self.db.litellm_verificationtoken.find_first( # type: ignore + where={key: value} # type: ignore + ) + elif table_name == "config": + response = await self.db.litellm_config.find_first( # type: ignore + where={key: value} # type: ignore + ) + elif table_name == "spend": + response = await self.db.l.find_first( # type: ignore + where={key: value} # type: ignore + ) + return response + except Exception as e: + import traceback + + error_msg = f"LiteLLM Prisma Client Exception get_generic_data: {str(e)}" + verbose_proxy_logger.error(error_msg) + error_msg = error_msg + "\nException Type: {}".format(type(e)) + error_traceback = error_msg + "\n" + traceback.format_exc() + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.proxy_logging_obj.failure_handler( + original_exception=e, + duration=_duration, + traceback_str=error_traceback, + call_type="get_generic_data", + ) + ) + + raise e + + @backoff.on_exception( + backoff.expo, + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for + on_backoff=on_backoff, # specifying the function to call on backoff + ) + @log_db_metrics + async def get_data( # noqa: PLR0915 + self, + token: Optional[Union[str, list]] = None, + user_id: Optional[str] = None, + user_id_list: Optional[list] = None, + team_id: Optional[str] = None, + team_id_list: Optional[list] = None, + key_val: Optional[dict] = None, + table_name: Optional[ + Literal[ + "user", + "key", + "config", + "spend", + "team", + "user_notification", + "combined_view", + ] + ] = None, + query_type: Literal["find_unique", "find_all"] = "find_unique", + expires: Optional[datetime] = None, + reset_at: Optional[datetime] = None, + offset: Optional[int] = None, # pagination, what row number to start from + limit: Optional[ + int + ] = None, # pagination, number of rows to getch when find_all==True + parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, + ): + args_passed_in = locals() + start_time = time.time() + hashed_token: Optional[str] = None + try: + response: Any = None + if (token is not None and table_name is None) or ( + table_name is not None and table_name == "key" + ): + # check if plain text or hash + if token is not None: + if isinstance(token, str): + hashed_token = _hash_token_if_needed(token=token) + verbose_proxy_logger.debug( + f"PrismaClient: find_unique for token: {hashed_token}" + ) + if query_type == "find_unique" and hashed_token is not None: + if token is None: + raise HTTPException( + status_code=400, + detail={"error": f"No token passed in. Token={token}"}, + ) + response = await self.db.litellm_verificationtoken.find_unique( + where={"token": hashed_token}, # type: ignore + include={"litellm_budget_table": True}, + ) + if response is not None: + # for prisma we need to cast the expires time to str + if response.expires is not None and isinstance( + response.expires, datetime + ): + response.expires = response.expires.isoformat() + else: + # Token does not exist. + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"Authentication Error: invalid user key - user key does not exist in db. User Key={token}", + ) + elif query_type == "find_all" and user_id is not None: + response = await self.db.litellm_verificationtoken.find_many( + where={"user_id": user_id}, + include={"litellm_budget_table": True}, + ) + if response is not None and len(response) > 0: + for r in response: + if isinstance(r.expires, datetime): + r.expires = r.expires.isoformat() + elif query_type == "find_all" and team_id is not None: + response = await self.db.litellm_verificationtoken.find_many( + where={"team_id": team_id}, + include={"litellm_budget_table": True}, + ) + if response is not None and len(response) > 0: + for r in response: + if isinstance(r.expires, datetime): + r.expires = r.expires.isoformat() + elif ( + query_type == "find_all" + and expires is not None + and reset_at is not None + ): + response = await self.db.litellm_verificationtoken.find_many( + where={ # type:ignore + "OR": [ + {"expires": None}, + {"expires": {"gt": expires}}, + ], + "budget_reset_at": {"lt": reset_at}, + } + ) + if response is not None and len(response) > 0: + for r in response: + if isinstance(r.expires, datetime): + r.expires = r.expires.isoformat() + elif query_type == "find_all": + where_filter: dict = {} + if token is not None: + where_filter["token"] = {} + if isinstance(token, str): + token = _hash_token_if_needed(token=token) + where_filter["token"]["in"] = [token] + elif isinstance(token, list): + hashed_tokens = [] + for t in token: + assert isinstance(t, str) + if t.startswith("sk-"): + new_token = self.hash_token(token=t) + hashed_tokens.append(new_token) + else: + hashed_tokens.append(t) + where_filter["token"]["in"] = hashed_tokens + response = await self.db.litellm_verificationtoken.find_many( + order={"spend": "desc"}, + where=where_filter, # type: ignore + include={"litellm_budget_table": True}, + ) + if response is not None: + return response + else: + # Token does not exist. + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication Error: invalid user key - token does not exist", + ) + elif (user_id is not None and table_name is None) or ( + table_name is not None and table_name == "user" + ): + if query_type == "find_unique": + if key_val is None: + key_val = {"user_id": user_id} + response = await self.db.litellm_usertable.find_unique( # type: ignore + where=key_val, # type: ignore + include={"organization_memberships": True}, + ) + elif query_type == "find_all" and key_val is not None: + response = await self.db.litellm_usertable.find_many( + where=key_val # type: ignore + ) # type: ignore + elif query_type == "find_all" and reset_at is not None: + response = await self.db.litellm_usertable.find_many( + where={ # type:ignore + "budget_reset_at": {"lt": reset_at}, + } + ) + elif query_type == "find_all" and user_id_list is not None: + response = await self.db.litellm_usertable.find_many( + where={"user_id": {"in": user_id_list}} + ) + elif query_type == "find_all": + if expires is not None: + response = await self.db.litellm_usertable.find_many( # type: ignore + order={"spend": "desc"}, + where={ # type:ignore + "OR": [ + {"expires": None}, # type:ignore + {"expires": {"gt": expires}}, # type:ignore + ], + }, + ) + else: + # return all users in the table, get their key aliases ordered by spend + sql_query = """ + SELECT + u.*, + json_agg(v.key_alias) AS key_aliases + FROM + "LiteLLM_UserTable" u + LEFT JOIN "LiteLLM_VerificationToken" v ON u.user_id = v.user_id + GROUP BY + u.user_id + ORDER BY u.spend DESC + LIMIT $1 + OFFSET $2 + """ + response = await self.db.query_raw(sql_query, limit, offset) + return response + elif table_name == "spend": + verbose_proxy_logger.debug( + "PrismaClient: get_data: table_name == 'spend'" + ) + if key_val is not None: + if query_type == "find_unique": + response = await self.db.litellm_spendlogs.find_unique( # type: ignore + where={ # type: ignore + key_val["key"]: key_val["value"], # type: ignore + } + ) + elif query_type == "find_all": + response = await self.db.litellm_spendlogs.find_many( # type: ignore + where={ + key_val["key"]: key_val["value"], # type: ignore + } + ) + return response + else: + response = await self.db.litellm_spendlogs.find_many( # type: ignore + order={"startTime": "desc"}, + ) + return response + elif table_name == "team": + if query_type == "find_unique": + response = await self.db.litellm_teamtable.find_unique( + where={"team_id": team_id}, # type: ignore + include={"litellm_model_table": True}, # type: ignore + ) + elif query_type == "find_all" and reset_at is not None: + response = await self.db.litellm_teamtable.find_many( + where={ # type:ignore + "budget_reset_at": {"lt": reset_at}, + } + ) + elif query_type == "find_all" and user_id is not None: + response = await self.db.litellm_teamtable.find_many( + where={ + "members": {"has": user_id}, + }, + include={"litellm_budget_table": True}, + ) + elif query_type == "find_all" and team_id_list is not None: + response = await self.db.litellm_teamtable.find_many( + where={"team_id": {"in": team_id_list}} + ) + elif query_type == "find_all" and team_id_list is None: + response = await self.db.litellm_teamtable.find_many(take=20) + return response + elif table_name == "user_notification": + if query_type == "find_unique": + response = await self.db.litellm_usernotifications.find_unique( # type: ignore + where={"user_id": user_id} # type: ignore + ) + elif query_type == "find_all": + response = await self.db.litellm_usernotifications.find_many() # type: ignore + return response + elif table_name == "combined_view": + # check if plain text or hash + if token is not None: + if isinstance(token, str): + hashed_token = _hash_token_if_needed(token=token) + verbose_proxy_logger.debug( + f"PrismaClient: find_unique for token: {hashed_token}" + ) + if query_type == "find_unique": + if token is None: + raise HTTPException( + status_code=400, + detail={"error": f"No token passed in. Token={token}"}, + ) + + sql_query = f""" + SELECT + v.*, + t.spend AS team_spend, + t.max_budget AS team_max_budget, + t.tpm_limit AS team_tpm_limit, + t.rpm_limit AS team_rpm_limit, + t.models AS team_models, + t.metadata AS team_metadata, + t.blocked AS team_blocked, + t.team_alias AS team_alias, + t.metadata AS team_metadata, + t.members_with_roles AS team_members_with_roles, + t.organization_id as org_id, + tm.spend AS team_member_spend, + m.aliases AS team_model_aliases, + -- Added comma to separate b.* columns + b.max_budget AS litellm_budget_table_max_budget, + b.tpm_limit AS litellm_budget_table_tpm_limit, + b.rpm_limit AS litellm_budget_table_rpm_limit, + b.model_max_budget as litellm_budget_table_model_max_budget, + b.soft_budget as litellm_budget_table_soft_budget + FROM "LiteLLM_VerificationToken" AS v + LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id + LEFT JOIN "LiteLLM_TeamMembership" AS tm ON v.team_id = tm.team_id AND tm.user_id = v.user_id + LEFT JOIN "LiteLLM_ModelTable" m ON t.model_id = m.id + LEFT JOIN "LiteLLM_BudgetTable" AS b ON v.budget_id = b.budget_id + WHERE v.token = '{token}' + """ + + print_verbose("sql_query being made={}".format(sql_query)) + response = await self.db.query_first(query=sql_query) + + if response is not None: + if response["team_models"] is None: + response["team_models"] = [] + if response["team_blocked"] is None: + response["team_blocked"] = False + + team_member: Optional[Member] = None + if ( + response["team_members_with_roles"] is not None + and response["user_id"] is not None + ): + ## find the team member corresponding to user id + """ + [ + { + "role": "admin", + "user_id": "default_user_id", + "user_email": null + }, + { + "role": "user", + "user_id": null, + "user_email": "test@email.com" + } + ] + """ + for tm in response["team_members_with_roles"]: + if tm.get("user_id") is not None and response[ + "user_id" + ] == tm.get("user_id"): + team_member = Member(**tm) + response["team_member"] = team_member + response = LiteLLM_VerificationTokenView( + **response, last_refreshed_at=time.time() + ) + # for prisma we need to cast the expires time to str + if response.expires is not None and isinstance( + response.expires, datetime + ): + response.expires = response.expires.isoformat() + return response + except Exception as e: + import traceback + + prisma_query_info = f"LiteLLM Prisma Client Exception: Error with `get_data`. Args passed in: {args_passed_in}" + error_msg = prisma_query_info + str(e) + print_verbose(error_msg) + error_traceback = error_msg + "\n" + traceback.format_exc() + verbose_proxy_logger.debug(error_traceback) + end_time = time.time() + _duration = end_time - start_time + + asyncio.create_task( + self.proxy_logging_obj.failure_handler( + original_exception=e, + duration=_duration, + call_type="get_data", + traceback_str=error_traceback, + ) + ) + raise e + + def jsonify_team_object(self, db_data: dict): + db_data = self.jsonify_object(data=db_data) + if db_data.get("members_with_roles", None) is not None and isinstance( + db_data["members_with_roles"], list + ): + db_data["members_with_roles"] = json.dumps(db_data["members_with_roles"]) + return db_data + + # Define a retrying strategy with exponential backoff + @backoff.on_exception( + backoff.expo, + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for + on_backoff=on_backoff, # specifying the function to call on backoff + ) + async def insert_data( # noqa: PLR0915 + self, + data: dict, + table_name: Literal[ + "user", "key", "config", "spend", "team", "user_notification" + ], + ): + """ + Add a key to the database. If it already exists, do nothing. + """ + start_time = time.time() + try: + verbose_proxy_logger.debug("PrismaClient: insert_data: %s", data) + if table_name == "key": + token = data["token"] + hashed_token = self.hash_token(token=token) + db_data = self.jsonify_object(data=data) + db_data["token"] = hashed_token + print_verbose( + "PrismaClient: Before upsert into litellm_verificationtoken" + ) + new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore + where={ + "token": hashed_token, + }, + data={ + "create": {**db_data}, # type: ignore + "update": {}, # don't do anything if it already exists + }, + include={"litellm_budget_table": True}, + ) + verbose_proxy_logger.info("Data Inserted into Keys Table") + return new_verification_token + elif table_name == "user": + db_data = self.jsonify_object(data=data) + try: + new_user_row = await self.db.litellm_usertable.upsert( + where={"user_id": data["user_id"]}, + data={ + "create": {**db_data}, # type: ignore + "update": {}, # don't do anything if it already exists + }, + ) + except Exception as e: + if ( + "Foreign key constraint failed on the field: `LiteLLM_UserTable_organization_id_fkey (index)`" + in str(e) + ): + raise HTTPException( + status_code=400, + detail={ + "error": f"Foreign Key Constraint failed. Organization ID={db_data['organization_id']} does not exist in LiteLLM_OrganizationTable. Create via `/organization/new`." + }, + ) + raise e + verbose_proxy_logger.info("Data Inserted into User Table") + return new_user_row + elif table_name == "team": + db_data = self.jsonify_team_object(db_data=data) + new_team_row = await self.db.litellm_teamtable.upsert( + where={"team_id": data["team_id"]}, + data={ + "create": {**db_data}, # type: ignore + "update": {}, # don't do anything if it already exists + }, + ) + verbose_proxy_logger.info("Data Inserted into Team Table") + return new_team_row + elif table_name == "config": + """ + For each param, + get the existing table values + + Add the new values + + Update DB + """ + tasks = [] + for k, v in data.items(): + updated_data = v + updated_data = json.dumps(updated_data) + updated_table_row = self.db.litellm_config.upsert( + where={"param_name": k}, # type: ignore + data={ + "create": {"param_name": k, "param_value": updated_data}, # type: ignore + "update": {"param_value": updated_data}, + }, + ) + + tasks.append(updated_table_row) + await asyncio.gather(*tasks) + verbose_proxy_logger.info("Data Inserted into Config Table") + elif table_name == "spend": + db_data = self.jsonify_object(data=data) + new_spend_row = await self.db.litellm_spendlogs.upsert( + where={"request_id": data["request_id"]}, + data={ + "create": {**db_data}, # type: ignore + "update": {}, # don't do anything if it already exists + }, + ) + verbose_proxy_logger.info("Data Inserted into Spend Table") + return new_spend_row + elif table_name == "user_notification": + db_data = self.jsonify_object(data=data) + new_user_notification_row = ( + await self.db.litellm_usernotifications.upsert( # type: ignore + where={"request_id": data["request_id"]}, + data={ + "create": {**db_data}, # type: ignore + "update": {}, # don't do anything if it already exists + }, + ) + ) + verbose_proxy_logger.info("Data Inserted into Model Request Table") + return new_user_notification_row + + except Exception as e: + import traceback + + error_msg = f"LiteLLM Prisma Client Exception in insert_data: {str(e)}" + print_verbose(error_msg) + error_traceback = error_msg + "\n" + traceback.format_exc() + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.proxy_logging_obj.failure_handler( + original_exception=e, + duration=_duration, + call_type="insert_data", + traceback_str=error_traceback, + ) + ) + raise e + + # Define a retrying strategy with exponential backoff + @backoff.on_exception( + backoff.expo, + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for + on_backoff=on_backoff, # specifying the function to call on backoff + ) + async def update_data( # noqa: PLR0915 + self, + token: Optional[str] = None, + data: dict = {}, + data_list: Optional[List] = None, + user_id: Optional[str] = None, + team_id: Optional[str] = None, + query_type: Literal["update", "update_many"] = "update", + table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None, + update_key_values: Optional[dict] = None, + update_key_values_custom_query: Optional[dict] = None, + ): + """ + Update existing data + """ + verbose_proxy_logger.debug( + f"PrismaClient: update_data, table_name: {table_name}" + ) + start_time = time.time() + try: + db_data = self.jsonify_object(data=data) + if update_key_values is not None: + update_key_values = self.jsonify_object(data=update_key_values) + if token is not None: + print_verbose(f"token: {token}") + # check if plain text or hash + token = _hash_token_if_needed(token=token) + db_data["token"] = token + response = await self.db.litellm_verificationtoken.update( + where={"token": token}, # type: ignore + data={**db_data}, # type: ignore + ) + verbose_proxy_logger.debug( + "\033[91m" + + f"DB Token Table update succeeded {response}" + + "\033[0m" + ) + _data: dict = {} + if response is not None: + try: + _data = response.model_dump() # type: ignore + except Exception: + _data = response.dict() + return {"token": token, "data": _data} + elif ( + user_id is not None + or (table_name is not None and table_name == "user") + and query_type == "update" + ): + """ + If data['spend'] + data['user'], update the user table with spend info as well + """ + if user_id is None: + user_id = db_data["user_id"] + if update_key_values is None: + if update_key_values_custom_query is not None: + update_key_values = update_key_values_custom_query + else: + update_key_values = db_data + update_user_row = await self.db.litellm_usertable.upsert( + where={"user_id": user_id}, # type: ignore + data={ + "create": {**db_data}, # type: ignore + "update": { + **update_key_values # type: ignore + }, # just update user-specified values, if it already exists + }, + ) + verbose_proxy_logger.info( + "\033[91m" + + f"DB User Table - update succeeded {update_user_row}" + + "\033[0m" + ) + return {"user_id": user_id, "data": update_user_row} + elif ( + team_id is not None + or (table_name is not None and table_name == "team") + and query_type == "update" + ): + """ + If data['spend'] + data['user'], update the user table with spend info as well + """ + if team_id is None: + team_id = db_data["team_id"] + if update_key_values is None: + update_key_values = db_data + if "team_id" not in db_data and team_id is not None: + db_data["team_id"] = team_id + if "members_with_roles" in db_data and isinstance( + db_data["members_with_roles"], list + ): + db_data["members_with_roles"] = json.dumps( + db_data["members_with_roles"] + ) + if "members_with_roles" in update_key_values and isinstance( + update_key_values["members_with_roles"], list + ): + update_key_values["members_with_roles"] = json.dumps( + update_key_values["members_with_roles"] + ) + update_team_row = await self.db.litellm_teamtable.upsert( + where={"team_id": team_id}, # type: ignore + data={ + "create": {**db_data}, # type: ignore + "update": { + **update_key_values # type: ignore + }, # just update user-specified values, if it already exists + }, + ) + verbose_proxy_logger.info( + "\033[91m" + + f"DB Team Table - update succeeded {update_team_row}" + + "\033[0m" + ) + return {"team_id": team_id, "data": update_team_row} + elif ( + table_name is not None + and table_name == "key" + and query_type == "update_many" + and data_list is not None + and isinstance(data_list, list) + ): + """ + Batch write update queries + """ + batcher = self.db.batch_() + for idx, t in enumerate(data_list): + # check if plain text or hash + if t.token.startswith("sk-"): # type: ignore + t.token = self.hash_token(token=t.token) # type: ignore + try: + data_json = self.jsonify_object( + data=t.model_dump(exclude_none=True) + ) + except Exception: + data_json = self.jsonify_object(data=t.dict(exclude_none=True)) + batcher.litellm_verificationtoken.update( + where={"token": t.token}, # type: ignore + data={**data_json}, # type: ignore + ) + await batcher.commit() + print_verbose( + "\033[91m" + "DB Token Table update succeeded" + "\033[0m" + ) + elif ( + table_name is not None + and table_name == "user" + and query_type == "update_many" + and data_list is not None + and isinstance(data_list, list) + ): + """ + Batch write update queries + """ + batcher = self.db.batch_() + for idx, user in enumerate(data_list): + try: + data_json = self.jsonify_object( + data=user.model_dump(exclude_none=True) + ) + except Exception: + data_json = self.jsonify_object(data=user.dict()) + batcher.litellm_usertable.upsert( + where={"user_id": user.user_id}, # type: ignore + data={ + "create": {**data_json}, # type: ignore + "update": { + **data_json # type: ignore + }, # just update user-specified values, if it already exists + }, + ) + await batcher.commit() + verbose_proxy_logger.info( + "\033[91m" + "DB User Table Batch update succeeded" + "\033[0m" + ) + elif ( + table_name is not None + and table_name == "team" + and query_type == "update_many" + and data_list is not None + and isinstance(data_list, list) + ): + # Batch write update queries + batcher = self.db.batch_() + for idx, team in enumerate(data_list): + try: + data_json = self.jsonify_team_object( + db_data=team.model_dump(exclude_none=True) + ) + except Exception: + data_json = self.jsonify_object( + data=team.dict(exclude_none=True) + ) + batcher.litellm_teamtable.upsert( + where={"team_id": team.team_id}, # type: ignore + data={ + "create": {**data_json}, # type: ignore + "update": { + **data_json # type: ignore + }, # just update user-specified values, if it already exists + }, + ) + await batcher.commit() + verbose_proxy_logger.info( + "\033[91m" + "DB Team Table Batch update succeeded" + "\033[0m" + ) + + except Exception as e: + import traceback + + error_msg = f"LiteLLM Prisma Client Exception - update_data: {str(e)}" + print_verbose(error_msg) + error_traceback = error_msg + "\n" + traceback.format_exc() + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.proxy_logging_obj.failure_handler( + original_exception=e, + duration=_duration, + call_type="update_data", + traceback_str=error_traceback, + ) + ) + raise e + + # Define a retrying strategy with exponential backoff + @backoff.on_exception( + backoff.expo, + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for + on_backoff=on_backoff, # specifying the function to call on backoff + ) + async def delete_data( + self, + tokens: Optional[List] = None, + team_id_list: Optional[List] = None, + table_name: Optional[Literal["user", "key", "config", "spend", "team"]] = None, + user_id: Optional[str] = None, + ): + """ + Allow user to delete a key(s) + + Ensure user owns that key, unless admin. + """ + start_time = time.time() + try: + if tokens is not None and isinstance(tokens, List): + hashed_tokens = [] + for token in tokens: + if isinstance(token, str) and token.startswith("sk-"): + hashed_token = self.hash_token(token=token) + else: + hashed_token = token + hashed_tokens.append(hashed_token) + filter_query: dict = {} + if user_id is not None: + filter_query = { + "AND": [{"token": {"in": hashed_tokens}}, {"user_id": user_id}] + } + else: + filter_query = {"token": {"in": hashed_tokens}} + + deleted_tokens = await self.db.litellm_verificationtoken.delete_many( + where=filter_query # type: ignore + ) + verbose_proxy_logger.debug("deleted_tokens: %s", deleted_tokens) + return {"deleted_keys": deleted_tokens} + elif ( + table_name == "team" + and team_id_list is not None + and isinstance(team_id_list, List) + ): + # admin only endpoint -> `/team/delete` + await self.db.litellm_teamtable.delete_many( + where={"team_id": {"in": team_id_list}} + ) + return {"deleted_teams": team_id_list} + elif ( + table_name == "key" + and team_id_list is not None + and isinstance(team_id_list, List) + ): + # admin only endpoint -> `/team/delete` + await self.db.litellm_verificationtoken.delete_many( + where={"team_id": {"in": team_id_list}} + ) + except Exception as e: + import traceback + + error_msg = f"LiteLLM Prisma Client Exception - delete_data: {str(e)}" + print_verbose(error_msg) + error_traceback = error_msg + "\n" + traceback.format_exc() + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.proxy_logging_obj.failure_handler( + original_exception=e, + duration=_duration, + call_type="delete_data", + traceback_str=error_traceback, + ) + ) + raise e + + # Define a retrying strategy with exponential backoff + @backoff.on_exception( + backoff.expo, + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for + on_backoff=on_backoff, # specifying the function to call on backoff + ) + async def connect(self): + start_time = time.time() + try: + verbose_proxy_logger.debug( + "PrismaClient: connect() called Attempting to Connect to DB" + ) + if self.db.is_connected() is False: + verbose_proxy_logger.debug( + "PrismaClient: DB not connected, Attempting to Connect to DB" + ) + await self.db.connect() + except Exception as e: + import traceback + + error_msg = f"LiteLLM Prisma Client Exception connect(): {str(e)}" + print_verbose(error_msg) + error_traceback = error_msg + "\n" + traceback.format_exc() + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.proxy_logging_obj.failure_handler( + original_exception=e, + duration=_duration, + call_type="connect", + traceback_str=error_traceback, + ) + ) + raise e + + # Define a retrying strategy with exponential backoff + @backoff.on_exception( + backoff.expo, + Exception, # base exception to catch for the backoff + max_tries=3, # maximum number of retries + max_time=10, # maximum total time to retry for + on_backoff=on_backoff, # specifying the function to call on backoff + ) + async def disconnect(self): + start_time = time.time() + try: + await self.db.disconnect() + except Exception as e: + import traceback + + error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}" + print_verbose(error_msg) + error_traceback = error_msg + "\n" + traceback.format_exc() + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.proxy_logging_obj.failure_handler( + original_exception=e, + duration=_duration, + call_type="disconnect", + traceback_str=error_traceback, + ) + ) + raise e + + async def health_check(self): + """ + Health check endpoint for the prisma client + """ + start_time = time.time() + try: + sql_query = "SELECT 1" + + # Execute the raw query + # The asterisk before `user_id_list` unpacks the list into separate arguments + response = await self.db.query_raw(sql_query) + return response + except Exception as e: + import traceback + + error_msg = f"LiteLLM Prisma Client Exception disconnect(): {str(e)}" + print_verbose(error_msg) + error_traceback = error_msg + "\n" + traceback.format_exc() + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + self.proxy_logging_obj.failure_handler( + original_exception=e, + duration=_duration, + call_type="health_check", + traceback_str=error_traceback, + ) + ) + raise e + + async def _get_spend_logs_row_count(self) -> int: + try: + sql_query = """ + SELECT reltuples::BIGINT + FROM pg_class + WHERE oid = '"LiteLLM_SpendLogs"'::regclass; + """ + result = await self.db.query_raw(query=sql_query) + return result[0]["reltuples"] + except Exception as e: + verbose_proxy_logger.error( + f"Error getting LiteLLM_SpendLogs row count: {e}" + ) + return 0 + + async def _set_spend_logs_row_count_in_proxy_state(self) -> None: + """ + Set the `LiteLLM_SpendLogs`row count in proxy state. + + This is used later to determine if we should run expensive UI Usage queries. + """ + from litellm.proxy.proxy_server import proxy_state + + _num_spend_logs_rows = await self._get_spend_logs_row_count() + proxy_state.set_proxy_state_variable( + variable_name="spend_logs_row_count", + value=_num_spend_logs_rows, + ) + + +### HELPER FUNCTIONS ### +async def _cache_user_row(user_id: str, cache: DualCache, db: PrismaClient): + """ + Check if a user_id exists in cache, + if not retrieve it. + """ + cache_key = f"{user_id}_user_api_key_user_id" + response = cache.get_cache(key=cache_key) + if response is None: # Cache miss + user_row = await db.get_data(user_id=user_id) + if user_row is not None: + print_verbose(f"User Row: {user_row}, type = {type(user_row)}") + if hasattr(user_row, "model_dump_json") and callable( + getattr(user_row, "model_dump_json") + ): + cache_value = user_row.model_dump_json() + cache.set_cache( + key=cache_key, value=cache_value, ttl=600 + ) # store for 10 minutes + return + + +async def send_email(receiver_email, subject, html): + """ + smtp_host, + smtp_port, + smtp_username, + smtp_password, + sender_name, + sender_email, + """ + ## SERVER SETUP ## + + smtp_host = os.getenv("SMTP_HOST") + smtp_port = int(os.getenv("SMTP_PORT", "587")) # default to port 587 + smtp_username = os.getenv("SMTP_USERNAME") + smtp_password = os.getenv("SMTP_PASSWORD") + sender_email = os.getenv("SMTP_SENDER_EMAIL", None) + if sender_email is None: + raise ValueError("Trying to use SMTP, but SMTP_SENDER_EMAIL is not set") + + ## EMAIL SETUP ## + email_message = MIMEMultipart() + email_message["From"] = sender_email + email_message["To"] = receiver_email + email_message["Subject"] = subject + verbose_proxy_logger.debug( + "sending email from %s to %s", sender_email, receiver_email + ) + + if smtp_host is None: + raise ValueError("Trying to use SMTP, but SMTP_HOST is not set") + + # Attach the body to the email + email_message.attach(MIMEText(html, "html")) + + try: + # Establish a secure connection with the SMTP server + with smtplib.SMTP(smtp_host, smtp_port) as server: # type: ignore + if os.getenv("SMTP_TLS", "True") != "False": + server.starttls() + + # Login to your email account only if smtp_username and smtp_password are provided + if smtp_username and smtp_password: + server.login(smtp_username, smtp_password) # type: ignore + + # Send the email + server.send_message(email_message) + + except Exception as e: + print_verbose("An error occurred while sending the email:" + str(e)) + + +def hash_token(token: str): + import hashlib + + # Hash the string using SHA-256 + hashed_token = hashlib.sha256(token.encode()).hexdigest() + + return hashed_token + + +def _hash_token_if_needed(token: str) -> str: + """ + Hash the token if it's a string and starts with "sk-" + + Else return the token as is + """ + if token.startswith("sk-"): + return hash_token(token=token) + else: + return token + + +class ProxyUpdateSpend: + @staticmethod + async def update_end_user_spend( + n_retry_times: int, prisma_client: PrismaClient, proxy_logging_obj: ProxyLogging + ): + for i in range(n_retry_times + 1): + start_time = time.time() + try: + async with prisma_client.db.tx( + timeout=timedelta(seconds=60) + ) as transaction: + async with transaction.batch_() as batcher: + for ( + end_user_id, + response_cost, + ) in prisma_client.end_user_list_transactons.items(): + if litellm.max_end_user_budget is not None: + pass + batcher.litellm_endusertable.upsert( + where={"user_id": end_user_id}, + data={ + "create": { + "user_id": end_user_id, + "spend": response_cost, + "blocked": False, + }, + "update": {"spend": {"increment": response_cost}}, + }, + ) + + break + except DB_CONNECTION_ERROR_TYPES as e: + if i >= n_retry_times: # If we've reached the maximum number of retries + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + finally: + prisma_client.end_user_list_transactons = ( + {} + ) # reset the end user list transactions - prevent bad data from causing issues + + @staticmethod + async def update_spend_logs( + n_retry_times: int, + prisma_client: PrismaClient, + db_writer_client: Optional[HTTPHandler], + proxy_logging_obj: ProxyLogging, + ): + BATCH_SIZE = 100 # Preferred size of each batch to write to the database + MAX_LOGS_PER_INTERVAL = ( + 1000 # Maximum number of logs to flush in a single interval + ) + # Get initial logs to process + logs_to_process = prisma_client.spend_log_transactions[:MAX_LOGS_PER_INTERVAL] + start_time = time.time() + try: + for i in range(n_retry_times + 1): + try: + base_url = os.getenv("SPEND_LOGS_URL", None) + if ( + len(logs_to_process) > 0 + and base_url is not None + and db_writer_client is not None + ): + if not base_url.endswith("/"): + base_url += "/" + verbose_proxy_logger.debug("base_url: {}".format(base_url)) + response = await db_writer_client.post( + url=base_url + "spend/update", + data=json.dumps(logs_to_process), + headers={"Content-Type": "application/json"}, + ) + if response.status_code == 200: + prisma_client.spend_log_transactions = ( + prisma_client.spend_log_transactions[ + len(logs_to_process) : + ] + ) + else: + for j in range(0, len(logs_to_process), BATCH_SIZE): + batch = logs_to_process[j : j + BATCH_SIZE] + batch_with_dates = [ + prisma_client.jsonify_object({**entry}) + for entry in batch + ] + await prisma_client.db.litellm_spendlogs.create_many( + data=batch_with_dates, skip_duplicates=True + ) + verbose_proxy_logger.debug( + f"Flushed {len(batch)} logs to the DB." + ) + + prisma_client.spend_log_transactions = ( + prisma_client.spend_log_transactions[len(logs_to_process) :] + ) + verbose_proxy_logger.debug( + f"{len(logs_to_process)} logs processed. Remaining in queue: {len(prisma_client.spend_log_transactions)}" + ) + break + except DB_CONNECTION_ERROR_TYPES: + if i is None: + i = 0 + if i >= n_retry_times: + raise + await asyncio.sleep(2**i) + except Exception as e: + prisma_client.spend_log_transactions = prisma_client.spend_log_transactions[ + len(logs_to_process) : + ] + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + + @staticmethod + def disable_spend_updates() -> bool: + """ + returns True if should not update spend in db + Skips writing spend logs and updates to key, team, user spend to DB + """ + from litellm.proxy.proxy_server import general_settings + + if general_settings.get("disable_spend_updates") is True: + return True + return False + + +async def update_spend( # noqa: PLR0915 + prisma_client: PrismaClient, + db_writer_client: Optional[HTTPHandler], + proxy_logging_obj: ProxyLogging, +): + """ + Batch write updates to db. + + Triggered every minute. + + Requires: + user_id_list: dict, + keys_list: list, + team_list: list, + spend_logs: list, + """ + n_retry_times = 3 + i = None + ### UPDATE USER TABLE ### + if len(prisma_client.user_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + start_time = time.time() + try: + async with prisma_client.db.tx( + timeout=timedelta(seconds=60) + ) as transaction: + async with transaction.batch_() as batcher: + for ( + user_id, + response_cost, + ) in prisma_client.user_list_transactons.items(): + batcher.litellm_usertable.update_many( + where={"user_id": user_id}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.user_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + break + except DB_CONNECTION_ERROR_TYPES as e: + if i >= n_retry_times: # If we've reached the maximum number of retries + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + + ### UPDATE END-USER TABLE ### + verbose_proxy_logger.debug( + "End-User Spend transactions: {}".format( + len(prisma_client.end_user_list_transactons.keys()) + ) + ) + if len(prisma_client.end_user_list_transactons.keys()) > 0: + await ProxyUpdateSpend.update_end_user_spend( + n_retry_times=n_retry_times, + prisma_client=prisma_client, + proxy_logging_obj=proxy_logging_obj, + ) + ### UPDATE KEY TABLE ### + verbose_proxy_logger.debug( + "KEY Spend transactions: {}".format( + len(prisma_client.key_list_transactons.keys()) + ) + ) + if len(prisma_client.key_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + start_time = time.time() + try: + async with prisma_client.db.tx( + timeout=timedelta(seconds=60) + ) as transaction: + async with transaction.batch_() as batcher: + for ( + token, + response_cost, + ) in prisma_client.key_list_transactons.items(): + batcher.litellm_verificationtoken.update_many( # 'update_many' prevents error from being raised if no row exists + where={"token": token}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.key_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + break + except DB_CONNECTION_ERROR_TYPES as e: + if i >= n_retry_times: # If we've reached the maximum number of retries + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + + ### UPDATE TEAM TABLE ### + verbose_proxy_logger.debug( + "Team Spend transactions: {}".format( + len(prisma_client.team_list_transactons.keys()) + ) + ) + if len(prisma_client.team_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + start_time = time.time() + try: + async with prisma_client.db.tx( + timeout=timedelta(seconds=60) + ) as transaction: + async with transaction.batch_() as batcher: + for ( + team_id, + response_cost, + ) in prisma_client.team_list_transactons.items(): + verbose_proxy_logger.debug( + "Updating spend for team id={} by {}".format( + team_id, response_cost + ) + ) + batcher.litellm_teamtable.update_many( # 'update_many' prevents error from being raised if no row exists + where={"team_id": team_id}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.team_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + break + except DB_CONNECTION_ERROR_TYPES as e: + if i >= n_retry_times: # If we've reached the maximum number of retries + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + + ### UPDATE TEAM Membership TABLE with spend ### + if len(prisma_client.team_member_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + start_time = time.time() + try: + async with prisma_client.db.tx( + timeout=timedelta(seconds=60) + ) as transaction: + async with transaction.batch_() as batcher: + for ( + key, + response_cost, + ) in prisma_client.team_member_list_transactons.items(): + # key is "team_id::<value>::user_id::<value>" + team_id = key.split("::")[1] + user_id = key.split("::")[3] + + batcher.litellm_teammembership.update_many( # 'update_many' prevents error from being raised if no row exists + where={"team_id": team_id, "user_id": user_id}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.team_member_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + break + except DB_CONNECTION_ERROR_TYPES as e: + if i >= n_retry_times: # If we've reached the maximum number of retries + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + + ### UPDATE ORG TABLE ### + if len(prisma_client.org_list_transactons.keys()) > 0: + for i in range(n_retry_times + 1): + start_time = time.time() + try: + async with prisma_client.db.tx( + timeout=timedelta(seconds=60) + ) as transaction: + async with transaction.batch_() as batcher: + for ( + org_id, + response_cost, + ) in prisma_client.org_list_transactons.items(): + batcher.litellm_organizationtable.update_many( # 'update_many' prevents error from being raised if no row exists + where={"organization_id": org_id}, + data={"spend": {"increment": response_cost}}, + ) + prisma_client.org_list_transactons = ( + {} + ) # Clear the remaining transactions after processing all batches in the loop. + break + except DB_CONNECTION_ERROR_TYPES as e: + if i >= n_retry_times: # If we've reached the maximum number of retries + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + # Optionally, sleep for a bit before retrying + await asyncio.sleep(2**i) # Exponential backoff + except Exception as e: + _raise_failed_update_spend_exception( + e=e, start_time=start_time, proxy_logging_obj=proxy_logging_obj + ) + + ### UPDATE SPEND LOGS ### + verbose_proxy_logger.debug( + "Spend Logs transactions: {}".format(len(prisma_client.spend_log_transactions)) + ) + + if len(prisma_client.spend_log_transactions) > 0: + await ProxyUpdateSpend.update_spend_logs( + n_retry_times=n_retry_times, + prisma_client=prisma_client, + proxy_logging_obj=proxy_logging_obj, + db_writer_client=db_writer_client, + ) + + +def _raise_failed_update_spend_exception( + e: Exception, start_time: float, proxy_logging_obj: ProxyLogging +): + """ + Raise an exception for failed update spend logs + + - Calls proxy_logging_obj.failure_handler to log the error + - Ensures error messages says "Non-Blocking" + """ + import traceback + + error_msg = ( + f"[Non-Blocking]LiteLLM Prisma Client Exception - update spend logs: {str(e)}" + ) + error_traceback = error_msg + "\n" + traceback.format_exc() + end_time = time.time() + _duration = end_time - start_time + asyncio.create_task( + proxy_logging_obj.failure_handler( + original_exception=e, + duration=_duration, + call_type="update_spend", + traceback_str=error_traceback, + ) + ) + raise e + + +def _is_projected_spend_over_limit( + current_spend: float, soft_budget_limit: Optional[float] +): + from datetime import date + + if soft_budget_limit is None: + # If there's no limit, we can't exceed it. + return False + + today = date.today() + + # Finding the first day of the next month, then subtracting one day to get the end of the current month. + if today.month == 12: # December edge case + end_month = date(today.year + 1, 1, 1) - timedelta(days=1) + else: + end_month = date(today.year, today.month + 1, 1) - timedelta(days=1) + + remaining_days = (end_month - today).days + + # Check for the start of the month to avoid division by zero + if today.day == 1: + daily_spend_estimate = current_spend + else: + daily_spend_estimate = current_spend / (today.day - 1) + + # Total projected spend for the month + projected_spend = current_spend + (daily_spend_estimate * remaining_days) + + if projected_spend > soft_budget_limit: + print_verbose("Projected spend exceeds soft budget limit!") + return True + return False + + +def _get_projected_spend_over_limit( + current_spend: float, soft_budget_limit: Optional[float] +) -> Optional[tuple]: + import datetime + + if soft_budget_limit is None: + return None + + today = datetime.date.today() + end_month = datetime.date(today.year, today.month + 1, 1) - datetime.timedelta( + days=1 + ) + remaining_days = (end_month - today).days + + daily_spend = current_spend / ( + today.day - 1 + ) # assuming the current spend till today (not including today) + projected_spend = daily_spend * remaining_days + + if projected_spend > soft_budget_limit: + approx_days = soft_budget_limit / daily_spend + limit_exceed_date = today + datetime.timedelta(days=approx_days) + + # return the projected spend and the date it will exceeded + return projected_spend, limit_exceed_date + + return None + + +def _is_valid_team_configs(team_id=None, team_config=None, request_data=None): + if team_id is None or team_config is None or request_data is None: + return + # check if valid model called for team + if "models" in team_config: + valid_models = team_config.pop("models") + model_in_request = request_data["model"] + if model_in_request not in valid_models: + raise Exception( + f"Invalid model for team {team_id}: {model_in_request}. Valid models for team are: {valid_models}\n" + ) + return + + +def _to_ns(dt): + return int(dt.timestamp() * 1e9) + + +def get_error_message_str(e: Exception) -> str: + error_message = "" + if isinstance(e, HTTPException): + if isinstance(e.detail, str): + error_message = e.detail + elif isinstance(e.detail, dict): + error_message = json.dumps(e.detail) + elif hasattr(e, "message"): + _error = getattr(e, "message", None) + if isinstance(_error, str): + error_message = _error + elif isinstance(_error, dict): + error_message = json.dumps(_error) + else: + error_message = str(e) + else: + error_message = str(e) + return error_message + + +def _get_redoc_url() -> str: + """ + Get the redoc URL from the environment variables. + + - If REDOC_URL is set, return it. + - Otherwise, default to "/redoc". + """ + return os.getenv("REDOC_URL", "/redoc") + + +def _get_docs_url() -> Optional[str]: + """ + Get the docs URL from the environment variables. + + - If DOCS_URL is set, return it. + - If NO_DOCS is True, return None. + - Otherwise, default to "/". + """ + docs_url = os.getenv("DOCS_URL", None) + if docs_url: + return docs_url + + if os.getenv("NO_DOCS", "False") == "True": + return None + + # default to "/" + return "/" + + +def handle_exception_on_proxy(e: Exception) -> ProxyException: + """ + Returns an Exception as ProxyException, this ensures all exceptions are OpenAI API compatible + """ + from fastapi import status + + if isinstance(e, HTTPException): + return ProxyException( + message=getattr(e, "detail", f"error({str(e)})"), + type=ProxyErrorTypes.internal_server_error, + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), + ) + elif isinstance(e, ProxyException): + return e + return ProxyException( + message="Internal Server Error, " + str(e), + type=ProxyErrorTypes.internal_server_error, + param=getattr(e, "param", "None"), + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +def _premium_user_check(): + """ + Raises an HTTPException if the user is not a premium user + """ + from litellm.proxy.proxy_server import premium_user + + if not premium_user: + raise HTTPException( + status_code=403, + detail={ + "error": f"This feature is only available for LiteLLM Enterprise users. {CommonProxyErrors.not_premium_user.value}" + }, + ) |