aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/spend_tracking/spend_tracking_utils.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/spend_tracking/spend_tracking_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/spend_tracking/spend_tracking_utils.py386
1 files changed, 386 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/spend_tracking/spend_tracking_utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/spend_tracking/spend_tracking_utils.py
new file mode 100644
index 00000000..6e9a0880
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/spend_tracking/spend_tracking_utils.py
@@ -0,0 +1,386 @@
+import hashlib
+import json
+import secrets
+from datetime import datetime
+from datetime import datetime as dt
+from datetime import timezone
+from typing import Any, List, Optional, cast
+
+from pydantic import BaseModel
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs
+from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
+from litellm.proxy.utils import PrismaClient, hash_token
+from litellm.types.utils import StandardLoggingPayload
+from litellm.utils import get_end_user_id_for_cost_tracking
+
+
+def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
+ if _master_key is None:
+ return False
+
+ ## string comparison
+ is_master_key = secrets.compare_digest(api_key, _master_key)
+ if is_master_key:
+ return True
+
+ ## hash comparison
+ is_master_key = secrets.compare_digest(api_key, hash_token(_master_key))
+ if is_master_key:
+ return True
+
+ return False
+
+
+def _get_spend_logs_metadata(
+ metadata: Optional[dict],
+ applied_guardrails: Optional[List[str]] = None,
+ batch_models: Optional[List[str]] = None,
+) -> SpendLogsMetadata:
+ if metadata is None:
+ return SpendLogsMetadata(
+ user_api_key=None,
+ user_api_key_alias=None,
+ user_api_key_team_id=None,
+ user_api_key_org_id=None,
+ user_api_key_user_id=None,
+ user_api_key_team_alias=None,
+ spend_logs_metadata=None,
+ requester_ip_address=None,
+ additional_usage_values=None,
+ applied_guardrails=None,
+ status=None or "success",
+ error_information=None,
+ proxy_server_request=None,
+ batch_models=None,
+ )
+ verbose_proxy_logger.debug(
+ "getting payload for SpendLogs, available keys in metadata: "
+ + str(list(metadata.keys()))
+ )
+
+ # Filter the metadata dictionary to include only the specified keys
+ clean_metadata = SpendLogsMetadata(
+ **{ # type: ignore
+ key: metadata[key]
+ for key in SpendLogsMetadata.__annotations__.keys()
+ if key in metadata
+ }
+ )
+ clean_metadata["applied_guardrails"] = applied_guardrails
+ clean_metadata["batch_models"] = batch_models
+ return clean_metadata
+
+
+def generate_hash_from_response(response_obj: Any) -> str:
+ """
+ Generate a stable hash from a response object.
+
+ Args:
+ response_obj: The response object to hash (can be dict, list, etc.)
+
+ Returns:
+ A hex string representation of the MD5 hash
+ """
+ try:
+ # Create a stable JSON string of the entire response object
+ # Sort keys to ensure consistent ordering
+ json_str = json.dumps(response_obj, sort_keys=True)
+
+ # Generate a hash of the response object
+ unique_hash = hashlib.md5(json_str.encode()).hexdigest()
+ return unique_hash
+ except Exception:
+ # Return a fallback hash if serialization fails
+ return hashlib.md5(str(response_obj).encode()).hexdigest()
+
+
+def get_spend_logs_id(
+ call_type: str, response_obj: dict, kwargs: dict
+) -> Optional[str]:
+ if call_type == "aretrieve_batch":
+ # Generate a hash from the response object
+ id: Optional[str] = generate_hash_from_response(response_obj)
+ else:
+ id = cast(Optional[str], response_obj.get("id")) or cast(
+ Optional[str], kwargs.get("litellm_call_id")
+ )
+ return id
+
+
+def get_logging_payload( # noqa: PLR0915
+ kwargs, response_obj, start_time, end_time
+) -> SpendLogsPayload:
+ from litellm.proxy.proxy_server import general_settings, master_key
+
+ if kwargs is None:
+ kwargs = {}
+ if response_obj is None or (
+ not isinstance(response_obj, BaseModel) and not isinstance(response_obj, dict)
+ ):
+ response_obj = {}
+ # standardize this function to be used across, s3, dynamoDB, langfuse logging
+ litellm_params = kwargs.get("litellm_params", {})
+ metadata = get_litellm_metadata_from_kwargs(kwargs)
+ metadata = _add_proxy_server_request_to_metadata(
+ metadata=metadata, litellm_params=litellm_params
+ )
+ completion_start_time = kwargs.get("completion_start_time", end_time)
+ call_type = kwargs.get("call_type")
+ cache_hit = kwargs.get("cache_hit", False)
+ usage = cast(dict, response_obj).get("usage", None) or {}
+ if isinstance(usage, litellm.Usage):
+ usage = dict(usage)
+
+ if isinstance(response_obj, dict):
+ response_obj_dict = response_obj
+ elif isinstance(response_obj, BaseModel):
+ response_obj_dict = response_obj.model_dump()
+ else:
+ response_obj_dict = {}
+
+ id = get_spend_logs_id(call_type or "acompletion", response_obj_dict, kwargs)
+ standard_logging_payload = cast(
+ Optional[StandardLoggingPayload], kwargs.get("standard_logging_object", None)
+ )
+
+ end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
+
+ api_key = metadata.get("user_api_key", "")
+
+ if api_key is not None and isinstance(api_key, str):
+ if api_key.startswith("sk-"):
+ # hash the api_key
+ api_key = hash_token(api_key)
+ if (
+ _is_master_key(api_key=api_key, _master_key=master_key)
+ and general_settings.get("disable_adding_master_key_hash_to_db") is True
+ ):
+ api_key = "litellm_proxy_master_key" # use a known alias, if the user disabled storing master key in db
+
+ if (
+ standard_logging_payload is not None
+ ): # [TODO] migrate completely to sl payload. currently missing pass-through endpoint data
+ api_key = (
+ api_key
+ or standard_logging_payload["metadata"].get("user_api_key_hash")
+ or ""
+ )
+ end_user_id = end_user_id or standard_logging_payload["metadata"].get(
+ "user_api_key_end_user_id"
+ )
+ else:
+ api_key = ""
+ request_tags = (
+ json.dumps(metadata.get("tags", []))
+ if isinstance(metadata.get("tags", []), list)
+ else "[]"
+ )
+ if (
+ _is_master_key(api_key=api_key, _master_key=master_key)
+ and general_settings.get("disable_adding_master_key_hash_to_db") is True
+ ):
+ api_key = "litellm_proxy_master_key" # use a known alias, if the user disabled storing master key in db
+
+ _model_id = metadata.get("model_info", {}).get("id", "")
+ _model_group = metadata.get("model_group", "")
+
+ # clean up litellm metadata
+ clean_metadata = _get_spend_logs_metadata(
+ metadata,
+ applied_guardrails=(
+ standard_logging_payload["metadata"].get("applied_guardrails", None)
+ if standard_logging_payload is not None
+ else None
+ ),
+ batch_models=(
+ standard_logging_payload.get("hidden_params", {}).get("batch_models", None)
+ if standard_logging_payload is not None
+ else None
+ ),
+ )
+
+ special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
+ additional_usage_values = {}
+ for k, v in usage.items():
+ if k not in special_usage_fields:
+ if isinstance(v, BaseModel):
+ v = v.model_dump()
+ additional_usage_values.update({k: v})
+ clean_metadata["additional_usage_values"] = additional_usage_values
+
+ if litellm.cache is not None:
+ cache_key = litellm.cache.get_cache_key(**kwargs)
+ else:
+ cache_key = "Cache OFF"
+ if cache_hit is True:
+ import time
+
+ id = f"{id}_cache_hit{time.time()}" # SpendLogs does not allow duplicate request_id
+ try:
+ payload: SpendLogsPayload = SpendLogsPayload(
+ request_id=str(id),
+ call_type=call_type or "",
+ api_key=str(api_key),
+ cache_hit=str(cache_hit),
+ startTime=_ensure_datetime_utc(start_time),
+ endTime=_ensure_datetime_utc(end_time),
+ completionStartTime=_ensure_datetime_utc(completion_start_time),
+ model=kwargs.get("model", "") or "",
+ user=metadata.get("user_api_key_user_id", "") or "",
+ team_id=metadata.get("user_api_key_team_id", "") or "",
+ metadata=json.dumps(clean_metadata),
+ cache_key=cache_key,
+ spend=kwargs.get("response_cost", 0),
+ total_tokens=usage.get("total_tokens", 0),
+ prompt_tokens=usage.get("prompt_tokens", 0),
+ completion_tokens=usage.get("completion_tokens", 0),
+ request_tags=request_tags,
+ end_user=end_user_id or "",
+ api_base=litellm_params.get("api_base", ""),
+ model_group=_model_group,
+ model_id=_model_id,
+ requester_ip_address=clean_metadata.get("requester_ip_address", None),
+ custom_llm_provider=kwargs.get("custom_llm_provider", ""),
+ messages=_get_messages_for_spend_logs_payload(
+ standard_logging_payload=standard_logging_payload, metadata=metadata
+ ),
+ response=_get_response_for_spend_logs_payload(standard_logging_payload),
+ )
+
+ verbose_proxy_logger.debug(
+ "SpendTable: created payload - payload: %s\n\n",
+ json.dumps(payload, indent=4, default=str),
+ )
+
+ return payload
+ except Exception as e:
+ verbose_proxy_logger.exception(
+ "Error creating spendlogs object - {}".format(str(e))
+ )
+ raise e
+
+
+def _ensure_datetime_utc(timestamp: datetime) -> datetime:
+ """Helper to ensure datetime is in UTC"""
+ timestamp = timestamp.astimezone(timezone.utc)
+ return timestamp
+
+
+async def get_spend_by_team_and_customer(
+ start_date: dt,
+ end_date: dt,
+ team_id: str,
+ customer_id: str,
+ prisma_client: PrismaClient,
+):
+ sql_query = """
+ WITH SpendByModelApiKey AS (
+ SELECT
+ date_trunc('day', sl."startTime") AS group_by_day,
+ COALESCE(tt.team_alias, 'Unassigned Team') AS team_name,
+ sl.end_user AS customer,
+ sl.model,
+ sl.api_key,
+ SUM(sl.spend) AS model_api_spend,
+ SUM(sl.total_tokens) AS model_api_tokens
+ FROM
+ "LiteLLM_SpendLogs" sl
+ LEFT JOIN
+ "LiteLLM_TeamTable" tt
+ ON
+ sl.team_id = tt.team_id
+ WHERE
+ sl."startTime" BETWEEN $1::date AND $2::date
+ AND sl.team_id = $3
+ AND sl.end_user = $4
+ GROUP BY
+ date_trunc('day', sl."startTime"),
+ tt.team_alias,
+ sl.end_user,
+ sl.model,
+ sl.api_key
+ )
+ SELECT
+ group_by_day,
+ jsonb_agg(jsonb_build_object(
+ 'team_name', team_name,
+ 'customer', customer,
+ 'total_spend', total_spend,
+ 'metadata', metadata
+ )) AS teams_customers
+ FROM (
+ SELECT
+ group_by_day,
+ team_name,
+ customer,
+ SUM(model_api_spend) AS total_spend,
+ jsonb_agg(jsonb_build_object(
+ 'model', model,
+ 'api_key', api_key,
+ 'spend', model_api_spend,
+ 'total_tokens', model_api_tokens
+ )) AS metadata
+ FROM
+ SpendByModelApiKey
+ GROUP BY
+ group_by_day,
+ team_name,
+ customer
+ ) AS aggregated
+ GROUP BY
+ group_by_day
+ ORDER BY
+ group_by_day;
+ """
+
+ db_response = await prisma_client.db.query_raw(
+ sql_query, start_date, end_date, team_id, customer_id
+ )
+ if db_response is None:
+ return []
+
+ return db_response
+
+
+def _get_messages_for_spend_logs_payload(
+ standard_logging_payload: Optional[StandardLoggingPayload],
+ metadata: Optional[dict] = None,
+) -> str:
+ return "{}"
+
+
+def _add_proxy_server_request_to_metadata(
+ metadata: dict,
+ litellm_params: dict,
+) -> dict:
+ """
+ Only store if _should_store_prompts_and_responses_in_spend_logs() is True
+ """
+ if _should_store_prompts_and_responses_in_spend_logs():
+ _proxy_server_request = cast(
+ Optional[dict], litellm_params.get("proxy_server_request", {})
+ )
+ if _proxy_server_request is not None:
+ _request_body = _proxy_server_request.get("body", {}) or {}
+ _request_body_json_str = json.dumps(_request_body, default=str)
+ metadata["proxy_server_request"] = _request_body_json_str
+ return metadata
+
+
+def _get_response_for_spend_logs_payload(
+ payload: Optional[StandardLoggingPayload],
+) -> str:
+ if payload is None:
+ return "{}"
+ if _should_store_prompts_and_responses_in_spend_logs():
+ return json.dumps(payload.get("response", {}))
+ return "{}"
+
+
+def _should_store_prompts_and_responses_in_spend_logs() -> bool:
+ from litellm.proxy.proxy_server import general_settings
+
+ return general_settings.get("store_prompts_in_spend_logs") is True