aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py246
1 files changed, 246 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py
new file mode 100644
index 00000000..e8a94732
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py
@@ -0,0 +1,246 @@
+import asyncio
+import traceback
+from datetime import datetime
+from typing import Any, Optional, Union, cast
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.litellm_core_utils.core_helpers import (
+ _get_parent_otel_span_from_kwargs,
+ get_litellm_metadata_from_kwargs,
+)
+from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
+from litellm.proxy._types import UserAPIKeyAuth
+from litellm.proxy.auth.auth_checks import log_db_metrics
+from litellm.proxy.utils import ProxyUpdateSpend
+from litellm.types.utils import (
+ StandardLoggingPayload,
+ StandardLoggingUserAPIKeyMetadata,
+)
+from litellm.utils import get_end_user_id_for_cost_tracking
+
+
+class _ProxyDBLogger(CustomLogger):
+ async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+ await self._PROXY_track_cost_callback(
+ kwargs, response_obj, start_time, end_time
+ )
+
+ async def async_post_call_failure_hook(
+ self,
+ request_data: dict,
+ original_exception: Exception,
+ user_api_key_dict: UserAPIKeyAuth,
+ ):
+ from litellm.proxy.proxy_server import update_database
+
+ if _ProxyDBLogger._should_track_errors_in_db() is False:
+ return
+
+ _metadata = dict(
+ StandardLoggingUserAPIKeyMetadata(
+ user_api_key_hash=user_api_key_dict.api_key,
+ user_api_key_alias=user_api_key_dict.key_alias,
+ user_api_key_user_email=user_api_key_dict.user_email,
+ user_api_key_user_id=user_api_key_dict.user_id,
+ user_api_key_team_id=user_api_key_dict.team_id,
+ user_api_key_org_id=user_api_key_dict.org_id,
+ user_api_key_team_alias=user_api_key_dict.team_alias,
+ user_api_key_end_user_id=user_api_key_dict.end_user_id,
+ )
+ )
+ _metadata["user_api_key"] = user_api_key_dict.api_key
+ _metadata["status"] = "failure"
+ _metadata["error_information"] = (
+ StandardLoggingPayloadSetup.get_error_information(
+ original_exception=original_exception,
+ )
+ )
+
+ existing_metadata: dict = request_data.get("metadata", None) or {}
+ existing_metadata.update(_metadata)
+
+ if "litellm_params" not in request_data:
+ request_data["litellm_params"] = {}
+ request_data["litellm_params"]["proxy_server_request"] = (
+ request_data.get("proxy_server_request") or {}
+ )
+ request_data["litellm_params"]["metadata"] = existing_metadata
+ await update_database(
+ token=user_api_key_dict.api_key,
+ response_cost=0.0,
+ user_id=user_api_key_dict.user_id,
+ end_user_id=user_api_key_dict.end_user_id,
+ team_id=user_api_key_dict.team_id,
+ kwargs=request_data,
+ completion_response=original_exception,
+ start_time=datetime.now(),
+ end_time=datetime.now(),
+ org_id=user_api_key_dict.org_id,
+ )
+
+ @log_db_metrics
+ async def _PROXY_track_cost_callback(
+ self,
+ kwargs, # kwargs to completion
+ completion_response: Optional[
+ Union[litellm.ModelResponse, Any]
+ ], # response from completion
+ start_time=None,
+ end_time=None, # start/end time for completion
+ ):
+ from litellm.proxy.proxy_server import (
+ prisma_client,
+ proxy_logging_obj,
+ update_cache,
+ update_database,
+ )
+
+ verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback")
+ try:
+ verbose_proxy_logger.debug(
+ f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
+ )
+ parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
+ litellm_params = kwargs.get("litellm_params", {}) or {}
+ end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
+ metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
+ user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None))
+ team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None))
+ org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None))
+ key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None))
+ end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
+ sl_object: Optional[StandardLoggingPayload] = kwargs.get(
+ "standard_logging_object", None
+ )
+ response_cost = (
+ sl_object.get("response_cost", None)
+ if sl_object is not None
+ else kwargs.get("response_cost", None)
+ )
+
+ if response_cost is not None:
+ user_api_key = metadata.get("user_api_key", None)
+ if kwargs.get("cache_hit", False) is True:
+ response_cost = 0.0
+ verbose_proxy_logger.info(
+ f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
+ )
+
+ verbose_proxy_logger.debug(
+ f"user_api_key {user_api_key}, prisma_client: {prisma_client}"
+ )
+ if _should_track_cost_callback(
+ user_api_key=user_api_key,
+ user_id=user_id,
+ team_id=team_id,
+ end_user_id=end_user_id,
+ ):
+ ## UPDATE DATABASE
+ await update_database(
+ token=user_api_key,
+ response_cost=response_cost,
+ user_id=user_id,
+ end_user_id=end_user_id,
+ team_id=team_id,
+ kwargs=kwargs,
+ completion_response=completion_response,
+ start_time=start_time,
+ end_time=end_time,
+ org_id=org_id,
+ )
+
+ # update cache
+ asyncio.create_task(
+ update_cache(
+ token=user_api_key,
+ user_id=user_id,
+ end_user_id=end_user_id,
+ response_cost=response_cost,
+ team_id=team_id,
+ parent_otel_span=parent_otel_span,
+ )
+ )
+
+ await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
+ token=user_api_key,
+ key_alias=key_alias,
+ end_user_id=end_user_id,
+ response_cost=response_cost,
+ max_budget=end_user_max_budget,
+ )
+ else:
+ raise Exception(
+ "User API key and team id and user id missing from custom callback."
+ )
+ else:
+ if kwargs["stream"] is not True or (
+ kwargs["stream"] is True and "complete_streaming_response" in kwargs
+ ):
+ if sl_object is not None:
+ cost_tracking_failure_debug_info: Union[dict, str] = (
+ sl_object["response_cost_failure_debug_info"] # type: ignore
+ or "response_cost_failure_debug_info is None in standard_logging_object"
+ )
+ else:
+ cost_tracking_failure_debug_info = (
+ "standard_logging_object not found"
+ )
+ model = kwargs.get("model")
+ raise Exception(
+ f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
+ )
+ except Exception as e:
+ error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
+ model = kwargs.get("model", "")
+ metadata = kwargs.get("litellm_params", {}).get("metadata", {})
+ error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n"
+ asyncio.create_task(
+ proxy_logging_obj.failed_tracking_alert(
+ error_message=error_msg,
+ failing_model=model,
+ )
+ )
+ verbose_proxy_logger.exception(
+ "Error in tracking cost callback - %s", str(e)
+ )
+
+ @staticmethod
+ def _should_track_errors_in_db():
+ """
+ Returns True if errors should be tracked in the database
+
+ By default, errors are tracked in the database
+
+ If users want to disable error tracking, they can set the disable_error_logs flag in the general_settings
+ """
+ from litellm.proxy.proxy_server import general_settings
+
+ if general_settings.get("disable_error_logs") is True:
+ return False
+ return
+
+
+def _should_track_cost_callback(
+ user_api_key: Optional[str],
+ user_id: Optional[str],
+ team_id: Optional[str],
+ end_user_id: Optional[str],
+) -> bool:
+ """
+ Determine if the cost callback should be tracked based on the kwargs
+ """
+
+ # don't run track cost callback if user opted into disabling spend
+ if ProxyUpdateSpend.disable_spend_updates() is True:
+ return False
+
+ if (
+ user_api_key is not None
+ or user_id is not None
+ or team_id is not None
+ or end_user_id is not None
+ ):
+ return True
+ return False