about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/spend_tracking/spend_tracking_utils.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/spend_tracking/spend_tracking_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/spend_tracking/spend_tracking_utils.py386
1 files changed, 386 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/spend_tracking/spend_tracking_utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/spend_tracking/spend_tracking_utils.py
new file mode 100644
index 00000000..6e9a0880
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/spend_tracking/spend_tracking_utils.py
@@ -0,0 +1,386 @@
+import hashlib
+import json
+import secrets
+from datetime import datetime
+from datetime import datetime as dt
+from datetime import timezone
+from typing import Any, List, Optional, cast
+
+from pydantic import BaseModel
+
+import litellm
+from litellm._logging import verbose_proxy_logger
+from litellm.litellm_core_utils.core_helpers import get_litellm_metadata_from_kwargs
+from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
+from litellm.proxy.utils import PrismaClient, hash_token
+from litellm.types.utils import StandardLoggingPayload
+from litellm.utils import get_end_user_id_for_cost_tracking
+
+
+def _is_master_key(api_key: str, _master_key: Optional[str]) -> bool:
+    if _master_key is None:
+        return False
+
+    ## string comparison
+    is_master_key = secrets.compare_digest(api_key, _master_key)
+    if is_master_key:
+        return True
+
+    ## hash comparison
+    is_master_key = secrets.compare_digest(api_key, hash_token(_master_key))
+    if is_master_key:
+        return True
+
+    return False
+
+
+def _get_spend_logs_metadata(
+    metadata: Optional[dict],
+    applied_guardrails: Optional[List[str]] = None,
+    batch_models: Optional[List[str]] = None,
+) -> SpendLogsMetadata:
+    if metadata is None:
+        return SpendLogsMetadata(
+            user_api_key=None,
+            user_api_key_alias=None,
+            user_api_key_team_id=None,
+            user_api_key_org_id=None,
+            user_api_key_user_id=None,
+            user_api_key_team_alias=None,
+            spend_logs_metadata=None,
+            requester_ip_address=None,
+            additional_usage_values=None,
+            applied_guardrails=None,
+            status=None or "success",
+            error_information=None,
+            proxy_server_request=None,
+            batch_models=None,
+        )
+    verbose_proxy_logger.debug(
+        "getting payload for SpendLogs, available keys in metadata: "
+        + str(list(metadata.keys()))
+    )
+
+    # Filter the metadata dictionary to include only the specified keys
+    clean_metadata = SpendLogsMetadata(
+        **{  # type: ignore
+            key: metadata[key]
+            for key in SpendLogsMetadata.__annotations__.keys()
+            if key in metadata
+        }
+    )
+    clean_metadata["applied_guardrails"] = applied_guardrails
+    clean_metadata["batch_models"] = batch_models
+    return clean_metadata
+
+
+def generate_hash_from_response(response_obj: Any) -> str:
+    """
+    Generate a stable hash from a response object.
+
+    Args:
+        response_obj: The response object to hash (can be dict, list, etc.)
+
+    Returns:
+        A hex string representation of the MD5 hash
+    """
+    try:
+        # Create a stable JSON string of the entire response object
+        # Sort keys to ensure consistent ordering
+        json_str = json.dumps(response_obj, sort_keys=True)
+
+        # Generate a hash of the response object
+        unique_hash = hashlib.md5(json_str.encode()).hexdigest()
+        return unique_hash
+    except Exception:
+        # Return a fallback hash if serialization fails
+        return hashlib.md5(str(response_obj).encode()).hexdigest()
+
+
+def get_spend_logs_id(
+    call_type: str, response_obj: dict, kwargs: dict
+) -> Optional[str]:
+    if call_type == "aretrieve_batch":
+        # Generate a hash from the response object
+        id: Optional[str] = generate_hash_from_response(response_obj)
+    else:
+        id = cast(Optional[str], response_obj.get("id")) or cast(
+            Optional[str], kwargs.get("litellm_call_id")
+        )
+    return id
+
+
+def get_logging_payload(  # noqa: PLR0915
+    kwargs, response_obj, start_time, end_time
+) -> SpendLogsPayload:
+    from litellm.proxy.proxy_server import general_settings, master_key
+
+    if kwargs is None:
+        kwargs = {}
+    if response_obj is None or (
+        not isinstance(response_obj, BaseModel) and not isinstance(response_obj, dict)
+    ):
+        response_obj = {}
+    # standardize this function to be used across, s3, dynamoDB, langfuse logging
+    litellm_params = kwargs.get("litellm_params", {})
+    metadata = get_litellm_metadata_from_kwargs(kwargs)
+    metadata = _add_proxy_server_request_to_metadata(
+        metadata=metadata, litellm_params=litellm_params
+    )
+    completion_start_time = kwargs.get("completion_start_time", end_time)
+    call_type = kwargs.get("call_type")
+    cache_hit = kwargs.get("cache_hit", False)
+    usage = cast(dict, response_obj).get("usage", None) or {}
+    if isinstance(usage, litellm.Usage):
+        usage = dict(usage)
+
+    if isinstance(response_obj, dict):
+        response_obj_dict = response_obj
+    elif isinstance(response_obj, BaseModel):
+        response_obj_dict = response_obj.model_dump()
+    else:
+        response_obj_dict = {}
+
+    id = get_spend_logs_id(call_type or "acompletion", response_obj_dict, kwargs)
+    standard_logging_payload = cast(
+        Optional[StandardLoggingPayload], kwargs.get("standard_logging_object", None)
+    )
+
+    end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
+
+    api_key = metadata.get("user_api_key", "")
+
+    if api_key is not None and isinstance(api_key, str):
+        if api_key.startswith("sk-"):
+            # hash the api_key
+            api_key = hash_token(api_key)
+        if (
+            _is_master_key(api_key=api_key, _master_key=master_key)
+            and general_settings.get("disable_adding_master_key_hash_to_db") is True
+        ):
+            api_key = "litellm_proxy_master_key"  # use a known alias, if the user disabled storing master key in db
+
+    if (
+        standard_logging_payload is not None
+    ):  # [TODO] migrate completely to sl payload. currently missing pass-through endpoint data
+        api_key = (
+            api_key
+            or standard_logging_payload["metadata"].get("user_api_key_hash")
+            or ""
+        )
+        end_user_id = end_user_id or standard_logging_payload["metadata"].get(
+            "user_api_key_end_user_id"
+        )
+    else:
+        api_key = ""
+    request_tags = (
+        json.dumps(metadata.get("tags", []))
+        if isinstance(metadata.get("tags", []), list)
+        else "[]"
+    )
+    if (
+        _is_master_key(api_key=api_key, _master_key=master_key)
+        and general_settings.get("disable_adding_master_key_hash_to_db") is True
+    ):
+        api_key = "litellm_proxy_master_key"  # use a known alias, if the user disabled storing master key in db
+
+    _model_id = metadata.get("model_info", {}).get("id", "")
+    _model_group = metadata.get("model_group", "")
+
+    # clean up litellm metadata
+    clean_metadata = _get_spend_logs_metadata(
+        metadata,
+        applied_guardrails=(
+            standard_logging_payload["metadata"].get("applied_guardrails", None)
+            if standard_logging_payload is not None
+            else None
+        ),
+        batch_models=(
+            standard_logging_payload.get("hidden_params", {}).get("batch_models", None)
+            if standard_logging_payload is not None
+            else None
+        ),
+    )
+
+    special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
+    additional_usage_values = {}
+    for k, v in usage.items():
+        if k not in special_usage_fields:
+            if isinstance(v, BaseModel):
+                v = v.model_dump()
+            additional_usage_values.update({k: v})
+    clean_metadata["additional_usage_values"] = additional_usage_values
+
+    if litellm.cache is not None:
+        cache_key = litellm.cache.get_cache_key(**kwargs)
+    else:
+        cache_key = "Cache OFF"
+    if cache_hit is True:
+        import time
+
+        id = f"{id}_cache_hit{time.time()}"  # SpendLogs does not allow duplicate request_id
+    try:
+        payload: SpendLogsPayload = SpendLogsPayload(
+            request_id=str(id),
+            call_type=call_type or "",
+            api_key=str(api_key),
+            cache_hit=str(cache_hit),
+            startTime=_ensure_datetime_utc(start_time),
+            endTime=_ensure_datetime_utc(end_time),
+            completionStartTime=_ensure_datetime_utc(completion_start_time),
+            model=kwargs.get("model", "") or "",
+            user=metadata.get("user_api_key_user_id", "") or "",
+            team_id=metadata.get("user_api_key_team_id", "") or "",
+            metadata=json.dumps(clean_metadata),
+            cache_key=cache_key,
+            spend=kwargs.get("response_cost", 0),
+            total_tokens=usage.get("total_tokens", 0),
+            prompt_tokens=usage.get("prompt_tokens", 0),
+            completion_tokens=usage.get("completion_tokens", 0),
+            request_tags=request_tags,
+            end_user=end_user_id or "",
+            api_base=litellm_params.get("api_base", ""),
+            model_group=_model_group,
+            model_id=_model_id,
+            requester_ip_address=clean_metadata.get("requester_ip_address", None),
+            custom_llm_provider=kwargs.get("custom_llm_provider", ""),
+            messages=_get_messages_for_spend_logs_payload(
+                standard_logging_payload=standard_logging_payload, metadata=metadata
+            ),
+            response=_get_response_for_spend_logs_payload(standard_logging_payload),
+        )
+
+        verbose_proxy_logger.debug(
+            "SpendTable: created payload - payload: %s\n\n",
+            json.dumps(payload, indent=4, default=str),
+        )
+
+        return payload
+    except Exception as e:
+        verbose_proxy_logger.exception(
+            "Error creating spendlogs object - {}".format(str(e))
+        )
+        raise e
+
+
+def _ensure_datetime_utc(timestamp: datetime) -> datetime:
+    """Helper to ensure datetime is in UTC"""
+    timestamp = timestamp.astimezone(timezone.utc)
+    return timestamp
+
+
+async def get_spend_by_team_and_customer(
+    start_date: dt,
+    end_date: dt,
+    team_id: str,
+    customer_id: str,
+    prisma_client: PrismaClient,
+):
+    sql_query = """
+    WITH SpendByModelApiKey AS (
+        SELECT
+            date_trunc('day', sl."startTime") AS group_by_day,
+            COALESCE(tt.team_alias, 'Unassigned Team') AS team_name,
+            sl.end_user AS customer,
+            sl.model,
+            sl.api_key,
+            SUM(sl.spend) AS model_api_spend,
+            SUM(sl.total_tokens) AS model_api_tokens
+        FROM 
+            "LiteLLM_SpendLogs" sl
+        LEFT JOIN 
+            "LiteLLM_TeamTable" tt 
+        ON 
+            sl.team_id = tt.team_id
+        WHERE
+            sl."startTime" BETWEEN $1::date AND $2::date
+            AND sl.team_id = $3
+            AND sl.end_user = $4
+        GROUP BY
+            date_trunc('day', sl."startTime"),
+            tt.team_alias,
+            sl.end_user,
+            sl.model,
+            sl.api_key
+    )
+        SELECT
+            group_by_day,
+            jsonb_agg(jsonb_build_object(
+                'team_name', team_name,
+                'customer', customer,
+                'total_spend', total_spend,
+                'metadata', metadata
+            )) AS teams_customers
+        FROM (
+            SELECT
+                group_by_day,
+                team_name,
+                customer,
+                SUM(model_api_spend) AS total_spend,
+                jsonb_agg(jsonb_build_object(
+                    'model', model,
+                    'api_key', api_key,
+                    'spend', model_api_spend,
+                    'total_tokens', model_api_tokens
+                )) AS metadata
+            FROM 
+                SpendByModelApiKey
+            GROUP BY
+                group_by_day,
+                team_name,
+                customer
+        ) AS aggregated
+        GROUP BY
+            group_by_day
+        ORDER BY
+            group_by_day;
+    """
+
+    db_response = await prisma_client.db.query_raw(
+        sql_query, start_date, end_date, team_id, customer_id
+    )
+    if db_response is None:
+        return []
+
+    return db_response
+
+
+def _get_messages_for_spend_logs_payload(
+    standard_logging_payload: Optional[StandardLoggingPayload],
+    metadata: Optional[dict] = None,
+) -> str:
+    return "{}"
+
+
+def _add_proxy_server_request_to_metadata(
+    metadata: dict,
+    litellm_params: dict,
+) -> dict:
+    """
+    Only store if _should_store_prompts_and_responses_in_spend_logs() is True
+    """
+    if _should_store_prompts_and_responses_in_spend_logs():
+        _proxy_server_request = cast(
+            Optional[dict], litellm_params.get("proxy_server_request", {})
+        )
+        if _proxy_server_request is not None:
+            _request_body = _proxy_server_request.get("body", {}) or {}
+            _request_body_json_str = json.dumps(_request_body, default=str)
+            metadata["proxy_server_request"] = _request_body_json_str
+    return metadata
+
+
+def _get_response_for_spend_logs_payload(
+    payload: Optional[StandardLoggingPayload],
+) -> str:
+    if payload is None:
+        return "{}"
+    if _should_store_prompts_and_responses_in_spend_logs():
+        return json.dumps(payload.get("response", {}))
+    return "{}"
+
+
+def _should_store_prompts_and_responses_in_spend_logs() -> bool:
+    from litellm.proxy.proxy_server import general_settings
+
+    return general_settings.get("store_prompts_in_spend_logs") is True