diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/budget_manager.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/litellm/budget_manager.py | 222 |
1 files changed, 222 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/budget_manager.py b/.venv/lib/python3.12/site-packages/litellm/budget_manager.py new file mode 100644 index 00000000..e664c4f4 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/budget_manager.py @@ -0,0 +1,222 @@ +# +-----------------------------------------------+ +# | | +# | NOT PROXY BUDGET MANAGER | +# | proxy budget manager is in proxy_server.py | +# | | +# +-----------------------------------------------+ +# +# Thank you users! We ❤️ you! - Krrish & Ishaan + +import json +import os +import threading +import time +from typing import Literal, Optional + +import litellm +from litellm.utils import ModelResponse + + +class BudgetManager: + def __init__( + self, + project_name: str, + client_type: str = "local", + api_base: Optional[str] = None, + headers: Optional[dict] = None, + ): + self.client_type = client_type + self.project_name = project_name + self.api_base = api_base or "https://api.litellm.ai" + self.headers = headers or {"Content-Type": "application/json"} + ## load the data or init the initial dictionaries + self.load_data() + + def print_verbose(self, print_statement): + try: + if litellm.set_verbose: + import logging + + logging.info(print_statement) + except Exception: + pass + + def load_data(self): + if self.client_type == "local": + # Check if user dict file exists + if os.path.isfile("user_cost.json"): + # Load the user dict + with open("user_cost.json", "r") as json_file: + self.user_dict = json.load(json_file) + else: + self.print_verbose("User Dictionary not found!") + self.user_dict = {} + self.print_verbose(f"user dict from local: {self.user_dict}") + elif self.client_type == "hosted": + # Load the user_dict from hosted db + url = self.api_base + "/get_budget" + data = {"project_name": self.project_name} + response = litellm.module_level_client.post( + url, headers=self.headers, json=data + ) + response = response.json() + if response["status"] == "error": + self.user_dict = ( + {} + ) # assume this means the user dict hasn't been stored yet + else: + self.user_dict = response["data"] + + def create_budget( + self, + total_budget: float, + user: str, + duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None, + created_at: float = time.time(), + ): + self.user_dict[user] = {"total_budget": total_budget} + if duration is None: + return self.user_dict[user] + + if duration == "daily": + duration_in_days = 1 + elif duration == "weekly": + duration_in_days = 7 + elif duration == "monthly": + duration_in_days = 28 + elif duration == "yearly": + duration_in_days = 365 + else: + raise ValueError( + """duration needs to be one of ["daily", "weekly", "monthly", "yearly"]""" + ) + self.user_dict[user] = { + "total_budget": total_budget, + "duration": duration_in_days, + "created_at": created_at, + "last_updated_at": created_at, + } + self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution + return self.user_dict[user] + + def projected_cost(self, model: str, messages: list, user: str): + text = "".join(message["content"] for message in messages) + prompt_tokens = litellm.token_counter(model=model, text=text) + prompt_cost, _ = litellm.cost_per_token( + model=model, prompt_tokens=prompt_tokens, completion_tokens=0 + ) + current_cost = self.user_dict[user].get("current_cost", 0) + projected_cost = prompt_cost + current_cost + return projected_cost + + def get_total_budget(self, user: str): + return self.user_dict[user]["total_budget"] + + def update_cost( + self, + user: str, + completion_obj: Optional[ModelResponse] = None, + model: Optional[str] = None, + input_text: Optional[str] = None, + output_text: Optional[str] = None, + ): + if model and input_text and output_text: + prompt_tokens = litellm.token_counter( + model=model, messages=[{"role": "user", "content": input_text}] + ) + completion_tokens = litellm.token_counter( + model=model, messages=[{"role": "user", "content": output_text}] + ) + ( + prompt_tokens_cost_usd_dollar, + completion_tokens_cost_usd_dollar, + ) = litellm.cost_per_token( + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar + elif completion_obj: + cost = litellm.completion_cost(completion_response=completion_obj) + model = completion_obj[ + "model" + ] # if this throws an error try, model = completion_obj['model'] + else: + raise ValueError( + "Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager" + ) + + self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get( + "current_cost", 0 + ) + if "model_cost" in self.user_dict[user]: + self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][ + "model_cost" + ].get(model, 0) + else: + self.user_dict[user]["model_cost"] = {model: cost} + + self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution + return {"user": self.user_dict[user]} + + def get_current_cost(self, user): + return self.user_dict[user].get("current_cost", 0) + + def get_model_cost(self, user): + return self.user_dict[user].get("model_cost", 0) + + def is_valid_user(self, user: str) -> bool: + return user in self.user_dict + + def get_users(self): + return list(self.user_dict.keys()) + + def reset_cost(self, user): + self.user_dict[user]["current_cost"] = 0 + self.user_dict[user]["model_cost"] = {} + return {"user": self.user_dict[user]} + + def reset_on_duration(self, user: str): + # Get current and creation time + last_updated_at = self.user_dict[user]["last_updated_at"] + current_time = time.time() + + # Convert duration from days to seconds + duration_in_seconds = self.user_dict[user]["duration"] * 24 * 60 * 60 + + # Check if duration has elapsed + if current_time - last_updated_at >= duration_in_seconds: + # Reset cost if duration has elapsed and update the creation time + self.reset_cost(user) + self.user_dict[user]["last_updated_at"] = current_time + self._save_data_thread() # Save the data + + def update_budget_all_users(self): + for user in self.get_users(): + if "duration" in self.user_dict[user]: + self.reset_on_duration(user) + + def _save_data_thread(self): + thread = threading.Thread( + target=self.save_data + ) # [Non-Blocking]: saves data without blocking execution + thread.start() + + def save_data(self): + if self.client_type == "local": + import json + + # save the user dict + with open("user_cost.json", "w") as json_file: + json.dump( + self.user_dict, json_file, indent=4 + ) # Indent for pretty formatting + return {"status": "success"} + elif self.client_type == "hosted": + url = self.api_base + "/set_budget" + data = {"project_name": self.project_name, "user_dict": self.user_dict} + response = litellm.module_level_client.post( + url, headers=self.headers, json=data + ) + response = response.json() + return response |