aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/budget_manager.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/budget_manager.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/budget_manager.py222
1 files changed, 222 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/budget_manager.py b/.venv/lib/python3.12/site-packages/litellm/budget_manager.py
new file mode 100644
index 00000000..e664c4f4
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/budget_manager.py
@@ -0,0 +1,222 @@
+# +-----------------------------------------------+
+# | |
+# | NOT PROXY BUDGET MANAGER |
+# | proxy budget manager is in proxy_server.py |
+# | |
+# +-----------------------------------------------+
+#
+# Thank you users! We ❤️ you! - Krrish & Ishaan
+
+import json
+import os
+import threading
+import time
+from typing import Literal, Optional
+
+import litellm
+from litellm.utils import ModelResponse
+
+
+class BudgetManager:
+ def __init__(
+ self,
+ project_name: str,
+ client_type: str = "local",
+ api_base: Optional[str] = None,
+ headers: Optional[dict] = None,
+ ):
+ self.client_type = client_type
+ self.project_name = project_name
+ self.api_base = api_base or "https://api.litellm.ai"
+ self.headers = headers or {"Content-Type": "application/json"}
+ ## load the data or init the initial dictionaries
+ self.load_data()
+
+ def print_verbose(self, print_statement):
+ try:
+ if litellm.set_verbose:
+ import logging
+
+ logging.info(print_statement)
+ except Exception:
+ pass
+
+ def load_data(self):
+ if self.client_type == "local":
+ # Check if user dict file exists
+ if os.path.isfile("user_cost.json"):
+ # Load the user dict
+ with open("user_cost.json", "r") as json_file:
+ self.user_dict = json.load(json_file)
+ else:
+ self.print_verbose("User Dictionary not found!")
+ self.user_dict = {}
+ self.print_verbose(f"user dict from local: {self.user_dict}")
+ elif self.client_type == "hosted":
+ # Load the user_dict from hosted db
+ url = self.api_base + "/get_budget"
+ data = {"project_name": self.project_name}
+ response = litellm.module_level_client.post(
+ url, headers=self.headers, json=data
+ )
+ response = response.json()
+ if response["status"] == "error":
+ self.user_dict = (
+ {}
+ ) # assume this means the user dict hasn't been stored yet
+ else:
+ self.user_dict = response["data"]
+
+ def create_budget(
+ self,
+ total_budget: float,
+ user: str,
+ duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None,
+ created_at: float = time.time(),
+ ):
+ self.user_dict[user] = {"total_budget": total_budget}
+ if duration is None:
+ return self.user_dict[user]
+
+ if duration == "daily":
+ duration_in_days = 1
+ elif duration == "weekly":
+ duration_in_days = 7
+ elif duration == "monthly":
+ duration_in_days = 28
+ elif duration == "yearly":
+ duration_in_days = 365
+ else:
+ raise ValueError(
+ """duration needs to be one of ["daily", "weekly", "monthly", "yearly"]"""
+ )
+ self.user_dict[user] = {
+ "total_budget": total_budget,
+ "duration": duration_in_days,
+ "created_at": created_at,
+ "last_updated_at": created_at,
+ }
+ self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
+ return self.user_dict[user]
+
+ def projected_cost(self, model: str, messages: list, user: str):
+ text = "".join(message["content"] for message in messages)
+ prompt_tokens = litellm.token_counter(model=model, text=text)
+ prompt_cost, _ = litellm.cost_per_token(
+ model=model, prompt_tokens=prompt_tokens, completion_tokens=0
+ )
+ current_cost = self.user_dict[user].get("current_cost", 0)
+ projected_cost = prompt_cost + current_cost
+ return projected_cost
+
+ def get_total_budget(self, user: str):
+ return self.user_dict[user]["total_budget"]
+
+ def update_cost(
+ self,
+ user: str,
+ completion_obj: Optional[ModelResponse] = None,
+ model: Optional[str] = None,
+ input_text: Optional[str] = None,
+ output_text: Optional[str] = None,
+ ):
+ if model and input_text and output_text:
+ prompt_tokens = litellm.token_counter(
+ model=model, messages=[{"role": "user", "content": input_text}]
+ )
+ completion_tokens = litellm.token_counter(
+ model=model, messages=[{"role": "user", "content": output_text}]
+ )
+ (
+ prompt_tokens_cost_usd_dollar,
+ completion_tokens_cost_usd_dollar,
+ ) = litellm.cost_per_token(
+ model=model,
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ )
+ cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
+ elif completion_obj:
+ cost = litellm.completion_cost(completion_response=completion_obj)
+ model = completion_obj[
+ "model"
+ ] # if this throws an error try, model = completion_obj['model']
+ else:
+ raise ValueError(
+ "Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager"
+ )
+
+ self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get(
+ "current_cost", 0
+ )
+ if "model_cost" in self.user_dict[user]:
+ self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][
+ "model_cost"
+ ].get(model, 0)
+ else:
+ self.user_dict[user]["model_cost"] = {model: cost}
+
+ self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
+ return {"user": self.user_dict[user]}
+
+ def get_current_cost(self, user):
+ return self.user_dict[user].get("current_cost", 0)
+
+ def get_model_cost(self, user):
+ return self.user_dict[user].get("model_cost", 0)
+
+ def is_valid_user(self, user: str) -> bool:
+ return user in self.user_dict
+
+ def get_users(self):
+ return list(self.user_dict.keys())
+
+ def reset_cost(self, user):
+ self.user_dict[user]["current_cost"] = 0
+ self.user_dict[user]["model_cost"] = {}
+ return {"user": self.user_dict[user]}
+
+ def reset_on_duration(self, user: str):
+ # Get current and creation time
+ last_updated_at = self.user_dict[user]["last_updated_at"]
+ current_time = time.time()
+
+ # Convert duration from days to seconds
+ duration_in_seconds = self.user_dict[user]["duration"] * 24 * 60 * 60
+
+ # Check if duration has elapsed
+ if current_time - last_updated_at >= duration_in_seconds:
+ # Reset cost if duration has elapsed and update the creation time
+ self.reset_cost(user)
+ self.user_dict[user]["last_updated_at"] = current_time
+ self._save_data_thread() # Save the data
+
+ def update_budget_all_users(self):
+ for user in self.get_users():
+ if "duration" in self.user_dict[user]:
+ self.reset_on_duration(user)
+
+ def _save_data_thread(self):
+ thread = threading.Thread(
+ target=self.save_data
+ ) # [Non-Blocking]: saves data without blocking execution
+ thread.start()
+
+ def save_data(self):
+ if self.client_type == "local":
+ import json
+
+ # save the user dict
+ with open("user_cost.json", "w") as json_file:
+ json.dump(
+ self.user_dict, json_file, indent=4
+ ) # Indent for pretty formatting
+ return {"status": "success"}
+ elif self.client_type == "hosted":
+ url = self.api_base + "/set_budget"
+ data = {"project_name": self.project_name, "user_dict": self.user_dict}
+ response = litellm.module_level_client.post(
+ url, headers=self.headers, json=data
+ )
+ response = response.json()
+ return response