diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py | 246 |
1 files changed, 246 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py new file mode 100644 index 00000000..e8a94732 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py @@ -0,0 +1,246 @@ +import asyncio +import traceback +from datetime import datetime +from typing import Any, Optional, Union, cast + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.litellm_core_utils.core_helpers import ( + _get_parent_otel_span_from_kwargs, + get_litellm_metadata_from_kwargs, +) +from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.auth.auth_checks import log_db_metrics +from litellm.proxy.utils import ProxyUpdateSpend +from litellm.types.utils import ( + StandardLoggingPayload, + StandardLoggingUserAPIKeyMetadata, +) +from litellm.utils import get_end_user_id_for_cost_tracking + + +class _ProxyDBLogger(CustomLogger): + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + await self._PROXY_track_cost_callback( + kwargs, response_obj, start_time, end_time + ) + + async def async_post_call_failure_hook( + self, + request_data: dict, + original_exception: Exception, + user_api_key_dict: UserAPIKeyAuth, + ): + from litellm.proxy.proxy_server import update_database + + if _ProxyDBLogger._should_track_errors_in_db() is False: + return + + _metadata = dict( + StandardLoggingUserAPIKeyMetadata( + user_api_key_hash=user_api_key_dict.api_key, + user_api_key_alias=user_api_key_dict.key_alias, + user_api_key_user_email=user_api_key_dict.user_email, + user_api_key_user_id=user_api_key_dict.user_id, + user_api_key_team_id=user_api_key_dict.team_id, + user_api_key_org_id=user_api_key_dict.org_id, + user_api_key_team_alias=user_api_key_dict.team_alias, + user_api_key_end_user_id=user_api_key_dict.end_user_id, + ) + ) + _metadata["user_api_key"] = user_api_key_dict.api_key + _metadata["status"] = "failure" + _metadata["error_information"] = ( + StandardLoggingPayloadSetup.get_error_information( + original_exception=original_exception, + ) + ) + + existing_metadata: dict = request_data.get("metadata", None) or {} + existing_metadata.update(_metadata) + + if "litellm_params" not in request_data: + request_data["litellm_params"] = {} + request_data["litellm_params"]["proxy_server_request"] = ( + request_data.get("proxy_server_request") or {} + ) + request_data["litellm_params"]["metadata"] = existing_metadata + await update_database( + token=user_api_key_dict.api_key, + response_cost=0.0, + user_id=user_api_key_dict.user_id, + end_user_id=user_api_key_dict.end_user_id, + team_id=user_api_key_dict.team_id, + kwargs=request_data, + completion_response=original_exception, + start_time=datetime.now(), + end_time=datetime.now(), + org_id=user_api_key_dict.org_id, + ) + + @log_db_metrics + async def _PROXY_track_cost_callback( + self, + kwargs, # kwargs to completion + completion_response: Optional[ + Union[litellm.ModelResponse, Any] + ], # response from completion + start_time=None, + end_time=None, # start/end time for completion + ): + from litellm.proxy.proxy_server import ( + prisma_client, + proxy_logging_obj, + update_cache, + update_database, + ) + + verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback") + try: + verbose_proxy_logger.debug( + f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" + ) + parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs) + litellm_params = kwargs.get("litellm_params", {}) or {} + end_user_id = get_end_user_id_for_cost_tracking(litellm_params) + metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs) + user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None)) + team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None)) + org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None)) + key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None)) + end_user_max_budget = metadata.get("user_api_end_user_max_budget", None) + sl_object: Optional[StandardLoggingPayload] = kwargs.get( + "standard_logging_object", None + ) + response_cost = ( + sl_object.get("response_cost", None) + if sl_object is not None + else kwargs.get("response_cost", None) + ) + + if response_cost is not None: + user_api_key = metadata.get("user_api_key", None) + if kwargs.get("cache_hit", False) is True: + response_cost = 0.0 + verbose_proxy_logger.info( + f"Cache Hit: response_cost {response_cost}, for user_id {user_id}" + ) + + verbose_proxy_logger.debug( + f"user_api_key {user_api_key}, prisma_client: {prisma_client}" + ) + if _should_track_cost_callback( + user_api_key=user_api_key, + user_id=user_id, + team_id=team_id, + end_user_id=end_user_id, + ): + ## UPDATE DATABASE + await update_database( + token=user_api_key, + response_cost=response_cost, + user_id=user_id, + end_user_id=end_user_id, + team_id=team_id, + kwargs=kwargs, + completion_response=completion_response, + start_time=start_time, + end_time=end_time, + org_id=org_id, + ) + + # update cache + asyncio.create_task( + update_cache( + token=user_api_key, + user_id=user_id, + end_user_id=end_user_id, + response_cost=response_cost, + team_id=team_id, + parent_otel_span=parent_otel_span, + ) + ) + + await proxy_logging_obj.slack_alerting_instance.customer_spend_alert( + token=user_api_key, + key_alias=key_alias, + end_user_id=end_user_id, + response_cost=response_cost, + max_budget=end_user_max_budget, + ) + else: + raise Exception( + "User API key and team id and user id missing from custom callback." + ) + else: + if kwargs["stream"] is not True or ( + kwargs["stream"] is True and "complete_streaming_response" in kwargs + ): + if sl_object is not None: + cost_tracking_failure_debug_info: Union[dict, str] = ( + sl_object["response_cost_failure_debug_info"] # type: ignore + or "response_cost_failure_debug_info is None in standard_logging_object" + ) + else: + cost_tracking_failure_debug_info = ( + "standard_logging_object not found" + ) + model = kwargs.get("model") + raise Exception( + f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing" + ) + except Exception as e: + error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}" + model = kwargs.get("model", "") + metadata = kwargs.get("litellm_params", {}).get("metadata", {}) + error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n" + asyncio.create_task( + proxy_logging_obj.failed_tracking_alert( + error_message=error_msg, + failing_model=model, + ) + ) + verbose_proxy_logger.exception( + "Error in tracking cost callback - %s", str(e) + ) + + @staticmethod + def _should_track_errors_in_db(): + """ + Returns True if errors should be tracked in the database + + By default, errors are tracked in the database + + If users want to disable error tracking, they can set the disable_error_logs flag in the general_settings + """ + from litellm.proxy.proxy_server import general_settings + + if general_settings.get("disable_error_logs") is True: + return False + return + + +def _should_track_cost_callback( + user_api_key: Optional[str], + user_id: Optional[str], + team_id: Optional[str], + end_user_id: Optional[str], +) -> bool: + """ + Determine if the cost callback should be tracked based on the kwargs + """ + + # don't run track cost callback if user opted into disabling spend + if ProxyUpdateSpend.disable_spend_updates() is True: + return False + + if ( + user_api_key is not None + or user_id is not None + or team_id is not None + or end_user_id is not None + ): + return True + return False |