diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/common_utils')
10 files changed, 1651 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/admin_ui_utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/admin_ui_utils.py new file mode 100644 index 00000000..204032ac --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/admin_ui_utils.py @@ -0,0 +1,240 @@ +import os + + +def show_missing_vars_in_env(): + from fastapi.responses import HTMLResponse + + from litellm.proxy.proxy_server import master_key, prisma_client + + if prisma_client is None and master_key is None: + return HTMLResponse( + content=missing_keys_form( + missing_key_names="DATABASE_URL, LITELLM_MASTER_KEY" + ), + status_code=200, + ) + if prisma_client is None: + return HTMLResponse( + content=missing_keys_form(missing_key_names="DATABASE_URL"), status_code=200 + ) + + if master_key is None: + return HTMLResponse( + content=missing_keys_form(missing_key_names="LITELLM_MASTER_KEY"), + status_code=200, + ) + return None + + +# LiteLLM Admin UI - Non SSO Login +url_to_redirect_to = os.getenv("PROXY_BASE_URL", "") +url_to_redirect_to += "/login" +html_form = f""" +<!DOCTYPE html> +<html> +<head> + <title>LiteLLM Login</title> + <style> + body {{ + font-family: Arial, sans-serif; + background-color: #f4f4f4; + margin: 0; + padding: 0; + display: flex; + justify-content: center; + align-items: center; + height: 100vh; + }} + + form {{ + background-color: #fff; + padding: 20px; + border-radius: 8px; + box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); + }} + + label {{ + display: block; + margin-bottom: 8px; + }} + + input {{ + width: 100%; + padding: 8px; + margin-bottom: 16px; + box-sizing: border-box; + border: 1px solid #ccc; + border-radius: 4px; + }} + + input[type="submit"] {{ + background-color: #4caf50; + color: #fff; + cursor: pointer; + }} + + input[type="submit"]:hover {{ + background-color: #45a049; + }} + </style> +</head> +<body> + <form action="{url_to_redirect_to}" method="post"> + <h2>LiteLLM Login</h2> + + <p>By default Username is "admin" and Password is your set LiteLLM Proxy `MASTER_KEY`</p> + <p>If you need to set UI credentials / SSO docs here: <a href="https://docs.litellm.ai/docs/proxy/ui" target="_blank">https://docs.litellm.ai/docs/proxy/ui</a></p> + <br> + <label for="username">Username:</label> + <input type="text" id="username" name="username" required> + <label for="password">Password:</label> + <input type="password" id="password" name="password" required> + <input type="submit" value="Submit"> + </form> +""" + + +def missing_keys_form(missing_key_names: str): + missing_keys_html_form = """ + <!DOCTYPE html> + <html lang="en"> + <head> + <meta charset="UTF-8"> + <meta name="viewport" content="width=device-width, initial-scale=1.0"> + <style> + body {{ + font-family: Arial, sans-serif; + background-color: #f4f4f9; + color: #333; + margin: 20px; + line-height: 1.6; + }} + .container {{ + max-width: 800px; + margin: auto; + padding: 20px; + background: #fff; + border: 1px solid #ddd; + border-radius: 5px; + box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); + }} + h1 {{ + font-size: 24px; + margin-bottom: 20px; + }} + pre {{ + background: #f8f8f8; + padding: 1px; + border: 1px solid #ccc; + border-radius: 4px; + overflow-x: auto; + font-size: 14px; + }} + .env-var {{ + font-weight: normal; + }} + .comment {{ + font-weight: normal; + color: #777; + }} + </style> + <title>Environment Setup Instructions</title> + </head> + <body> + <div class="container"> + <h1>Environment Setup Instructions</h1> + <p>Please add the following variables to your environment variables:</p> + <pre> + <span class="env-var">LITELLM_MASTER_KEY="sk-1234"</span> <span class="comment"># Your master key for the proxy server. Can use this to send /chat/completion requests etc</span> + <span class="env-var">LITELLM_SALT_KEY="sk-XXXXXXXX"</span> <span class="comment"># Can NOT CHANGE THIS ONCE SET - It is used to encrypt/decrypt credentials stored in DB. If value of 'LITELLM_SALT_KEY' changes your models cannot be retrieved from DB</span> + <span class="env-var">DATABASE_URL="postgres://..."</span> <span class="comment"># Need a postgres database? (Check out Supabase, Neon, etc)</span> + <span class="comment">## OPTIONAL ##</span> + <span class="env-var">PORT=4000</span> <span class="comment"># DO THIS FOR RENDER/RAILWAY</span> + <span class="env-var">STORE_MODEL_IN_DB="True"</span> <span class="comment"># Allow storing models in db</span> + </pre> + <h1>Missing Environment Variables</h1> + <p>{missing_keys}</p> + </div> + + <div class="container"> + <h1>Need Help? Support</h1> + <p>Discord: <a href="https://discord.com/invite/wuPM9dRgDw" target="_blank">https://discord.com/invite/wuPM9dRgDw</a></p> + <p>Docs: <a href="https://docs.litellm.ai/docs/" target="_blank">https://docs.litellm.ai/docs/</a></p> + </div> + </body> + </html> + """ + return missing_keys_html_form.format(missing_keys=missing_key_names) + + +def admin_ui_disabled(): + from fastapi.responses import HTMLResponse + + ui_disabled_html = """ + <!DOCTYPE html> + <html lang="en"> + <head> + <meta charset="UTF-8"> + <meta name="viewport" content="width=device-width, initial-scale=1.0"> + <style> + body {{ + font-family: Arial, sans-serif; + background-color: #f4f4f9; + color: #333; + margin: 20px; + line-height: 1.6; + }} + .container {{ + max-width: 800px; + margin: auto; + padding: 20px; + background: #fff; + border: 1px solid #ddd; + border-radius: 5px; + box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); + }} + h1 {{ + font-size: 24px; + margin-bottom: 20px; + }} + pre {{ + background: #f8f8f8; + padding: 1px; + border: 1px solid #ccc; + border-radius: 4px; + overflow-x: auto; + font-size: 14px; + }} + .env-var {{ + font-weight: normal; + }} + .comment {{ + font-weight: normal; + color: #777; + }} + </style> + <title>Admin UI Disabled</title> + </head> + <body> + <div class="container"> + <h1>Admin UI is Disabled</h1> + <p>The Admin UI has been disabled by the administrator. To re-enable it, please update the following environment variable:</p> + <pre> + <span class="env-var">DISABLE_ADMIN_UI="False"</span> <span class="comment"># Set this to "False" to enable the Admin UI.</span> + </pre> + <p>After making this change, restart the application for it to take effect.</p> + </div> + + <div class="container"> + <h1>Need Help? Support</h1> + <p>Discord: <a href="https://discord.com/invite/wuPM9dRgDw" target="_blank">https://discord.com/invite/wuPM9dRgDw</a></p> + <p>Docs: <a href="https://docs.litellm.ai/docs/" target="_blank">https://docs.litellm.ai/docs/</a></p> + </div> + </body> + </html> + """ + + return HTMLResponse( + content=ui_disabled_html, + status_code=200, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/callback_utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/callback_utils.py new file mode 100644 index 00000000..2280e72e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/callback_utils.py @@ -0,0 +1,319 @@ +from typing import Any, Dict, List, Optional + +import litellm +from litellm import get_secret +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import CommonProxyErrors, LiteLLMPromptInjectionParams +from litellm.proxy.types_utils.utils import get_instance_fn + +blue_color_code = "\033[94m" +reset_color_code = "\033[0m" + + +def initialize_callbacks_on_proxy( # noqa: PLR0915 + value: Any, + premium_user: bool, + config_file_path: str, + litellm_settings: dict, + callback_specific_params: dict = {}, +): + from litellm.proxy.proxy_server import prisma_client + + verbose_proxy_logger.debug( + f"{blue_color_code}initializing callbacks={value} on proxy{reset_color_code}" + ) + if isinstance(value, list): + imported_list: List[Any] = [] + for callback in value: # ["presidio", <my-custom-callback>] + if ( + isinstance(callback, str) + and callback in litellm._known_custom_logger_compatible_callbacks + ): + imported_list.append(callback) + elif isinstance(callback, str) and callback == "presidio": + from litellm.proxy.guardrails.guardrail_hooks.presidio import ( + _OPTIONAL_PresidioPIIMasking, + ) + + presidio_logging_only: Optional[bool] = litellm_settings.get( + "presidio_logging_only", None + ) + if presidio_logging_only is not None: + presidio_logging_only = bool( + presidio_logging_only + ) # validate boolean given + + _presidio_params = {} + if "presidio" in callback_specific_params and isinstance( + callback_specific_params["presidio"], dict + ): + _presidio_params = callback_specific_params["presidio"] + + params: Dict[str, Any] = { + "logging_only": presidio_logging_only, + **_presidio_params, + } + pii_masking_object = _OPTIONAL_PresidioPIIMasking(**params) + imported_list.append(pii_masking_object) + elif isinstance(callback, str) and callback == "llamaguard_moderations": + from enterprise.enterprise_hooks.llama_guard import ( + _ENTERPRISE_LlamaGuard, + ) + + if premium_user is not True: + raise Exception( + "Trying to use Llama Guard" + + CommonProxyErrors.not_premium_user.value + ) + + llama_guard_object = _ENTERPRISE_LlamaGuard() + imported_list.append(llama_guard_object) + elif isinstance(callback, str) and callback == "hide_secrets": + from enterprise.enterprise_hooks.secret_detection import ( + _ENTERPRISE_SecretDetection, + ) + + if premium_user is not True: + raise Exception( + "Trying to use secret hiding" + + CommonProxyErrors.not_premium_user.value + ) + + _secret_detection_object = _ENTERPRISE_SecretDetection() + imported_list.append(_secret_detection_object) + elif isinstance(callback, str) and callback == "openai_moderations": + from enterprise.enterprise_hooks.openai_moderation import ( + _ENTERPRISE_OpenAI_Moderation, + ) + + if premium_user is not True: + raise Exception( + "Trying to use OpenAI Moderations Check" + + CommonProxyErrors.not_premium_user.value + ) + + openai_moderations_object = _ENTERPRISE_OpenAI_Moderation() + imported_list.append(openai_moderations_object) + elif isinstance(callback, str) and callback == "lakera_prompt_injection": + from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import ( + lakeraAI_Moderation, + ) + + init_params = {} + if "lakera_prompt_injection" in callback_specific_params: + init_params = callback_specific_params["lakera_prompt_injection"] + lakera_moderations_object = lakeraAI_Moderation(**init_params) + imported_list.append(lakera_moderations_object) + elif isinstance(callback, str) and callback == "aporia_prompt_injection": + from litellm.proxy.guardrails.guardrail_hooks.aporia_ai import ( + AporiaGuardrail, + ) + + aporia_guardrail_object = AporiaGuardrail() + imported_list.append(aporia_guardrail_object) + elif isinstance(callback, str) and callback == "google_text_moderation": + from enterprise.enterprise_hooks.google_text_moderation import ( + _ENTERPRISE_GoogleTextModeration, + ) + + if premium_user is not True: + raise Exception( + "Trying to use Google Text Moderation" + + CommonProxyErrors.not_premium_user.value + ) + + google_text_moderation_obj = _ENTERPRISE_GoogleTextModeration() + imported_list.append(google_text_moderation_obj) + elif isinstance(callback, str) and callback == "llmguard_moderations": + from enterprise.enterprise_hooks.llm_guard import _ENTERPRISE_LLMGuard + + if premium_user is not True: + raise Exception( + "Trying to use Llm Guard" + + CommonProxyErrors.not_premium_user.value + ) + + llm_guard_moderation_obj = _ENTERPRISE_LLMGuard() + imported_list.append(llm_guard_moderation_obj) + elif isinstance(callback, str) and callback == "blocked_user_check": + from enterprise.enterprise_hooks.blocked_user_list import ( + _ENTERPRISE_BlockedUserList, + ) + + if premium_user is not True: + raise Exception( + "Trying to use ENTERPRISE BlockedUser" + + CommonProxyErrors.not_premium_user.value + ) + + blocked_user_list = _ENTERPRISE_BlockedUserList( + prisma_client=prisma_client + ) + imported_list.append(blocked_user_list) + elif isinstance(callback, str) and callback == "banned_keywords": + from enterprise.enterprise_hooks.banned_keywords import ( + _ENTERPRISE_BannedKeywords, + ) + + if premium_user is not True: + raise Exception( + "Trying to use ENTERPRISE BannedKeyword" + + CommonProxyErrors.not_premium_user.value + ) + + banned_keywords_obj = _ENTERPRISE_BannedKeywords() + imported_list.append(banned_keywords_obj) + elif isinstance(callback, str) and callback == "detect_prompt_injection": + from litellm.proxy.hooks.prompt_injection_detection import ( + _OPTIONAL_PromptInjectionDetection, + ) + + prompt_injection_params = None + if "prompt_injection_params" in litellm_settings: + prompt_injection_params_in_config = litellm_settings[ + "prompt_injection_params" + ] + prompt_injection_params = LiteLLMPromptInjectionParams( + **prompt_injection_params_in_config + ) + + prompt_injection_detection_obj = _OPTIONAL_PromptInjectionDetection( + prompt_injection_params=prompt_injection_params, + ) + imported_list.append(prompt_injection_detection_obj) + elif isinstance(callback, str) and callback == "batch_redis_requests": + from litellm.proxy.hooks.batch_redis_get import ( + _PROXY_BatchRedisRequests, + ) + + batch_redis_obj = _PROXY_BatchRedisRequests() + imported_list.append(batch_redis_obj) + elif isinstance(callback, str) and callback == "azure_content_safety": + from litellm.proxy.hooks.azure_content_safety import ( + _PROXY_AzureContentSafety, + ) + + azure_content_safety_params = litellm_settings[ + "azure_content_safety_params" + ] + for k, v in azure_content_safety_params.items(): + if ( + v is not None + and isinstance(v, str) + and v.startswith("os.environ/") + ): + azure_content_safety_params[k] = get_secret(v) + + azure_content_safety_obj = _PROXY_AzureContentSafety( + **azure_content_safety_params, + ) + imported_list.append(azure_content_safety_obj) + else: + verbose_proxy_logger.debug( + f"{blue_color_code} attempting to import custom calback={callback} {reset_color_code}" + ) + imported_list.append( + get_instance_fn( + value=callback, + config_file_path=config_file_path, + ) + ) + if isinstance(litellm.callbacks, list): + litellm.callbacks.extend(imported_list) + else: + litellm.callbacks = imported_list # type: ignore + + if "prometheus" in value: + if premium_user is not True: + verbose_proxy_logger.warning( + f"Prometheus metrics are only available for premium users. {CommonProxyErrors.not_premium_user.value}" + ) + from litellm.proxy.proxy_server import app + + verbose_proxy_logger.debug("Starting Prometheus Metrics on /metrics") + from prometheus_client import make_asgi_app + + # Add prometheus asgi middleware to route /metrics requests + metrics_app = make_asgi_app() + app.mount("/metrics", metrics_app) + else: + litellm.callbacks = [ + get_instance_fn( + value=value, + config_file_path=config_file_path, + ) + ] + verbose_proxy_logger.debug( + f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}" + ) + + +def get_model_group_from_litellm_kwargs(kwargs: dict) -> Optional[str]: + _litellm_params = kwargs.get("litellm_params", None) or {} + _metadata = _litellm_params.get("metadata", None) or {} + _model_group = _metadata.get("model_group", None) + if _model_group is not None: + return _model_group + + return None + + +def get_model_group_from_request_data(data: dict) -> Optional[str]: + _metadata = data.get("metadata", None) or {} + _model_group = _metadata.get("model_group", None) + if _model_group is not None: + return _model_group + + return None + + +def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str, str]: + """ + Helper function to return x-litellm-key-remaining-tokens-{model_group} and x-litellm-key-remaining-requests-{model_group} + + Returns {} when api_key + model rpm/tpm limit is not set + + """ + headers = {} + _metadata = data.get("metadata", None) or {} + model_group = get_model_group_from_request_data(data) + + # Remaining Requests + remaining_requests_variable_name = f"litellm-key-remaining-requests-{model_group}" + remaining_requests = _metadata.get(remaining_requests_variable_name, None) + if remaining_requests: + headers[f"x-litellm-key-remaining-requests-{model_group}"] = remaining_requests + + # Remaining Tokens + remaining_tokens_variable_name = f"litellm-key-remaining-tokens-{model_group}" + remaining_tokens = _metadata.get(remaining_tokens_variable_name, None) + if remaining_tokens: + headers[f"x-litellm-key-remaining-tokens-{model_group}"] = remaining_tokens + + return headers + + +def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]: + _metadata = request_data.get("metadata", None) or {} + headers = {} + if "applied_guardrails" in _metadata: + headers["x-litellm-applied-guardrails"] = ",".join( + _metadata["applied_guardrails"] + ) + + if "semantic-similarity" in _metadata: + headers["x-litellm-semantic-similarity"] = str(_metadata["semantic-similarity"]) + + return headers + + +def add_guardrail_to_applied_guardrails_header( + request_data: Dict, guardrail_name: Optional[str] +): + if guardrail_name is None: + return + _metadata = request_data.get("metadata", None) or {} + if "applied_guardrails" in _metadata: + _metadata["applied_guardrails"].append(guardrail_name) + else: + _metadata["applied_guardrails"] = [guardrail_name] diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/debug_utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/debug_utils.py new file mode 100644 index 00000000..fdfbe0cb --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/debug_utils.py @@ -0,0 +1,242 @@ +# Start tracing memory allocations +import json +import os +import tracemalloc + +from fastapi import APIRouter + +from litellm import get_secret_str +from litellm._logging import verbose_proxy_logger + +router = APIRouter() + +if os.environ.get("LITELLM_PROFILE", "false").lower() == "true": + try: + import objgraph # type: ignore + + print("growth of objects") # noqa + objgraph.show_growth() + print("\n\nMost common types") # noqa + objgraph.show_most_common_types() + roots = objgraph.get_leaking_objects() + print("\n\nLeaking objects") # noqa + objgraph.show_most_common_types(objects=roots) + except ImportError: + raise ImportError( + "objgraph not found. Please install objgraph to use this feature." + ) + + tracemalloc.start(10) + + @router.get("/memory-usage", include_in_schema=False) + async def memory_usage(): + # Take a snapshot of the current memory usage + snapshot = tracemalloc.take_snapshot() + top_stats = snapshot.statistics("lineno") + verbose_proxy_logger.debug("TOP STATS: %s", top_stats) + + # Get the top 50 memory usage lines + top_50 = top_stats[:50] + result = [] + for stat in top_50: + result.append(f"{stat.traceback.format(limit=10)}: {stat.size / 1024} KiB") + + return {"top_50_memory_usage": result} + + +@router.get("/memory-usage-in-mem-cache", include_in_schema=False) +async def memory_usage_in_mem_cache(): + # returns the size of all in-memory caches on the proxy server + """ + 1. user_api_key_cache + 2. router_cache + 3. proxy_logging_cache + 4. internal_usage_cache + """ + from litellm.proxy.proxy_server import ( + llm_router, + proxy_logging_obj, + user_api_key_cache, + ) + + if llm_router is None: + num_items_in_llm_router_cache = 0 + else: + num_items_in_llm_router_cache = len( + llm_router.cache.in_memory_cache.cache_dict + ) + len(llm_router.cache.in_memory_cache.ttl_dict) + + num_items_in_user_api_key_cache = len( + user_api_key_cache.in_memory_cache.cache_dict + ) + len(user_api_key_cache.in_memory_cache.ttl_dict) + + num_items_in_proxy_logging_obj_cache = len( + proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict + ) + len(proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict) + + return { + "num_items_in_user_api_key_cache": num_items_in_user_api_key_cache, + "num_items_in_llm_router_cache": num_items_in_llm_router_cache, + "num_items_in_proxy_logging_obj_cache": num_items_in_proxy_logging_obj_cache, + } + + +@router.get("/memory-usage-in-mem-cache-items", include_in_schema=False) +async def memory_usage_in_mem_cache_items(): + # returns the size of all in-memory caches on the proxy server + """ + 1. user_api_key_cache + 2. router_cache + 3. proxy_logging_cache + 4. internal_usage_cache + """ + from litellm.proxy.proxy_server import ( + llm_router, + proxy_logging_obj, + user_api_key_cache, + ) + + if llm_router is None: + llm_router_in_memory_cache_dict = {} + llm_router_in_memory_ttl_dict = {} + else: + llm_router_in_memory_cache_dict = llm_router.cache.in_memory_cache.cache_dict + llm_router_in_memory_ttl_dict = llm_router.cache.in_memory_cache.ttl_dict + + return { + "user_api_key_cache": user_api_key_cache.in_memory_cache.cache_dict, + "user_api_key_ttl": user_api_key_cache.in_memory_cache.ttl_dict, + "llm_router_cache": llm_router_in_memory_cache_dict, + "llm_router_ttl": llm_router_in_memory_ttl_dict, + "proxy_logging_obj_cache": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict, + "proxy_logging_obj_ttl": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict, + } + + +@router.get("/otel-spans", include_in_schema=False) +async def get_otel_spans(): + from litellm.proxy.proxy_server import open_telemetry_logger + + if open_telemetry_logger is None: + return { + "otel_spans": [], + "spans_grouped_by_parent": {}, + "most_recent_parent": None, + } + + otel_exporter = open_telemetry_logger.OTEL_EXPORTER + if hasattr(otel_exporter, "get_finished_spans"): + recorded_spans = otel_exporter.get_finished_spans() # type: ignore + else: + recorded_spans = [] + + print("Spans: ", recorded_spans) # noqa + + most_recent_parent = None + most_recent_start_time = 1000000 + spans_grouped_by_parent = {} + for span in recorded_spans: + if span.parent is not None: + parent_trace_id = span.parent.trace_id + if parent_trace_id not in spans_grouped_by_parent: + spans_grouped_by_parent[parent_trace_id] = [] + spans_grouped_by_parent[parent_trace_id].append(span.name) + + # check time of span + if span.start_time > most_recent_start_time: + most_recent_parent = parent_trace_id + most_recent_start_time = span.start_time + + # these are otel spans - get the span name + span_names = [span.name for span in recorded_spans] + return { + "otel_spans": span_names, + "spans_grouped_by_parent": spans_grouped_by_parent, + "most_recent_parent": most_recent_parent, + } + + +# Helper functions for debugging +def init_verbose_loggers(): + try: + worker_config = get_secret_str("WORKER_CONFIG") + # if not, assume it's a json string + if worker_config is None: + return + if os.path.isfile(worker_config): + return + _settings = json.loads(worker_config) + if not isinstance(_settings, dict): + return + + debug = _settings.get("debug", None) + detailed_debug = _settings.get("detailed_debug", None) + if debug is True: # this needs to be first, so users can see Router init debugg + import logging + + from litellm._logging import ( + verbose_logger, + verbose_proxy_logger, + verbose_router_logger, + ) + + # this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS + verbose_logger.setLevel(level=logging.INFO) # sets package logs to info + verbose_router_logger.setLevel( + level=logging.INFO + ) # set router logs to info + verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info + if detailed_debug is True: + import logging + + from litellm._logging import ( + verbose_logger, + verbose_proxy_logger, + verbose_router_logger, + ) + + verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug + verbose_router_logger.setLevel( + level=logging.DEBUG + ) # set router logs to debug + verbose_proxy_logger.setLevel( + level=logging.DEBUG + ) # set proxy logs to debug + elif debug is False and detailed_debug is False: + # users can control proxy debugging using env variable = 'LITELLM_LOG' + litellm_log_setting = os.environ.get("LITELLM_LOG", "") + if litellm_log_setting is not None: + if litellm_log_setting.upper() == "INFO": + import logging + + from litellm._logging import ( + verbose_proxy_logger, + verbose_router_logger, + ) + + # this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS + + verbose_router_logger.setLevel( + level=logging.INFO + ) # set router logs to info + verbose_proxy_logger.setLevel( + level=logging.INFO + ) # set proxy logs to info + elif litellm_log_setting.upper() == "DEBUG": + import logging + + from litellm._logging import ( + verbose_proxy_logger, + verbose_router_logger, + ) + + verbose_router_logger.setLevel( + level=logging.DEBUG + ) # set router logs to info + verbose_proxy_logger.setLevel( + level=logging.DEBUG + ) # set proxy logs to debug + except Exception as e: + import logging + + logging.warning(f"Failed to init verbose loggers: {str(e)}") diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/encrypt_decrypt_utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/encrypt_decrypt_utils.py new file mode 100644 index 00000000..ec9279a0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/encrypt_decrypt_utils.py @@ -0,0 +1,104 @@ +import base64 +import os +from typing import Optional + +from litellm._logging import verbose_proxy_logger + + +def _get_salt_key(): + from litellm.proxy.proxy_server import master_key + + salt_key = os.getenv("LITELLM_SALT_KEY", None) + + if salt_key is None: + verbose_proxy_logger.debug( + "LITELLM_SALT_KEY is None using master_key to encrypt/decrypt secrets stored in DB" + ) + + salt_key = master_key + + return salt_key + + +def encrypt_value_helper(value: str, new_encryption_key: Optional[str] = None): + + signing_key = new_encryption_key or _get_salt_key() + + try: + if isinstance(value, str): + encrypted_value = encrypt_value(value=value, signing_key=signing_key) # type: ignore + encrypted_value = base64.b64encode(encrypted_value).decode("utf-8") + + return encrypted_value + + verbose_proxy_logger.debug( + f"Invalid value type passed to encrypt_value: {type(value)} for Value: {value}\n Value must be a string" + ) + # if it's not a string - do not encrypt it and return the value + return value + except Exception as e: + raise e + + +def decrypt_value_helper(value: str): + + signing_key = _get_salt_key() + + try: + if isinstance(value, str): + decoded_b64 = base64.b64decode(value) + value = decrypt_value(value=decoded_b64, signing_key=signing_key) # type: ignore + return value + + # if it's not str - do not decrypt it, return the value + return value + except Exception as e: + import traceback + + traceback.print_stack() + verbose_proxy_logger.error( + f"Error decrypting value, Did your master_key/salt key change recently? \nError: {str(e)}\nSet permanent salt key - https://docs.litellm.ai/docs/proxy/prod#5-set-litellm-salt-key" + ) + # [Non-Blocking Exception. - this should not block decrypting other values] + pass + + +def encrypt_value(value: str, signing_key: str): + import hashlib + + import nacl.secret + import nacl.utils + + # get 32 byte master key # + hash_object = hashlib.sha256(signing_key.encode()) + hash_bytes = hash_object.digest() + + # initialize secret box # + box = nacl.secret.SecretBox(hash_bytes) + + # encode message # + value_bytes = value.encode("utf-8") + + encrypted = box.encrypt(value_bytes) + + return encrypted + + +def decrypt_value(value: bytes, signing_key: str) -> str: + import hashlib + + import nacl.secret + import nacl.utils + + # get 32 byte master key # + hash_object = hashlib.sha256(signing_key.encode()) + hash_bytes = hash_object.digest() + + # initialize secret box # + box = nacl.secret.SecretBox(hash_bytes) + + # Convert the bytes object to a string + plaintext = box.decrypt(value) + + plaintext = plaintext.decode("utf-8") # type: ignore + return plaintext # type: ignore diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/http_parsing_utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/http_parsing_utils.py new file mode 100644 index 00000000..5736ee21 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/http_parsing_utils.py @@ -0,0 +1,182 @@ +import json +from typing import Dict, List, Optional + +import orjson +from fastapi import Request, UploadFile, status + +from litellm._logging import verbose_proxy_logger +from litellm.types.router import Deployment + + +async def _read_request_body(request: Optional[Request]) -> Dict: + """ + Safely read the request body and parse it as JSON. + + Parameters: + - request: The request object to read the body from + + Returns: + - dict: Parsed request data as a dictionary or an empty dictionary if parsing fails + """ + try: + if request is None: + return {} + + # Check if we already read and parsed the body + _cached_request_body: Optional[dict] = _safe_get_request_parsed_body( + request=request + ) + if _cached_request_body is not None: + return _cached_request_body + + _request_headers: dict = _safe_get_request_headers(request=request) + content_type = _request_headers.get("content-type", "") + + if "form" in content_type: + parsed_body = dict(await request.form()) + else: + # Read the request body + body = await request.body() + + # Return empty dict if body is empty or None + if not body: + parsed_body = {} + else: + try: + parsed_body = orjson.loads(body) + except orjson.JSONDecodeError: + # Fall back to the standard json module which is more forgiving + # First decode bytes to string if needed + body_str = body.decode("utf-8") if isinstance(body, bytes) else body + + # Replace invalid surrogate pairs + import re + + # This regex finds incomplete surrogate pairs + body_str = re.sub( + r"[\uD800-\uDBFF](?![\uDC00-\uDFFF])", "", body_str + ) + # This regex finds low surrogates without high surrogates + body_str = re.sub( + r"(?<![\uD800-\uDBFF])[\uDC00-\uDFFF]", "", body_str + ) + + parsed_body = json.loads(body_str) + + # Cache the parsed result + _safe_set_request_parsed_body(request=request, parsed_body=parsed_body) + return parsed_body + + except (json.JSONDecodeError, orjson.JSONDecodeError): + verbose_proxy_logger.exception("Invalid JSON payload received.") + return {} + except Exception as e: + # Catch unexpected errors to avoid crashes + verbose_proxy_logger.exception( + "Unexpected error reading request body - {}".format(e) + ) + return {} + + +def _safe_get_request_parsed_body(request: Optional[Request]) -> Optional[dict]: + if request is None: + return None + if hasattr(request, "scope") and "parsed_body" in request.scope: + return request.scope["parsed_body"] + return None + + +def _safe_set_request_parsed_body( + request: Optional[Request], + parsed_body: dict, +) -> None: + try: + if request is None: + return + request.scope["parsed_body"] = parsed_body + except Exception as e: + verbose_proxy_logger.debug( + "Unexpected error setting request parsed body - {}".format(e) + ) + + +def _safe_get_request_headers(request: Optional[Request]) -> dict: + """ + [Non-Blocking] Safely get the request headers + """ + try: + if request is None: + return {} + return dict(request.headers) + except Exception as e: + verbose_proxy_logger.debug( + "Unexpected error reading request headers - {}".format(e) + ) + return {} + + +def check_file_size_under_limit( + request_data: dict, + file: UploadFile, + router_model_names: List[str], +) -> bool: + """ + Check if any files passed in request are under max_file_size_mb + + Returns True -> when file size is under max_file_size_mb limit + Raises ProxyException -> when file size is over max_file_size_mb limit or not a premium_user + """ + from litellm.proxy.proxy_server import ( + CommonProxyErrors, + ProxyException, + llm_router, + premium_user, + ) + + file_contents_size = file.size or 0 + file_content_size_in_mb = file_contents_size / (1024 * 1024) + if "metadata" not in request_data: + request_data["metadata"] = {} + request_data["metadata"]["file_size_in_mb"] = file_content_size_in_mb + max_file_size_mb = None + + if llm_router is not None and request_data["model"] in router_model_names: + try: + deployment: Optional[Deployment] = ( + llm_router.get_deployment_by_model_group_name( + model_group_name=request_data["model"] + ) + ) + if ( + deployment + and deployment.litellm_params is not None + and deployment.litellm_params.max_file_size_mb is not None + ): + max_file_size_mb = deployment.litellm_params.max_file_size_mb + except Exception as e: + verbose_proxy_logger.error( + "Got error when checking file size: %s", (str(e)) + ) + + if max_file_size_mb is not None: + verbose_proxy_logger.debug( + "Checking file size, file content size=%s, max_file_size_mb=%s", + file_content_size_in_mb, + max_file_size_mb, + ) + if not premium_user: + raise ProxyException( + message=f"Tried setting max_file_size_mb for /audio/transcriptions. {CommonProxyErrors.not_premium_user.value}", + code=status.HTTP_400_BAD_REQUEST, + type="bad_request", + param="file", + ) + if file_content_size_in_mb > max_file_size_mb: + raise ProxyException( + message=f"File size is too large. Please check your file size. Passed file size: {file_content_size_in_mb} MB. Max file size: {max_file_size_mb} MB", + code=status.HTTP_400_BAD_REQUEST, + type="bad_request", + param="file", + ) + + return True diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/load_config_utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/load_config_utils.py new file mode 100644 index 00000000..38e7b3f3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/load_config_utils.py @@ -0,0 +1,76 @@ +import yaml + +from litellm._logging import verbose_proxy_logger + + +def get_file_contents_from_s3(bucket_name, object_key): + try: + # v0 rely on boto3 for authentication - allowing boto3 to handle IAM credentials etc + import tempfile + + import boto3 + from botocore.credentials import Credentials + + from litellm.main import bedrock_converse_chat_completion + + credentials: Credentials = bedrock_converse_chat_completion.get_credentials() + s3_client = boto3.client( + "s3", + aws_access_key_id=credentials.access_key, + aws_secret_access_key=credentials.secret_key, + aws_session_token=credentials.token, # Optional, if using temporary credentials + ) + verbose_proxy_logger.debug( + f"Retrieving {object_key} from S3 bucket: {bucket_name}" + ) + response = s3_client.get_object(Bucket=bucket_name, Key=object_key) + verbose_proxy_logger.debug(f"Response: {response}") + + # Read the file contents + file_contents = response["Body"].read().decode("utf-8") + verbose_proxy_logger.debug("File contents retrieved from S3") + + # Create a temporary file with YAML extension + with tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") as temp_file: + temp_file.write(file_contents.encode("utf-8")) + temp_file_path = temp_file.name + verbose_proxy_logger.debug(f"File stored temporarily at: {temp_file_path}") + + # Load the YAML file content + with open(temp_file_path, "r") as yaml_file: + config = yaml.safe_load(yaml_file) + + return config + except ImportError as e: + # this is most likely if a user is not using the litellm docker container + verbose_proxy_logger.error(f"ImportError: {str(e)}") + pass + except Exception as e: + verbose_proxy_logger.error(f"Error retrieving file contents: {str(e)}") + return None + + +async def get_config_file_contents_from_gcs(bucket_name, object_key): + try: + from litellm.integrations.gcs_bucket.gcs_bucket import GCSBucketLogger + + gcs_bucket = GCSBucketLogger( + bucket_name=bucket_name, + ) + file_contents = await gcs_bucket.download_gcs_object(object_key) + if file_contents is None: + raise Exception(f"File contents are None for {object_key}") + # file_contentis is a bytes object, so we need to convert it to yaml + file_contents = file_contents.decode("utf-8") + # convert to yaml + config = yaml.safe_load(file_contents) + return config + + except Exception as e: + verbose_proxy_logger.error(f"Error retrieving file contents: {str(e)}") + return None + + +# # Example usage +# bucket_name = 'litellm-proxy' +# object_key = 'litellm_proxy_config.yaml' diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/openai_endpoint_utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/openai_endpoint_utils.py new file mode 100644 index 00000000..316a8427 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/openai_endpoint_utils.py @@ -0,0 +1,39 @@ +""" +Contains utils used by OpenAI compatible endpoints +""" + +from typing import Optional + +from fastapi import Request + +from litellm.proxy.common_utils.http_parsing_utils import _read_request_body + + +def remove_sensitive_info_from_deployment(deployment_dict: dict) -> dict: + """ + Removes sensitive information from a deployment dictionary. + + Args: + deployment_dict (dict): The deployment dictionary to remove sensitive information from. + + Returns: + dict: The modified deployment dictionary with sensitive information removed. + """ + deployment_dict["litellm_params"].pop("api_key", None) + deployment_dict["litellm_params"].pop("vertex_credentials", None) + deployment_dict["litellm_params"].pop("aws_access_key_id", None) + deployment_dict["litellm_params"].pop("aws_secret_access_key", None) + + return deployment_dict + + +async def get_custom_llm_provider_from_request_body(request: Request) -> Optional[str]: + """ + Get the `custom_llm_provider` from the request body + + Safely reads the request body + """ + request_body: dict = await _read_request_body(request=request) or {} + if "custom_llm_provider" in request_body: + return request_body["custom_llm_provider"] + return None diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/proxy_state.py b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/proxy_state.py new file mode 100644 index 00000000..edd18c60 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/proxy_state.py @@ -0,0 +1,36 @@ +""" +This file is used to store the state variables of the proxy server. + +Example: `spend_logs_row_count` is used to store the number of rows in the `LiteLLM_SpendLogs` table. +""" + +from typing import Any, Literal + +from litellm.proxy._types import ProxyStateVariables + + +class ProxyState: + """ + Proxy state class has get/set methods for Proxy state variables. + """ + + # Note: mypy does not recognize when we fetch ProxyStateVariables.annotations.keys(), so we also need to add the valid keys here + valid_keys_literal = Literal["spend_logs_row_count"] + + def __init__(self) -> None: + self.proxy_state_variables: ProxyStateVariables = ProxyStateVariables( + spend_logs_row_count=0, + ) + + def get_proxy_state_variable( + self, + variable_name: valid_keys_literal, + ) -> Any: + return self.proxy_state_variables.get(variable_name, None) + + def set_proxy_state_variable( + self, + variable_name: valid_keys_literal, + value: Any, + ) -> None: + self.proxy_state_variables[variable_name] = value diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/reset_budget_job.py b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/reset_budget_job.py new file mode 100644 index 00000000..1d50002f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/reset_budget_job.py @@ -0,0 +1,365 @@ +import asyncio +import json +import time +from datetime import datetime, timedelta +from typing import List, Literal, Optional, Union + +from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.duration_parser import duration_in_seconds +from litellm.proxy._types import ( + LiteLLM_TeamTable, + LiteLLM_UserTable, + LiteLLM_VerificationToken, +) +from litellm.proxy.utils import PrismaClient, ProxyLogging +from litellm.types.services import ServiceTypes + + +class ResetBudgetJob: + """ + Resets the budget for all the keys, users, and teams that need it + """ + + def __init__(self, proxy_logging_obj: ProxyLogging, prisma_client: PrismaClient): + self.proxy_logging_obj: ProxyLogging = proxy_logging_obj + self.prisma_client: PrismaClient = prisma_client + + async def reset_budget( + self, + ): + """ + Gets all the non-expired keys for a db, which need spend to be reset + + Resets their spend + + Updates db + """ + if self.prisma_client is not None: + ### RESET KEY BUDGET ### + await self.reset_budget_for_litellm_keys() + + ### RESET USER BUDGET ### + await self.reset_budget_for_litellm_users() + + ## Reset Team Budget + await self.reset_budget_for_litellm_teams() + + async def reset_budget_for_litellm_keys(self): + """ + Resets the budget for all the litellm keys + + Catches Exceptions and logs them + """ + now = datetime.utcnow() + start_time = time.time() + keys_to_reset: Optional[List[LiteLLM_VerificationToken]] = None + try: + keys_to_reset = await self.prisma_client.get_data( + table_name="key", query_type="find_all", expires=now, reset_at=now + ) + verbose_proxy_logger.debug( + "Keys to reset %s", json.dumps(keys_to_reset, indent=4, default=str) + ) + updated_keys: List[LiteLLM_VerificationToken] = [] + failed_keys = [] + if keys_to_reset is not None and len(keys_to_reset) > 0: + for key in keys_to_reset: + try: + updated_key = await ResetBudgetJob._reset_budget_for_key( + key=key, current_time=now + ) + if updated_key is not None: + updated_keys.append(updated_key) + else: + failed_keys.append( + {"key": key, "error": "Returned None without exception"} + ) + except Exception as e: + failed_keys.append({"key": key, "error": str(e)}) + verbose_proxy_logger.exception( + "Failed to reset budget for key: %s", key + ) + + verbose_proxy_logger.debug( + "Updated keys %s", json.dumps(updated_keys, indent=4, default=str) + ) + + if updated_keys: + await self.prisma_client.update_data( + query_type="update_many", + data_list=updated_keys, + table_name="key", + ) + + end_time = time.time() + if len(failed_keys) > 0: # If any keys failed to reset + raise Exception( + f"Failed to reset {len(failed_keys)} keys: {json.dumps(failed_keys, default=str)}" + ) + + asyncio.create_task( + self.proxy_logging_obj.service_logging_obj.async_service_success_hook( + service=ServiceTypes.RESET_BUDGET_JOB, + duration=end_time - start_time, + call_type="reset_budget_keys", + start_time=start_time, + end_time=end_time, + event_metadata={ + "num_keys_found": len(keys_to_reset) if keys_to_reset else 0, + "keys_found": json.dumps(keys_to_reset, indent=4, default=str), + "num_keys_updated": len(updated_keys), + "keys_updated": json.dumps(updated_keys, indent=4, default=str), + "num_keys_failed": len(failed_keys), + "keys_failed": json.dumps(failed_keys, indent=4, default=str), + }, + ) + ) + except Exception as e: + end_time = time.time() + asyncio.create_task( + self.proxy_logging_obj.service_logging_obj.async_service_failure_hook( + service=ServiceTypes.RESET_BUDGET_JOB, + duration=end_time - start_time, + error=e, + call_type="reset_budget_keys", + start_time=start_time, + end_time=end_time, + event_metadata={ + "num_keys_found": len(keys_to_reset) if keys_to_reset else 0, + "keys_found": json.dumps(keys_to_reset, indent=4, default=str), + }, + ) + ) + verbose_proxy_logger.exception("Failed to reset budget for keys: %s", e) + + async def reset_budget_for_litellm_users(self): + """ + Resets the budget for all LiteLLM Internal Users if their budget has expired + """ + now = datetime.utcnow() + start_time = time.time() + users_to_reset: Optional[List[LiteLLM_UserTable]] = None + try: + users_to_reset = await self.prisma_client.get_data( + table_name="user", query_type="find_all", reset_at=now + ) + updated_users: List[LiteLLM_UserTable] = [] + failed_users = [] + if users_to_reset is not None and len(users_to_reset) > 0: + for user in users_to_reset: + try: + updated_user = await ResetBudgetJob._reset_budget_for_user( + user=user, current_time=now + ) + if updated_user is not None: + updated_users.append(updated_user) + else: + failed_users.append( + { + "user": user, + "error": "Returned None without exception", + } + ) + except Exception as e: + failed_users.append({"user": user, "error": str(e)}) + verbose_proxy_logger.exception( + "Failed to reset budget for user: %s", user + ) + + verbose_proxy_logger.debug( + "Updated users %s", json.dumps(updated_users, indent=4, default=str) + ) + if updated_users: + await self.prisma_client.update_data( + query_type="update_many", + data_list=updated_users, + table_name="user", + ) + + end_time = time.time() + if len(failed_users) > 0: # If any users failed to reset + raise Exception( + f"Failed to reset {len(failed_users)} users: {json.dumps(failed_users, default=str)}" + ) + + asyncio.create_task( + self.proxy_logging_obj.service_logging_obj.async_service_success_hook( + service=ServiceTypes.RESET_BUDGET_JOB, + duration=end_time - start_time, + call_type="reset_budget_users", + start_time=start_time, + end_time=end_time, + event_metadata={ + "num_users_found": len(users_to_reset) if users_to_reset else 0, + "users_found": json.dumps( + users_to_reset, indent=4, default=str + ), + "num_users_updated": len(updated_users), + "users_updated": json.dumps( + updated_users, indent=4, default=str + ), + "num_users_failed": len(failed_users), + "users_failed": json.dumps(failed_users, indent=4, default=str), + }, + ) + ) + except Exception as e: + end_time = time.time() + asyncio.create_task( + self.proxy_logging_obj.service_logging_obj.async_service_failure_hook( + service=ServiceTypes.RESET_BUDGET_JOB, + duration=end_time - start_time, + error=e, + call_type="reset_budget_users", + start_time=start_time, + end_time=end_time, + event_metadata={ + "num_users_found": len(users_to_reset) if users_to_reset else 0, + "users_found": json.dumps( + users_to_reset, indent=4, default=str + ), + }, + ) + ) + verbose_proxy_logger.exception("Failed to reset budget for users: %s", e) + + async def reset_budget_for_litellm_teams(self): + """ + Resets the budget for all LiteLLM Internal Teams if their budget has expired + """ + now = datetime.utcnow() + start_time = time.time() + teams_to_reset: Optional[List[LiteLLM_TeamTable]] = None + try: + teams_to_reset = await self.prisma_client.get_data( + table_name="team", query_type="find_all", reset_at=now + ) + updated_teams: List[LiteLLM_TeamTable] = [] + failed_teams = [] + if teams_to_reset is not None and len(teams_to_reset) > 0: + for team in teams_to_reset: + try: + updated_team = await ResetBudgetJob._reset_budget_for_team( + team=team, current_time=now + ) + if updated_team is not None: + updated_teams.append(updated_team) + else: + failed_teams.append( + { + "team": team, + "error": "Returned None without exception", + } + ) + except Exception as e: + failed_teams.append({"team": team, "error": str(e)}) + verbose_proxy_logger.exception( + "Failed to reset budget for team: %s", team + ) + + verbose_proxy_logger.debug( + "Updated teams %s", json.dumps(updated_teams, indent=4, default=str) + ) + if updated_teams: + await self.prisma_client.update_data( + query_type="update_many", + data_list=updated_teams, + table_name="team", + ) + + end_time = time.time() + if len(failed_teams) > 0: # If any teams failed to reset + raise Exception( + f"Failed to reset {len(failed_teams)} teams: {json.dumps(failed_teams, default=str)}" + ) + + asyncio.create_task( + self.proxy_logging_obj.service_logging_obj.async_service_success_hook( + service=ServiceTypes.RESET_BUDGET_JOB, + duration=end_time - start_time, + call_type="reset_budget_teams", + start_time=start_time, + end_time=end_time, + event_metadata={ + "num_teams_found": len(teams_to_reset) if teams_to_reset else 0, + "teams_found": json.dumps( + teams_to_reset, indent=4, default=str + ), + "num_teams_updated": len(updated_teams), + "teams_updated": json.dumps( + updated_teams, indent=4, default=str + ), + "num_teams_failed": len(failed_teams), + "teams_failed": json.dumps(failed_teams, indent=4, default=str), + }, + ) + ) + except Exception as e: + end_time = time.time() + asyncio.create_task( + self.proxy_logging_obj.service_logging_obj.async_service_failure_hook( + service=ServiceTypes.RESET_BUDGET_JOB, + duration=end_time - start_time, + error=e, + call_type="reset_budget_teams", + start_time=start_time, + end_time=end_time, + event_metadata={ + "num_teams_found": len(teams_to_reset) if teams_to_reset else 0, + "teams_found": json.dumps( + teams_to_reset, indent=4, default=str + ), + }, + ) + ) + verbose_proxy_logger.exception("Failed to reset budget for teams: %s", e) + + @staticmethod + async def _reset_budget_common( + item: Union[LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken], + current_time: datetime, + item_type: Literal["key", "team", "user"], + ): + """ + In-place, updates spend=0, and sets budget_reset_at to current_time + budget_duration + + Common logic for resetting budget for a team, user, or key + """ + try: + item.spend = 0.0 + if hasattr(item, "budget_duration") and item.budget_duration is not None: + duration_s = duration_in_seconds(duration=item.budget_duration) + item.budget_reset_at = current_time + timedelta(seconds=duration_s) + return item + except Exception as e: + verbose_proxy_logger.exception( + "Error resetting budget for %s: %s. Item: %s", item_type, e, item + ) + raise e + + @staticmethod + async def _reset_budget_for_team( + team: LiteLLM_TeamTable, current_time: datetime + ) -> Optional[LiteLLM_TeamTable]: + await ResetBudgetJob._reset_budget_common( + item=team, current_time=current_time, item_type="team" + ) + return team + + @staticmethod + async def _reset_budget_for_user( + user: LiteLLM_UserTable, current_time: datetime + ) -> Optional[LiteLLM_UserTable]: + await ResetBudgetJob._reset_budget_common( + item=user, current_time=current_time, item_type="user" + ) + return user + + @staticmethod + async def _reset_budget_for_key( + key: LiteLLM_VerificationToken, current_time: datetime + ) -> Optional[LiteLLM_VerificationToken]: + await ResetBudgetJob._reset_budget_common( + item=key, current_time=current_time, item_type="key" + ) + return key diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/swagger_utils.py b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/swagger_utils.py new file mode 100644 index 00000000..75a64707 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/common_utils/swagger_utils.py @@ -0,0 +1,48 @@ +from typing import Any, Dict + +from pydantic import BaseModel, Field + +from litellm.exceptions import LITELLM_EXCEPTION_TYPES + + +class ErrorResponse(BaseModel): + detail: Dict[str, Any] = Field( + ..., + example={ # type: ignore + "error": { + "message": "Error message", + "type": "error_type", + "param": "error_param", + "code": "error_code", + } + }, + ) + + +# Define a function to get the status code +def get_status_code(exception): + if hasattr(exception, "status_code"): + return exception.status_code + # Default status codes for exceptions without a status_code attribute + if exception.__name__ == "Timeout": + return 408 # Request Timeout + if exception.__name__ == "APIConnectionError": + return 503 # Service Unavailable + return 500 # Internal Server Error as default + + +# Create error responses +ERROR_RESPONSES = { + get_status_code(exception): { + "model": ErrorResponse, + "description": exception.__doc__ or exception.__name__, + } + for exception in LITELLM_EXCEPTION_TYPES +} + +# Ensure we have a 500 error response +if 500 not in ERROR_RESPONSES: + ERROR_RESPONSES[500] = { + "model": ErrorResponse, + "description": "Internal Server Error", + } |