about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py246
1 files changed, 246 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py
new file mode 100644
index 00000000..e8a94732
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/hooks/proxy_track_cost_callback.py
@@ -0,0 +1,246 @@
+import asyncio
+import traceback
+from datetime import datetime
+from typing import Any, Optional, Union, cast
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.integrations.custom_logger import CustomLogger
+from litellm.litellm_core_utils.core_helpers import (
+    _get_parent_otel_span_from_kwargs,
+    get_litellm_metadata_from_kwargs,
+)
+from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
+from litellm.proxy._types import UserAPIKeyAuth
+from litellm.proxy.auth.auth_checks import log_db_metrics
+from litellm.proxy.utils import ProxyUpdateSpend
+from litellm.types.utils import (
+    StandardLoggingPayload,
+    StandardLoggingUserAPIKeyMetadata,
+)
+from litellm.utils import get_end_user_id_for_cost_tracking
+
+
+class _ProxyDBLogger(CustomLogger):
+    async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
+        await self._PROXY_track_cost_callback(
+            kwargs, response_obj, start_time, end_time
+        )
+
+    async def async_post_call_failure_hook(
+        self,
+        request_data: dict,
+        original_exception: Exception,
+        user_api_key_dict: UserAPIKeyAuth,
+    ):
+        from litellm.proxy.proxy_server import update_database
+
+        if _ProxyDBLogger._should_track_errors_in_db() is False:
+            return
+
+        _metadata = dict(
+            StandardLoggingUserAPIKeyMetadata(
+                user_api_key_hash=user_api_key_dict.api_key,
+                user_api_key_alias=user_api_key_dict.key_alias,
+                user_api_key_user_email=user_api_key_dict.user_email,
+                user_api_key_user_id=user_api_key_dict.user_id,
+                user_api_key_team_id=user_api_key_dict.team_id,
+                user_api_key_org_id=user_api_key_dict.org_id,
+                user_api_key_team_alias=user_api_key_dict.team_alias,
+                user_api_key_end_user_id=user_api_key_dict.end_user_id,
+            )
+        )
+        _metadata["user_api_key"] = user_api_key_dict.api_key
+        _metadata["status"] = "failure"
+        _metadata["error_information"] = (
+            StandardLoggingPayloadSetup.get_error_information(
+                original_exception=original_exception,
+            )
+        )
+
+        existing_metadata: dict = request_data.get("metadata", None) or {}
+        existing_metadata.update(_metadata)
+
+        if "litellm_params" not in request_data:
+            request_data["litellm_params"] = {}
+        request_data["litellm_params"]["proxy_server_request"] = (
+            request_data.get("proxy_server_request") or {}
+        )
+        request_data["litellm_params"]["metadata"] = existing_metadata
+        await update_database(
+            token=user_api_key_dict.api_key,
+            response_cost=0.0,
+            user_id=user_api_key_dict.user_id,
+            end_user_id=user_api_key_dict.end_user_id,
+            team_id=user_api_key_dict.team_id,
+            kwargs=request_data,
+            completion_response=original_exception,
+            start_time=datetime.now(),
+            end_time=datetime.now(),
+            org_id=user_api_key_dict.org_id,
+        )
+
+    @log_db_metrics
+    async def _PROXY_track_cost_callback(
+        self,
+        kwargs,  # kwargs to completion
+        completion_response: Optional[
+            Union[litellm.ModelResponse, Any]
+        ],  # response from completion
+        start_time=None,
+        end_time=None,  # start/end time for completion
+    ):
+        from litellm.proxy.proxy_server import (
+            prisma_client,
+            proxy_logging_obj,
+            update_cache,
+            update_database,
+        )
+
+        verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback")
+        try:
+            verbose_proxy_logger.debug(
+                f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
+            )
+            parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
+            litellm_params = kwargs.get("litellm_params", {}) or {}
+            end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
+            metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
+            user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None))
+            team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None))
+            org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None))
+            key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None))
+            end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
+            sl_object: Optional[StandardLoggingPayload] = kwargs.get(
+                "standard_logging_object", None
+            )
+            response_cost = (
+                sl_object.get("response_cost", None)
+                if sl_object is not None
+                else kwargs.get("response_cost", None)
+            )
+
+            if response_cost is not None:
+                user_api_key = metadata.get("user_api_key", None)
+                if kwargs.get("cache_hit", False) is True:
+                    response_cost = 0.0
+                    verbose_proxy_logger.info(
+                        f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
+                    )
+
+                verbose_proxy_logger.debug(
+                    f"user_api_key {user_api_key}, prisma_client: {prisma_client}"
+                )
+                if _should_track_cost_callback(
+                    user_api_key=user_api_key,
+                    user_id=user_id,
+                    team_id=team_id,
+                    end_user_id=end_user_id,
+                ):
+                    ## UPDATE DATABASE
+                    await update_database(
+                        token=user_api_key,
+                        response_cost=response_cost,
+                        user_id=user_id,
+                        end_user_id=end_user_id,
+                        team_id=team_id,
+                        kwargs=kwargs,
+                        completion_response=completion_response,
+                        start_time=start_time,
+                        end_time=end_time,
+                        org_id=org_id,
+                    )
+
+                    # update cache
+                    asyncio.create_task(
+                        update_cache(
+                            token=user_api_key,
+                            user_id=user_id,
+                            end_user_id=end_user_id,
+                            response_cost=response_cost,
+                            team_id=team_id,
+                            parent_otel_span=parent_otel_span,
+                        )
+                    )
+
+                    await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
+                        token=user_api_key,
+                        key_alias=key_alias,
+                        end_user_id=end_user_id,
+                        response_cost=response_cost,
+                        max_budget=end_user_max_budget,
+                    )
+                else:
+                    raise Exception(
+                        "User API key and team id and user id missing from custom callback."
+                    )
+            else:
+                if kwargs["stream"] is not True or (
+                    kwargs["stream"] is True and "complete_streaming_response" in kwargs
+                ):
+                    if sl_object is not None:
+                        cost_tracking_failure_debug_info: Union[dict, str] = (
+                            sl_object["response_cost_failure_debug_info"]  # type: ignore
+                            or "response_cost_failure_debug_info is None in standard_logging_object"
+                        )
+                    else:
+                        cost_tracking_failure_debug_info = (
+                            "standard_logging_object not found"
+                        )
+                    model = kwargs.get("model")
+                    raise Exception(
+                        f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
+                    )
+        except Exception as e:
+            error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
+            model = kwargs.get("model", "")
+            metadata = kwargs.get("litellm_params", {}).get("metadata", {})
+            error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n metadata: {metadata}\n"
+            asyncio.create_task(
+                proxy_logging_obj.failed_tracking_alert(
+                    error_message=error_msg,
+                    failing_model=model,
+                )
+            )
+            verbose_proxy_logger.exception(
+                "Error in tracking cost callback - %s", str(e)
+            )
+
+    @staticmethod
+    def _should_track_errors_in_db():
+        """
+        Returns True if errors should be tracked in the database
+
+        By default, errors are tracked in the database
+
+        If users want to disable error tracking, they can set the disable_error_logs flag in the general_settings
+        """
+        from litellm.proxy.proxy_server import general_settings
+
+        if general_settings.get("disable_error_logs") is True:
+            return False
+        return
+
+
+def _should_track_cost_callback(
+    user_api_key: Optional[str],
+    user_id: Optional[str],
+    team_id: Optional[str],
+    end_user_id: Optional[str],
+) -> bool:
+    """
+    Determine if the cost callback should be tracked based on the kwargs
+    """
+
+    # don't run track cost callback if user opted into disabling spend
+    if ProxyUpdateSpend.disable_spend_updates() is True:
+        return False
+
+    if (
+        user_api_key is not None
+        or user_id is not None
+        or team_id is not None
+        or end_user_id is not None
+    ):
+        return True
+    return False