aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/proxy_server.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/proxy_server.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/proxy_server.py8185
1 files changed, 8185 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/proxy_server.py b/.venv/lib/python3.12/site-packages/litellm/proxy/proxy_server.py
new file mode 100644
index 00000000..ae1c8d18
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/proxy_server.py
@@ -0,0 +1,8185 @@
+import asyncio
+import copy
+import inspect
+import io
+import os
+import random
+import secrets
+import subprocess
+import sys
+import time
+import traceback
+import uuid
+import warnings
+from datetime import datetime, timedelta
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ List,
+ Optional,
+ Tuple,
+ cast,
+ get_args,
+ get_origin,
+ get_type_hints,
+)
+
+from litellm.types.utils import (
+ ModelResponse,
+ ModelResponseStream,
+ TextCompletionResponse,
+)
+
+if TYPE_CHECKING:
+ from opentelemetry.trace import Span as _Span
+
+ from litellm.integrations.opentelemetry import OpenTelemetry
+
+ Span = _Span
+else:
+ Span = Any
+ OpenTelemetry = Any
+
+
+def showwarning(message, category, filename, lineno, file=None, line=None):
+ traceback_info = f"{filename}:{lineno}: {category.__name__}: {message}\n"
+ if file is not None:
+ file.write(traceback_info)
+
+
+warnings.showwarning = showwarning
+warnings.filterwarnings("default", category=UserWarning)
+
+# Your client code here
+
+
+messages: list = []
+sys.path.insert(
+ 0, os.path.abspath("../..")
+) # Adds the parent directory to the system path - for litellm local dev
+
+try:
+ import logging
+
+ import backoff
+ import fastapi
+ import orjson
+ import yaml # type: ignore
+ from apscheduler.schedulers.asyncio import AsyncIOScheduler
+except ImportError as e:
+ raise ImportError(f"Missing dependency {e}. Run `pip install 'litellm[proxy]'`")
+
+list_of_messages = [
+ "'The thing I wish you improved is...'",
+ "'A feature I really want is...'",
+ "'The worst thing about this product is...'",
+ "'This product would be better if...'",
+ "'I don't like how this works...'",
+ "'It would help me if you could add...'",
+ "'This feature doesn't meet my needs because...'",
+ "'I get frustrated when the product...'",
+]
+
+
+def generate_feedback_box():
+ box_width = 60
+
+ # Select a random message
+ message = random.choice(list_of_messages)
+
+ print() # noqa
+ print("\033[1;37m" + "#" + "-" * box_width + "#\033[0m") # noqa
+ print("\033[1;37m" + "#" + " " * box_width + "#\033[0m") # noqa
+ print("\033[1;37m" + "# {:^59} #\033[0m".format(message)) # noqa
+ print( # noqa
+ "\033[1;37m"
+ + "# {:^59} #\033[0m".format("https://github.com/BerriAI/litellm/issues/new")
+ ) # noqa
+ print("\033[1;37m" + "#" + " " * box_width + "#\033[0m") # noqa
+ print("\033[1;37m" + "#" + "-" * box_width + "#\033[0m") # noqa
+ print() # noqa
+ print(" Thank you for using LiteLLM! - Krrish & Ishaan") # noqa
+ print() # noqa
+ print() # noqa
+ print() # noqa
+ print( # noqa
+ "\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m"
+ ) # noqa
+ print() # noqa
+ print() # noqa
+
+
+from collections import defaultdict
+from contextlib import asynccontextmanager
+
+import litellm
+from litellm import Router
+from litellm._logging import verbose_proxy_logger, verbose_router_logger
+from litellm.caching.caching import DualCache, RedisCache
+from litellm.constants import LITELLM_PROXY_ADMIN_NAME
+from litellm.exceptions import RejectedRequestError
+from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
+from litellm.litellm_core_utils.core_helpers import (
+ _get_parent_otel_span_from_kwargs,
+ get_litellm_metadata_from_kwargs,
+)
+from litellm.litellm_core_utils.credential_accessor import CredentialAccessor
+from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
+from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
+from litellm.proxy._types import *
+from litellm.proxy.analytics_endpoints.analytics_endpoints import (
+ router as analytics_router,
+)
+from litellm.proxy.anthropic_endpoints.endpoints import router as anthropic_router
+from litellm.proxy.auth.auth_checks import get_team_object, log_db_metrics
+from litellm.proxy.auth.auth_utils import check_response_size_is_safe
+from litellm.proxy.auth.handle_jwt import JWTHandler
+from litellm.proxy.auth.litellm_license import LicenseCheck
+from litellm.proxy.auth.model_checks import (
+ get_complete_model_list,
+ get_key_models,
+ get_team_models,
+)
+from litellm.proxy.auth.user_api_key_auth import (
+ user_api_key_auth,
+ user_api_key_auth_websocket,
+)
+from litellm.proxy.batches_endpoints.endpoints import router as batches_router
+
+## Import All Misc routes here ##
+from litellm.proxy.caching_routes import router as caching_router
+from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
+from litellm.proxy.common_utils.admin_ui_utils import html_form
+from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
+from litellm.proxy.common_utils.debug_utils import init_verbose_loggers
+from litellm.proxy.common_utils.debug_utils import router as debugging_endpoints_router
+from litellm.proxy.common_utils.encrypt_decrypt_utils import (
+ decrypt_value_helper,
+ encrypt_value_helper,
+)
+from litellm.proxy.common_utils.http_parsing_utils import (
+ _read_request_body,
+ check_file_size_under_limit,
+)
+from litellm.proxy.common_utils.load_config_utils import (
+ get_config_file_contents_from_gcs,
+ get_file_contents_from_s3,
+)
+from litellm.proxy.common_utils.openai_endpoint_utils import (
+ remove_sensitive_info_from_deployment,
+)
+from litellm.proxy.common_utils.proxy_state import ProxyState
+from litellm.proxy.common_utils.reset_budget_job import ResetBudgetJob
+from litellm.proxy.common_utils.swagger_utils import ERROR_RESPONSES
+from litellm.proxy.credential_endpoints.endpoints import router as credential_router
+from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
+from litellm.proxy.fine_tuning_endpoints.endpoints import set_fine_tuning_config
+from litellm.proxy.guardrails.guardrail_endpoints import router as guardrails_router
+from litellm.proxy.guardrails.init_guardrails import (
+ init_guardrails_v2,
+ initialize_guardrails,
+)
+from litellm.proxy.health_check import perform_health_check
+from litellm.proxy.health_endpoints._health_endpoints import router as health_router
+from litellm.proxy.hooks.model_max_budget_limiter import (
+ _PROXY_VirtualKeyModelMaxBudgetLimiter,
+)
+from litellm.proxy.hooks.prompt_injection_detection import (
+ _OPTIONAL_PromptInjectionDetection,
+)
+from litellm.proxy.hooks.proxy_track_cost_callback import _ProxyDBLogger
+from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
+from litellm.proxy.management_endpoints.budget_management_endpoints import (
+ router as budget_management_router,
+)
+from litellm.proxy.management_endpoints.customer_endpoints import (
+ router as customer_router,
+)
+from litellm.proxy.management_endpoints.internal_user_endpoints import (
+ router as internal_user_router,
+)
+from litellm.proxy.management_endpoints.internal_user_endpoints import user_update
+from litellm.proxy.management_endpoints.key_management_endpoints import (
+ delete_verification_tokens,
+ duration_in_seconds,
+ generate_key_helper_fn,
+)
+from litellm.proxy.management_endpoints.key_management_endpoints import (
+ router as key_management_router,
+)
+from litellm.proxy.management_endpoints.model_management_endpoints import (
+ _add_model_to_db,
+ _add_team_model_to_db,
+ check_if_team_id_matches_key,
+)
+from litellm.proxy.management_endpoints.model_management_endpoints import (
+ router as model_management_router,
+)
+from litellm.proxy.management_endpoints.organization_endpoints import (
+ router as organization_router,
+)
+from litellm.proxy.management_endpoints.team_callback_endpoints import (
+ router as team_callback_router,
+)
+from litellm.proxy.management_endpoints.team_endpoints import router as team_router
+from litellm.proxy.management_endpoints.team_endpoints import (
+ update_team,
+ validate_membership,
+)
+from litellm.proxy.management_endpoints.ui_sso import (
+ get_disabled_non_admin_personal_key_creation,
+)
+from litellm.proxy.management_endpoints.ui_sso import router as ui_sso_router
+from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update
+from litellm.proxy.openai_files_endpoints.files_endpoints import (
+ router as openai_files_router,
+)
+from litellm.proxy.openai_files_endpoints.files_endpoints import set_files_config
+from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
+ router as llm_passthrough_router,
+)
+from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
+ initialize_pass_through_endpoints,
+)
+from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
+ router as pass_through_router,
+)
+from litellm.proxy.rerank_endpoints.endpoints import router as rerank_router
+from litellm.proxy.response_api_endpoints.endpoints import router as response_router
+from litellm.proxy.route_llm_request import route_request
+from litellm.proxy.spend_tracking.spend_management_endpoints import (
+ router as spend_management_router,
+)
+from litellm.proxy.spend_tracking.spend_tracking_utils import get_logging_payload
+from litellm.proxy.types_utils.utils import get_instance_fn
+from litellm.proxy.ui_crud_endpoints.proxy_setting_endpoints import (
+ router as ui_crud_endpoints_router,
+)
+from litellm.proxy.utils import (
+ PrismaClient,
+ ProxyLogging,
+ ProxyUpdateSpend,
+ _cache_user_row,
+ _get_docs_url,
+ _get_projected_spend_over_limit,
+ _get_redoc_url,
+ _is_projected_spend_over_limit,
+ _is_valid_team_configs,
+ get_error_message_str,
+ hash_token,
+ update_spend,
+)
+from litellm.proxy.vertex_ai_endpoints.langfuse_endpoints import (
+ router as langfuse_router,
+)
+from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import router as vertex_router
+from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import set_default_vertex_config
+from litellm.router import (
+ AssistantsTypedDict,
+ Deployment,
+ LiteLLM_Params,
+ ModelGroupInfo,
+)
+from litellm.scheduler import DefaultPriorities, FlowItem, Scheduler
+from litellm.secret_managers.aws_secret_manager import load_aws_kms
+from litellm.secret_managers.google_kms import load_google_kms
+from litellm.secret_managers.main import (
+ get_secret,
+ get_secret_bool,
+ get_secret_str,
+ str_to_bool,
+)
+from litellm.types.integrations.slack_alerting import SlackAlertingArgs
+from litellm.types.llms.anthropic import (
+ AnthropicMessagesRequest,
+ AnthropicResponse,
+ AnthropicResponseContentBlockText,
+ AnthropicResponseUsageBlock,
+)
+from litellm.types.llms.openai import HttpxBinaryResponseContent
+from litellm.types.router import DeploymentTypedDict
+from litellm.types.router import ModelInfo as RouterModelInfo
+from litellm.types.router import RouterGeneralSettings, updateDeployment
+from litellm.types.utils import CredentialItem, CustomHuggingfaceTokenizer
+from litellm.types.utils import ModelInfo as ModelMapInfo
+from litellm.types.utils import RawRequestTypedDict, StandardLoggingPayload
+from litellm.utils import _add_custom_logger_callback_to_specific_event
+
+try:
+ from litellm._version import version
+except Exception:
+ version = "0.0.0"
+litellm.suppress_debug_info = True
+import json
+from typing import Union
+
+from fastapi import (
+ Depends,
+ FastAPI,
+ File,
+ Form,
+ Header,
+ HTTPException,
+ Path,
+ Query,
+ Request,
+ Response,
+ UploadFile,
+ status,
+)
+from fastapi.encoders import jsonable_encoder
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.openapi.utils import get_openapi
+from fastapi.responses import (
+ FileResponse,
+ JSONResponse,
+ ORJSONResponse,
+ RedirectResponse,
+ StreamingResponse,
+)
+from fastapi.routing import APIRouter
+from fastapi.security import OAuth2PasswordBearer
+from fastapi.security.api_key import APIKeyHeader
+from fastapi.staticfiles import StaticFiles
+
+# import enterprise folder
+try:
+ # when using litellm cli
+ import litellm.proxy.enterprise as enterprise
+except Exception:
+ # when using litellm docker image
+ try:
+ import enterprise # type: ignore
+ except Exception:
+ pass
+
+server_root_path = os.getenv("SERVER_ROOT_PATH", "")
+_license_check = LicenseCheck()
+premium_user: bool = _license_check.is_premium()
+global_max_parallel_request_retries_env: Optional[str] = os.getenv(
+ "LITELLM_GLOBAL_MAX_PARALLEL_REQUEST_RETRIES"
+)
+proxy_state = ProxyState()
+if global_max_parallel_request_retries_env is None:
+ global_max_parallel_request_retries: int = 3
+else:
+ global_max_parallel_request_retries = int(global_max_parallel_request_retries_env)
+
+global_max_parallel_request_retry_timeout_env: Optional[str] = os.getenv(
+ "LITELLM_GLOBAL_MAX_PARALLEL_REQUEST_RETRY_TIMEOUT"
+)
+if global_max_parallel_request_retry_timeout_env is None:
+ global_max_parallel_request_retry_timeout: float = 60.0
+else:
+ global_max_parallel_request_retry_timeout = float(
+ global_max_parallel_request_retry_timeout_env
+ )
+
+ui_link = f"{server_root_path}/ui/"
+ui_message = (
+ f"šŸ‘‰ [```LiteLLM Admin Panel on /ui```]({ui_link}). Create, Edit Keys with SSO"
+)
+ui_message += "\n\nšŸ’ø [```LiteLLM Model Cost Map```](https://models.litellm.ai/)."
+
+custom_swagger_message = "[**Customize Swagger Docs**](https://docs.litellm.ai/docs/proxy/enterprise#swagger-docs---custom-routes--branding)"
+
+### CUSTOM BRANDING [ENTERPRISE FEATURE] ###
+_title = os.getenv("DOCS_TITLE", "LiteLLM API") if premium_user else "LiteLLM API"
+_description = (
+ os.getenv(
+ "DOCS_DESCRIPTION",
+ f"Enterprise Edition \n\nProxy Server to call 100+ LLMs in the OpenAI format. {custom_swagger_message}\n\n{ui_message}",
+ )
+ if premium_user
+ else f"Proxy Server to call 100+ LLMs in the OpenAI format. {custom_swagger_message}\n\n{ui_message}"
+)
+
+
+def cleanup_router_config_variables():
+ global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, use_background_health_checks, health_check_interval, prisma_client
+
+ # Set all variables to None
+ master_key = None
+ user_config_file_path = None
+ otel_logging = None
+ user_custom_auth = None
+ user_custom_auth_path = None
+ user_custom_key_generate = None
+ user_custom_sso = None
+ use_background_health_checks = None
+ health_check_interval = None
+ prisma_client = None
+
+
+async def proxy_shutdown_event():
+ global prisma_client, master_key, user_custom_auth, user_custom_key_generate
+ verbose_proxy_logger.info("Shutting down LiteLLM Proxy Server")
+ if prisma_client:
+ verbose_proxy_logger.debug("Disconnecting from Prisma")
+ await prisma_client.disconnect()
+
+ if litellm.cache is not None:
+ await litellm.cache.disconnect()
+
+ await jwt_handler.close()
+
+ if db_writer_client is not None:
+ await db_writer_client.close()
+
+ # flush remaining langfuse logs
+ if "langfuse" in litellm.success_callback:
+ try:
+ # flush langfuse logs on shutdow
+ from litellm.utils import langFuseLogger
+
+ if langFuseLogger is not None:
+ langFuseLogger.Langfuse.flush()
+ except Exception:
+ # [DO NOT BLOCK shutdown events for this]
+ pass
+
+ ## RESET CUSTOM VARIABLES ##
+ cleanup_router_config_variables()
+
+
+@asynccontextmanager
+async def proxy_startup_event(app: FastAPI):
+ global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, premium_user, _license_check
+ import json
+
+ init_verbose_loggers()
+ ### LOAD MASTER KEY ###
+ # check if master key set in environment - load from there
+ master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore
+ # check if DATABASE_URL in environment - load from there
+ if prisma_client is None:
+ _db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore
+ prisma_client = await ProxyStartupEvent._setup_prisma_client(
+ database_url=_db_url,
+ proxy_logging_obj=proxy_logging_obj,
+ user_api_key_cache=user_api_key_cache,
+ )
+
+ ## CHECK PREMIUM USER
+ verbose_proxy_logger.debug(
+ "litellm.proxy.proxy_server.py::startup() - CHECKING PREMIUM USER - {}".format(
+ premium_user
+ )
+ )
+ if premium_user is False:
+ premium_user = _license_check.is_premium()
+
+ ### LOAD CONFIG ###
+ worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
+ env_config_yaml: Optional[str] = get_secret_str("CONFIG_FILE_PATH")
+ verbose_proxy_logger.debug("worker_config: %s", worker_config)
+ # check if it's a valid file path
+ if env_config_yaml is not None:
+ if os.path.isfile(env_config_yaml) and proxy_config.is_yaml(
+ config_file_path=env_config_yaml
+ ):
+ (
+ llm_router,
+ llm_model_list,
+ general_settings,
+ ) = await proxy_config.load_config(
+ router=llm_router, config_file_path=env_config_yaml
+ )
+ elif worker_config is not None:
+ if (
+ isinstance(worker_config, str)
+ and os.path.isfile(worker_config)
+ and proxy_config.is_yaml(config_file_path=worker_config)
+ ):
+ (
+ llm_router,
+ llm_model_list,
+ general_settings,
+ ) = await proxy_config.load_config(
+ router=llm_router, config_file_path=worker_config
+ )
+ elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None and isinstance(
+ worker_config, str
+ ):
+ (
+ llm_router,
+ llm_model_list,
+ general_settings,
+ ) = await proxy_config.load_config(
+ router=llm_router, config_file_path=worker_config
+ )
+ elif isinstance(worker_config, dict):
+ await initialize(**worker_config)
+ else:
+ # if not, assume it's a json string
+ worker_config = json.loads(worker_config)
+ if isinstance(worker_config, dict):
+ await initialize(**worker_config)
+
+ ProxyStartupEvent._initialize_startup_logging(
+ llm_router=llm_router,
+ proxy_logging_obj=proxy_logging_obj,
+ redis_usage_cache=redis_usage_cache,
+ )
+
+ ## JWT AUTH ##
+ ProxyStartupEvent._initialize_jwt_auth(
+ general_settings=general_settings,
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ )
+
+ if use_background_health_checks:
+ asyncio.create_task(
+ _run_background_health_check()
+ ) # start the background health check coroutine.
+
+ if prompt_injection_detection_obj is not None: # [TODO] - REFACTOR THIS
+ prompt_injection_detection_obj.update_environment(router=llm_router)
+
+ verbose_proxy_logger.debug("prisma_client: %s", prisma_client)
+ if prisma_client is not None and litellm.max_budget > 0:
+ ProxyStartupEvent._add_proxy_budget_to_db(
+ litellm_proxy_budget_name=litellm_proxy_admin_name
+ )
+
+ ### START BATCH WRITING DB + CHECKING NEW MODELS###
+ if prisma_client is not None:
+ await ProxyStartupEvent.initialize_scheduled_background_jobs(
+ general_settings=general_settings,
+ prisma_client=prisma_client,
+ proxy_budget_rescheduler_min_time=proxy_budget_rescheduler_min_time,
+ proxy_budget_rescheduler_max_time=proxy_budget_rescheduler_max_time,
+ proxy_batch_write_at=proxy_batch_write_at,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+ ## [Optional] Initialize dd tracer
+ ProxyStartupEvent._init_dd_tracer()
+
+ # End of startup event
+ yield
+
+ # Shutdown event
+ await proxy_shutdown_event()
+
+
+app = FastAPI(
+ docs_url=_get_docs_url(),
+ redoc_url=_get_redoc_url(),
+ title=_title,
+ description=_description,
+ version=version,
+ root_path=server_root_path, # check if user passed root path, FastAPI defaults this value to ""
+ lifespan=proxy_startup_event,
+)
+
+
+### CUSTOM API DOCS [ENTERPRISE FEATURE] ###
+# Custom OpenAPI schema generator to include only selected routes
+from fastapi.routing import APIWebSocketRoute
+
+
+def get_openapi_schema():
+ if app.openapi_schema:
+ return app.openapi_schema
+
+ openapi_schema = get_openapi(
+ title=app.title,
+ version=app.version,
+ description=app.description,
+ routes=app.routes,
+ )
+
+ # Find all WebSocket routes
+ websocket_routes = [
+ route for route in app.routes if isinstance(route, APIWebSocketRoute)
+ ]
+
+ # Add each WebSocket route to the schema
+ for route in websocket_routes:
+ # Get the base path without query parameters
+ base_path = route.path.split("{")[0].rstrip("?")
+
+ # Extract parameters from the route
+ parameters = []
+ if hasattr(route, "dependant"):
+ for param in route.dependant.query_params:
+ parameters.append(
+ {
+ "name": param.name,
+ "in": "query",
+ "required": param.required,
+ "schema": {
+ "type": "string"
+ }, # You can make this more specific if needed
+ }
+ )
+
+ openapi_schema["paths"][base_path] = {
+ "get": {
+ "summary": f"WebSocket: {route.name or base_path}",
+ "description": "WebSocket connection endpoint",
+ "operationId": f"websocket_{route.name or base_path.replace('/', '_')}",
+ "parameters": parameters,
+ "responses": {"101": {"description": "WebSocket Protocol Switched"}},
+ "tags": ["WebSocket"],
+ }
+ }
+
+ app.openapi_schema = openapi_schema
+ return app.openapi_schema
+
+
+def custom_openapi():
+ if app.openapi_schema:
+ return app.openapi_schema
+ openapi_schema = get_openapi_schema()
+
+ # Filter routes to include only specific ones
+ openai_routes = LiteLLMRoutes.openai_routes.value
+ paths_to_include: dict = {}
+ for route in openai_routes:
+ if route in openapi_schema["paths"]:
+ paths_to_include[route] = openapi_schema["paths"][route]
+ openapi_schema["paths"] = paths_to_include
+ app.openapi_schema = openapi_schema
+ return app.openapi_schema
+
+
+if os.getenv("DOCS_FILTERED", "False") == "True" and premium_user:
+ app.openapi = custom_openapi # type: ignore
+
+
+class UserAPIKeyCacheTTLEnum(enum.Enum):
+ in_memory_cache_ttl = 60 # 1 min ttl ## configure via `general_settings::user_api_key_cache_ttl: <your-value>`
+
+
+@app.exception_handler(ProxyException)
+async def openai_exception_handler(request: Request, exc: ProxyException):
+ # NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
+ headers = exc.headers
+ return JSONResponse(
+ status_code=(
+ int(exc.code) if exc.code else status.HTTP_500_INTERNAL_SERVER_ERROR
+ ),
+ content={
+ "error": {
+ "message": exc.message,
+ "type": exc.type,
+ "param": exc.param,
+ "code": exc.code,
+ }
+ },
+ headers=headers,
+ )
+
+
+router = APIRouter()
+origins = ["*"]
+
+# get current directory
+try:
+ current_dir = os.path.dirname(os.path.abspath(__file__))
+ ui_path = os.path.join(current_dir, "_experimental", "out")
+ app.mount("/ui", StaticFiles(directory=ui_path, html=True), name="ui")
+ # Iterate through files in the UI directory
+ for filename in os.listdir(ui_path):
+ if filename.endswith(".html") and filename != "index.html":
+ # Create a folder with the same name as the HTML file
+ folder_name = os.path.splitext(filename)[0]
+ folder_path = os.path.join(ui_path, folder_name)
+ os.makedirs(folder_path, exist_ok=True)
+
+ # Move the HTML file into the folder and rename it to 'index.html'
+ src = os.path.join(ui_path, filename)
+ dst = os.path.join(folder_path, "index.html")
+ os.rename(src, dst)
+
+ if server_root_path != "":
+ print( # noqa
+ f"server_root_path is set, forwarding any /ui requests to {server_root_path}/ui"
+ ) # noqa
+ if os.getenv("PROXY_BASE_URL") is None:
+ os.environ["PROXY_BASE_URL"] = server_root_path
+
+ @app.middleware("http")
+ async def redirect_ui_middleware(request: Request, call_next):
+ if request.url.path.startswith("/ui"):
+ new_url = str(request.url).replace("/ui", f"{server_root_path}/ui", 1)
+ return RedirectResponse(new_url)
+ return await call_next(request)
+
+except Exception:
+ pass
+# current_dir = os.path.dirname(os.path.abspath(__file__))
+# ui_path = os.path.join(current_dir, "_experimental", "out")
+# # Mount this test directory instead
+# app.mount("/ui", StaticFiles(directory=ui_path, html=True), name="ui")
+
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=origins,
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+
+from typing import Dict
+
+user_api_base = None
+user_model = None
+user_debug = False
+user_max_tokens = None
+user_request_timeout = None
+user_temperature = None
+user_telemetry = True
+user_config = None
+user_headers = None
+user_config_file_path: Optional[str] = None
+local_logging = True # writes logs to a local api_log.json file for debugging
+experimental = False
+#### GLOBAL VARIABLES ####
+llm_router: Optional[Router] = None
+llm_model_list: Optional[list] = None
+general_settings: dict = {}
+callback_settings: dict = {}
+log_file = "api_log.json"
+worker_config = None
+master_key: Optional[str] = None
+otel_logging = False
+prisma_client: Optional[PrismaClient] = None
+user_api_key_cache = DualCache(
+ default_in_memory_ttl=UserAPIKeyCacheTTLEnum.in_memory_cache_ttl.value
+)
+model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
+ dual_cache=user_api_key_cache
+)
+litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
+redis_usage_cache: Optional[RedisCache] = (
+ None # redis cache used for tracking spend, tpm/rpm limits
+)
+user_custom_auth = None
+user_custom_key_generate = None
+user_custom_sso = None
+use_background_health_checks = None
+use_queue = False
+health_check_interval = None
+health_check_details = None
+health_check_results = {}
+queue: List = []
+litellm_proxy_budget_name = "litellm-proxy-budget"
+litellm_proxy_admin_name = LITELLM_PROXY_ADMIN_NAME
+ui_access_mode: Literal["admin", "all"] = "all"
+proxy_budget_rescheduler_min_time = 597
+proxy_budget_rescheduler_max_time = 605
+proxy_batch_write_at = 10 # in seconds
+litellm_master_key_hash = None
+disable_spend_logs = False
+jwt_handler = JWTHandler()
+prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None
+store_model_in_db: bool = False
+open_telemetry_logger: Optional[OpenTelemetry] = None
+### INITIALIZE GLOBAL LOGGING OBJECT ###
+proxy_logging_obj = ProxyLogging(
+ user_api_key_cache=user_api_key_cache, premium_user=premium_user
+)
+### REDIS QUEUE ###
+async_result = None
+celery_app_conn = None
+celery_fn = None # Redis Queue for handling requests
+### DB WRITER ###
+db_writer_client: Optional[AsyncHTTPHandler] = None
+### logger ###
+
+
+async def check_request_disconnection(request: Request, llm_api_call_task):
+ """
+ Asynchronously checks if the request is disconnected at regular intervals.
+ If the request is disconnected
+ - cancel the litellm.router task
+ - raises an HTTPException with status code 499 and detail "Client disconnected the request".
+
+ Parameters:
+ - request: Request: The request object to check for disconnection.
+ Returns:
+ - None
+ """
+
+ # only run this function for 10 mins -> if these don't get cancelled -> we don't want the server to have many while loops
+ start_time = time.time()
+ while time.time() - start_time < 600:
+ await asyncio.sleep(1)
+ if await request.is_disconnected():
+
+ # cancel the LLM API Call task if any passed - this is passed from individual providers
+ # Example OpenAI, Azure, VertexAI etc
+ llm_api_call_task.cancel()
+
+ raise HTTPException(
+ status_code=499,
+ detail="Client disconnected the request",
+ )
+
+
+def _resolve_typed_dict_type(typ):
+ """Resolve the actual TypedDict class from a potentially wrapped type."""
+ from typing_extensions import _TypedDictMeta # type: ignore
+
+ origin = get_origin(typ)
+ if origin is Union: # Check if it's a Union (like Optional)
+ for arg in get_args(typ):
+ if isinstance(arg, _TypedDictMeta):
+ return arg
+ elif isinstance(typ, type) and isinstance(typ, dict):
+ return typ
+ return None
+
+
+def _resolve_pydantic_type(typ) -> List:
+ """Resolve the actual TypedDict class from a potentially wrapped type."""
+ origin = get_origin(typ)
+ typs = []
+ if origin is Union: # Check if it's a Union (like Optional)
+ for arg in get_args(typ):
+ if (
+ arg is not None
+ and not isinstance(arg, type(None))
+ and "NoneType" not in str(arg)
+ ):
+ typs.append(arg)
+ elif isinstance(typ, type) and isinstance(typ, BaseModel):
+ return [typ]
+ return typs
+
+
+def load_from_azure_key_vault(use_azure_key_vault: bool = False):
+ if use_azure_key_vault is False:
+ return
+
+ try:
+ from azure.identity import DefaultAzureCredential
+ from azure.keyvault.secrets import SecretClient
+
+ # Set your Azure Key Vault URI
+ KVUri = os.getenv("AZURE_KEY_VAULT_URI", None)
+
+ if KVUri is None:
+ raise Exception(
+ "Error when loading keys from Azure Key Vault: AZURE_KEY_VAULT_URI is not set."
+ )
+
+ credential = DefaultAzureCredential()
+
+ # Create the SecretClient using the credential
+ client = SecretClient(vault_url=KVUri, credential=credential)
+
+ litellm.secret_manager_client = client
+ litellm._key_management_system = KeyManagementSystem.AZURE_KEY_VAULT
+ except Exception as e:
+ _error_str = str(e)
+ verbose_proxy_logger.exception(
+ "Error when loading keys from Azure Key Vault: %s .Ensure you run `pip install azure-identity azure-keyvault-secrets`",
+ _error_str,
+ )
+
+
+def cost_tracking():
+ global prisma_client
+ if prisma_client is not None:
+ litellm.logging_callback_manager.add_litellm_callback(_ProxyDBLogger())
+
+
+def _set_spend_logs_payload(
+ payload: Union[dict, SpendLogsPayload],
+ prisma_client: PrismaClient,
+ spend_logs_url: Optional[str] = None,
+):
+ verbose_proxy_logger.info(
+ "Writing spend log to db - request_id: {}, spend: {}".format(
+ payload.get("request_id"), payload.get("spend")
+ )
+ )
+ if prisma_client is not None and spend_logs_url is not None:
+ if isinstance(payload["startTime"], datetime):
+ payload["startTime"] = payload["startTime"].isoformat()
+ if isinstance(payload["endTime"], datetime):
+ payload["endTime"] = payload["endTime"].isoformat()
+ prisma_client.spend_log_transactions.append(payload)
+ elif prisma_client is not None:
+ prisma_client.spend_log_transactions.append(payload)
+ return prisma_client
+
+
+async def update_database( # noqa: PLR0915
+ token,
+ response_cost,
+ user_id=None,
+ end_user_id=None,
+ team_id=None,
+ kwargs=None,
+ completion_response=None,
+ start_time=None,
+ end_time=None,
+ org_id=None,
+):
+ try:
+ global prisma_client
+ verbose_proxy_logger.debug(
+ f"Enters prisma db call, response_cost: {response_cost}, token: {token}; user_id: {user_id}; team_id: {team_id}"
+ )
+ if ProxyUpdateSpend.disable_spend_updates() is True:
+ return
+ if token is not None and isinstance(token, str) and token.startswith("sk-"):
+ hashed_token = hash_token(token=token)
+ else:
+ hashed_token = token
+
+ ### UPDATE USER SPEND ###
+ async def _update_user_db():
+ """
+ - Update that user's row
+ - Update litellm-proxy-budget row (global proxy spend)
+ """
+ ## if an end-user is passed in, do an upsert - we can't guarantee they already exist in db
+ existing_user_obj = await user_api_key_cache.async_get_cache(key=user_id)
+ if existing_user_obj is not None and isinstance(existing_user_obj, dict):
+ existing_user_obj = LiteLLM_UserTable(**existing_user_obj)
+ try:
+ if prisma_client is not None: # update
+ user_ids = [user_id]
+ if (
+ litellm.max_budget > 0
+ ): # track global proxy budget, if user set max budget
+ user_ids.append(litellm_proxy_budget_name)
+ ### KEY CHANGE ###
+ for _id in user_ids:
+ if _id is not None:
+ prisma_client.user_list_transactons[_id] = (
+ response_cost
+ + prisma_client.user_list_transactons.get(_id, 0)
+ )
+ if end_user_id is not None:
+ prisma_client.end_user_list_transactons[end_user_id] = (
+ response_cost
+ + prisma_client.end_user_list_transactons.get(
+ end_user_id, 0
+ )
+ )
+ except Exception as e:
+ verbose_proxy_logger.info(
+ "\033[91m"
+ + f"Update User DB call failed to execute {str(e)}\n{traceback.format_exc()}"
+ )
+
+ ### UPDATE KEY SPEND ###
+ async def _update_key_db():
+ try:
+ verbose_proxy_logger.debug(
+ f"adding spend to key db. Response cost: {response_cost}. Token: {hashed_token}."
+ )
+ if hashed_token is None:
+ return
+ if prisma_client is not None:
+ prisma_client.key_list_transactons[hashed_token] = (
+ response_cost
+ + prisma_client.key_list_transactons.get(hashed_token, 0)
+ )
+ except Exception as e:
+ verbose_proxy_logger.exception(
+ f"Update Key DB Call failed to execute - {str(e)}"
+ )
+ raise e
+
+ ### UPDATE SPEND LOGS ###
+ async def _insert_spend_log_to_db():
+ try:
+ global prisma_client
+ if prisma_client is not None:
+ # Helper to generate payload to log
+ payload = get_logging_payload(
+ kwargs=kwargs,
+ response_obj=completion_response,
+ start_time=start_time,
+ end_time=end_time,
+ )
+ payload["spend"] = response_cost
+ prisma_client = _set_spend_logs_payload(
+ payload=payload,
+ spend_logs_url=os.getenv("SPEND_LOGS_URL"),
+ prisma_client=prisma_client,
+ )
+ except Exception as e:
+ verbose_proxy_logger.debug(
+ f"Update Spend Logs DB failed to execute - {str(e)}\n{traceback.format_exc()}"
+ )
+ raise e
+
+ ### UPDATE TEAM SPEND ###
+ async def _update_team_db():
+ try:
+ verbose_proxy_logger.debug(
+ f"adding spend to team db. Response cost: {response_cost}. team_id: {team_id}."
+ )
+ if team_id is None:
+ verbose_proxy_logger.debug(
+ "track_cost_callback: team_id is None. Not tracking spend for team"
+ )
+ return
+ if prisma_client is not None:
+ prisma_client.team_list_transactons[team_id] = (
+ response_cost
+ + prisma_client.team_list_transactons.get(team_id, 0)
+ )
+
+ try:
+ # Track spend of the team member within this team
+ # key is "team_id::<value>::user_id::<value>"
+ team_member_key = f"team_id::{team_id}::user_id::{user_id}"
+ prisma_client.team_member_list_transactons[team_member_key] = (
+ response_cost
+ + prisma_client.team_member_list_transactons.get(
+ team_member_key, 0
+ )
+ )
+ except Exception:
+ pass
+ except Exception as e:
+ verbose_proxy_logger.info(
+ f"Update Team DB failed to execute - {str(e)}\n{traceback.format_exc()}"
+ )
+ raise e
+
+ ### UPDATE ORG SPEND ###
+ async def _update_org_db():
+ try:
+ verbose_proxy_logger.debug(
+ "adding spend to org db. Response cost: {}. org_id: {}.".format(
+ response_cost, org_id
+ )
+ )
+ if org_id is None:
+ verbose_proxy_logger.debug(
+ "track_cost_callback: org_id is None. Not tracking spend for org"
+ )
+ return
+ if prisma_client is not None:
+ prisma_client.org_list_transactons[org_id] = (
+ response_cost
+ + prisma_client.org_list_transactons.get(org_id, 0)
+ )
+ except Exception as e:
+ verbose_proxy_logger.info(
+ f"Update Org DB failed to execute - {str(e)}\n{traceback.format_exc()}"
+ )
+ raise e
+
+ asyncio.create_task(_update_user_db())
+ asyncio.create_task(_update_key_db())
+ asyncio.create_task(_update_team_db())
+ asyncio.create_task(_update_org_db())
+ # asyncio.create_task(_insert_spend_log_to_db())
+ if disable_spend_logs is False:
+ await _insert_spend_log_to_db()
+ else:
+ verbose_proxy_logger.info(
+ "disable_spend_logs=True. Skipping writing spend logs to db. Other spend updates - Key/User/Team table will still occur."
+ )
+
+ verbose_proxy_logger.debug("Runs spend update on all tables")
+ except Exception:
+ verbose_proxy_logger.debug(
+ f"Error updating Prisma database: {traceback.format_exc()}"
+ )
+
+
+async def update_cache( # noqa: PLR0915
+ token: Optional[str],
+ user_id: Optional[str],
+ end_user_id: Optional[str],
+ team_id: Optional[str],
+ response_cost: Optional[float],
+ parent_otel_span: Optional[Span], # type: ignore
+):
+ """
+ Use this to update the cache with new user spend.
+
+ Put any alerting logic in here.
+ """
+
+ values_to_update_in_cache: List[Tuple[Any, Any]] = []
+
+ ### UPDATE KEY SPEND ###
+ async def _update_key_cache(token: str, response_cost: float):
+ # Fetch the existing cost for the given token
+ if isinstance(token, str) and token.startswith("sk-"):
+ hashed_token = hash_token(token=token)
+ else:
+ hashed_token = token
+ verbose_proxy_logger.debug("_update_key_cache: hashed_token=%s", hashed_token)
+ existing_spend_obj: LiteLLM_VerificationTokenView = await user_api_key_cache.async_get_cache(key=hashed_token) # type: ignore
+ verbose_proxy_logger.debug(
+ f"_update_key_cache: existing_spend_obj={existing_spend_obj}"
+ )
+ verbose_proxy_logger.debug(
+ f"_update_key_cache: existing spend: {existing_spend_obj}"
+ )
+ if existing_spend_obj is None:
+ return
+ else:
+ existing_spend = existing_spend_obj.spend
+ # Calculate the new cost by adding the existing cost and response_cost
+ new_spend = existing_spend + response_cost
+
+ ## CHECK IF USER PROJECTED SPEND > SOFT LIMIT
+ if (
+ existing_spend_obj.soft_budget_cooldown is False
+ and existing_spend_obj.litellm_budget_table is not None
+ and (
+ _is_projected_spend_over_limit(
+ current_spend=new_spend,
+ soft_budget_limit=existing_spend_obj.litellm_budget_table[
+ "soft_budget"
+ ],
+ )
+ is True
+ )
+ ):
+ projected_spend, projected_exceeded_date = _get_projected_spend_over_limit(
+ current_spend=new_spend,
+ soft_budget_limit=existing_spend_obj.litellm_budget_table.get(
+ "soft_budget", None
+ ),
+ ) # type: ignore
+ soft_limit = existing_spend_obj.litellm_budget_table.get(
+ "soft_budget", float("inf")
+ )
+ call_info = CallInfo(
+ token=existing_spend_obj.token or "",
+ spend=new_spend,
+ key_alias=existing_spend_obj.key_alias,
+ max_budget=soft_limit,
+ user_id=existing_spend_obj.user_id,
+ projected_spend=projected_spend,
+ projected_exceeded_date=projected_exceeded_date,
+ )
+ # alert user
+ asyncio.create_task(
+ proxy_logging_obj.budget_alerts(
+ type="projected_limit_exceeded",
+ user_info=call_info,
+ )
+ )
+ # set cooldown on alert
+
+ if (
+ existing_spend_obj is not None
+ and getattr(existing_spend_obj, "team_spend", None) is not None
+ ):
+ existing_team_spend = existing_spend_obj.team_spend or 0
+ # Calculate the new cost by adding the existing cost and response_cost
+ existing_spend_obj.team_spend = existing_team_spend + response_cost
+
+ if (
+ existing_spend_obj is not None
+ and getattr(existing_spend_obj, "team_member_spend", None) is not None
+ ):
+ existing_team_member_spend = existing_spend_obj.team_member_spend or 0
+ # Calculate the new cost by adding the existing cost and response_cost
+ existing_spend_obj.team_member_spend = (
+ existing_team_member_spend + response_cost
+ )
+
+ # Update the cost column for the given token
+ existing_spend_obj.spend = new_spend
+ values_to_update_in_cache.append((hashed_token, existing_spend_obj))
+
+ ### UPDATE USER SPEND ###
+ async def _update_user_cache():
+ ## UPDATE CACHE FOR USER ID + GLOBAL PROXY
+ user_ids = [user_id]
+ try:
+ for _id in user_ids:
+ # Fetch the existing cost for the given user
+ if _id is None:
+ continue
+ existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id)
+ if existing_spend_obj is None:
+ # do nothing if there is no cache value
+ return
+ verbose_proxy_logger.debug(
+ f"_update_user_db: existing spend: {existing_spend_obj}; response_cost: {response_cost}"
+ )
+
+ if isinstance(existing_spend_obj, dict):
+ existing_spend = existing_spend_obj["spend"]
+ else:
+ existing_spend = existing_spend_obj.spend
+ # Calculate the new cost by adding the existing cost and response_cost
+ new_spend = existing_spend + response_cost
+
+ # Update the cost column for the given user
+ if isinstance(existing_spend_obj, dict):
+ existing_spend_obj["spend"] = new_spend
+ values_to_update_in_cache.append((_id, existing_spend_obj))
+ else:
+ existing_spend_obj.spend = new_spend
+ values_to_update_in_cache.append((_id, existing_spend_obj.json()))
+ ## UPDATE GLOBAL PROXY ##
+ global_proxy_spend = await user_api_key_cache.async_get_cache(
+ key="{}:spend".format(litellm_proxy_admin_name)
+ )
+ if global_proxy_spend is None:
+ # do nothing if not in cache
+ return
+ elif response_cost is not None and global_proxy_spend is not None:
+ increment = global_proxy_spend + response_cost
+ values_to_update_in_cache.append(
+ ("{}:spend".format(litellm_proxy_admin_name), increment)
+ )
+ except Exception as e:
+ verbose_proxy_logger.debug(
+ f"An error occurred updating user cache: {str(e)}\n\n{traceback.format_exc()}"
+ )
+
+ ### UPDATE END-USER SPEND ###
+ async def _update_end_user_cache():
+ if end_user_id is None or response_cost is None:
+ return
+
+ _id = "end_user_id:{}".format(end_user_id)
+ try:
+ # Fetch the existing cost for the given user
+ existing_spend_obj = await user_api_key_cache.async_get_cache(key=_id)
+ if existing_spend_obj is None:
+ # if user does not exist in LiteLLM_UserTable, create a new user
+ # do nothing if end-user not in api key cache
+ return
+ verbose_proxy_logger.debug(
+ f"_update_end_user_db: existing spend: {existing_spend_obj}; response_cost: {response_cost}"
+ )
+ if existing_spend_obj is None:
+ existing_spend = 0
+ else:
+ if isinstance(existing_spend_obj, dict):
+ existing_spend = existing_spend_obj["spend"]
+ else:
+ existing_spend = existing_spend_obj.spend
+ # Calculate the new cost by adding the existing cost and response_cost
+ new_spend = existing_spend + response_cost
+
+ # Update the cost column for the given user
+ if isinstance(existing_spend_obj, dict):
+ existing_spend_obj["spend"] = new_spend
+ values_to_update_in_cache.append((_id, existing_spend_obj))
+ else:
+ existing_spend_obj.spend = new_spend
+ values_to_update_in_cache.append((_id, existing_spend_obj.json()))
+ except Exception as e:
+ verbose_proxy_logger.exception(
+ f"An error occurred updating end user cache: {str(e)}"
+ )
+
+ ### UPDATE TEAM SPEND ###
+ async def _update_team_cache():
+ if team_id is None or response_cost is None:
+ return
+
+ _id = "team_id:{}".format(team_id)
+ try:
+ # Fetch the existing cost for the given user
+ existing_spend_obj: Optional[LiteLLM_TeamTable] = (
+ await user_api_key_cache.async_get_cache(key=_id)
+ )
+ if existing_spend_obj is None:
+ # do nothing if team not in api key cache
+ return
+ verbose_proxy_logger.debug(
+ f"_update_team_db: existing spend: {existing_spend_obj}; response_cost: {response_cost}"
+ )
+ if existing_spend_obj is None:
+ existing_spend: Optional[float] = 0.0
+ else:
+ if isinstance(existing_spend_obj, dict):
+ existing_spend = existing_spend_obj["spend"]
+ else:
+ existing_spend = existing_spend_obj.spend
+
+ if existing_spend is None:
+ existing_spend = 0.0
+ # Calculate the new cost by adding the existing cost and response_cost
+ new_spend = existing_spend + response_cost
+
+ # Update the cost column for the given user
+ if isinstance(existing_spend_obj, dict):
+ existing_spend_obj["spend"] = new_spend
+ values_to_update_in_cache.append((_id, existing_spend_obj))
+ else:
+ existing_spend_obj.spend = new_spend
+ values_to_update_in_cache.append((_id, existing_spend_obj))
+ except Exception as e:
+ verbose_proxy_logger.exception(
+ f"An error occurred updating end user cache: {str(e)}"
+ )
+
+ if token is not None and response_cost is not None:
+ await _update_key_cache(token=token, response_cost=response_cost)
+
+ if user_id is not None:
+ await _update_user_cache()
+
+ if end_user_id is not None:
+ await _update_end_user_cache()
+
+ if team_id is not None:
+ await _update_team_cache()
+
+ asyncio.create_task(
+ user_api_key_cache.async_set_cache_pipeline(
+ cache_list=values_to_update_in_cache,
+ ttl=60,
+ litellm_parent_otel_span=parent_otel_span,
+ )
+ )
+
+
+def run_ollama_serve():
+ try:
+ command = ["ollama", "serve"]
+
+ with open(os.devnull, "w") as devnull:
+ subprocess.Popen(command, stdout=devnull, stderr=devnull)
+ except Exception as e:
+ verbose_proxy_logger.debug(
+ f"""
+ LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
+ """
+ )
+
+
+async def _run_background_health_check():
+ """
+ Periodically run health checks in the background on the endpoints.
+
+ Update health_check_results, based on this.
+ """
+ global health_check_results, llm_model_list, health_check_interval, health_check_details
+
+ # make 1 deep copy of llm_model_list -> use this for all background health checks
+ _llm_model_list = copy.deepcopy(llm_model_list)
+
+ if _llm_model_list is None:
+ return
+
+ while True:
+ healthy_endpoints, unhealthy_endpoints = await perform_health_check(
+ model_list=_llm_model_list, details=health_check_details
+ )
+
+ # Update the global variable with the health check results
+ health_check_results["healthy_endpoints"] = healthy_endpoints
+ health_check_results["unhealthy_endpoints"] = unhealthy_endpoints
+ health_check_results["healthy_count"] = len(healthy_endpoints)
+ health_check_results["unhealthy_count"] = len(unhealthy_endpoints)
+
+ if health_check_interval is not None and isinstance(
+ health_check_interval, float
+ ):
+ await asyncio.sleep(health_check_interval)
+
+
+class StreamingCallbackError(Exception):
+ pass
+
+
+class ProxyConfig:
+ """
+ Abstraction class on top of config loading/updating logic. Gives us one place to control all config updating logic.
+ """
+
+ def __init__(self) -> None:
+ self.config: Dict[str, Any] = {}
+
+ def is_yaml(self, config_file_path: str) -> bool:
+ if not os.path.isfile(config_file_path):
+ return False
+
+ _, file_extension = os.path.splitext(config_file_path)
+ return file_extension.lower() == ".yaml" or file_extension.lower() == ".yml"
+
+ def _load_yaml_file(self, file_path: str) -> dict:
+ """
+ Load and parse a YAML file
+ """
+ try:
+ with open(file_path, "r") as file:
+ return yaml.safe_load(file) or {}
+ except Exception as e:
+ raise Exception(f"Error loading yaml file {file_path}: {str(e)}")
+
+ async def _get_config_from_file(
+ self, config_file_path: Optional[str] = None
+ ) -> dict:
+ """
+ Given a config file path, load the config from the file.
+ Args:
+ config_file_path (str): path to the config file
+ Returns:
+ dict: config
+ """
+ global prisma_client, user_config_file_path
+
+ file_path = config_file_path or user_config_file_path
+ if config_file_path is not None:
+ user_config_file_path = config_file_path
+ # Load existing config
+ ## Yaml
+ if os.path.exists(f"{file_path}"):
+ with open(f"{file_path}", "r") as config_file:
+ config = yaml.safe_load(config_file)
+ elif file_path is not None:
+ raise Exception(f"Config file not found: {file_path}")
+ else:
+ config = {
+ "model_list": [],
+ "general_settings": {},
+ "router_settings": {},
+ "litellm_settings": {},
+ }
+
+ # Process includes
+ config = self._process_includes(
+ config=config, base_dir=os.path.dirname(os.path.abspath(file_path or ""))
+ )
+
+ verbose_proxy_logger.debug(f"loaded config={json.dumps(config, indent=4)}")
+ return config
+
+ def _process_includes(self, config: dict, base_dir: str) -> dict:
+ """
+ Process includes by appending their contents to the main config
+
+ Handles nested config.yamls with `include` section
+
+ Example config: This will get the contents from files in `include` and append it
+ ```yaml
+ include:
+ - model_config.yaml
+
+ litellm_settings:
+ callbacks: ["prometheus"]
+ ```
+ """
+ if "include" not in config:
+ return config
+
+ if not isinstance(config["include"], list):
+ raise ValueError("'include' must be a list of file paths")
+
+ # Load and append all included files
+ for include_file in config["include"]:
+ file_path = os.path.join(base_dir, include_file)
+ if not os.path.exists(file_path):
+ raise FileNotFoundError(f"Included file not found: {file_path}")
+
+ included_config = self._load_yaml_file(file_path)
+ # Simply update/extend the main config with included config
+ for key, value in included_config.items():
+ if isinstance(value, list) and key in config:
+ config[key].extend(value)
+ else:
+ config[key] = value
+
+ # Remove the include directive
+ del config["include"]
+ return config
+
+ async def save_config(self, new_config: dict):
+ global prisma_client, general_settings, user_config_file_path, store_model_in_db
+ # Load existing config
+ ## DB - writes valid config to db
+ """
+ - Do not write restricted params like 'api_key' to the database
+ - if api_key is passed, save that to the local environment or connected secret manage (maybe expose `litellm.save_secret()`)
+ """
+ if prisma_client is not None and (
+ general_settings.get("store_model_in_db", False) is True
+ or store_model_in_db
+ ):
+ # if using - db for config - models are in ModelTable
+ new_config.pop("model_list", None)
+ await prisma_client.insert_data(data=new_config, table_name="config")
+ else:
+ # Save the updated config - if user is not using a dB
+ ## YAML
+ with open(f"{user_config_file_path}", "w") as config_file:
+ yaml.dump(new_config, config_file, default_flow_style=False)
+
+ def _check_for_os_environ_vars(
+ self, config: dict, depth: int = 0, max_depth: int = 10
+ ) -> dict:
+ """
+ Check for os.environ/ variables in the config and replace them with the actual values.
+ Includes a depth limit to prevent infinite recursion.
+
+ Args:
+ config (dict): The configuration dictionary to process.
+ depth (int): Current recursion depth.
+ max_depth (int): Maximum allowed recursion depth.
+
+ Returns:
+ dict: Processed configuration dictionary.
+ """
+ if depth > max_depth:
+ verbose_proxy_logger.warning(
+ f"Maximum recursion depth ({max_depth}) reached while processing config."
+ )
+ return config
+
+ for key, value in config.items():
+ if isinstance(value, dict):
+ config[key] = self._check_for_os_environ_vars(
+ config=value, depth=depth + 1, max_depth=max_depth
+ )
+ elif isinstance(value, list):
+ for item in value:
+ if isinstance(item, dict):
+ item = self._check_for_os_environ_vars(
+ config=item, depth=depth + 1, max_depth=max_depth
+ )
+ # if the value is a string and starts with "os.environ/" - then it's an environment variable
+ elif isinstance(value, str) and value.startswith("os.environ/"):
+ config[key] = get_secret(value)
+ return config
+
+ def _get_team_config(self, team_id: str, all_teams_config: List[Dict]) -> Dict:
+ team_config: dict = {}
+ for team in all_teams_config:
+ if "team_id" not in team:
+ raise Exception(f"team_id missing from team: {team}")
+ if team_id == team["team_id"]:
+ team_config = team
+ break
+ for k, v in team_config.items():
+ if isinstance(v, str) and v.startswith("os.environ/"):
+ team_config[k] = get_secret(v)
+ return team_config
+
+ def load_team_config(self, team_id: str):
+ """
+ - for a given team id
+ - return the relevant completion() call params
+ """
+
+ # load existing config
+ config = self.get_config_state()
+
+ ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
+ litellm_settings = config.get("litellm_settings", {})
+ all_teams_config = litellm_settings.get("default_team_settings", None)
+ if all_teams_config is None:
+ return {}
+ team_config = self._get_team_config(
+ team_id=team_id, all_teams_config=all_teams_config
+ )
+ return team_config
+
+ def _init_cache(
+ self,
+ cache_params: dict,
+ ):
+ global redis_usage_cache, llm_router
+ from litellm import Cache
+
+ if "default_in_memory_ttl" in cache_params:
+ litellm.default_in_memory_ttl = cache_params["default_in_memory_ttl"]
+
+ if "default_redis_ttl" in cache_params:
+ litellm.default_redis_ttl = cache_params["default_in_redis_ttl"]
+
+ litellm.cache = Cache(**cache_params)
+
+ if litellm.cache is not None and isinstance(litellm.cache.cache, RedisCache):
+ ## INIT PROXY REDIS USAGE CLIENT ##
+ redis_usage_cache = litellm.cache.cache
+
+ async def get_config(self, config_file_path: Optional[str] = None) -> dict:
+ """
+ Load config file
+ Supports reading from:
+ - .yaml file paths
+ - LiteLLM connected DB
+ - GCS
+ - S3
+
+ Args:
+ config_file_path (str): path to the config file
+ Returns:
+ dict: config
+
+ """
+ global prisma_client, store_model_in_db
+ # Load existing config
+
+ if os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
+ bucket_name = os.environ.get("LITELLM_CONFIG_BUCKET_NAME")
+ object_key = os.environ.get("LITELLM_CONFIG_BUCKET_OBJECT_KEY")
+ bucket_type = os.environ.get("LITELLM_CONFIG_BUCKET_TYPE")
+ verbose_proxy_logger.debug(
+ "bucket_name: %s, object_key: %s", bucket_name, object_key
+ )
+ if bucket_type == "gcs":
+ config = await get_config_file_contents_from_gcs(
+ bucket_name=bucket_name, object_key=object_key
+ )
+ else:
+ config = get_file_contents_from_s3(
+ bucket_name=bucket_name, object_key=object_key
+ )
+
+ if config is None:
+ raise Exception("Unable to load config from given source.")
+ else:
+ # default to file
+ config = await self._get_config_from_file(config_file_path=config_file_path)
+ ## UPDATE CONFIG WITH DB
+ if prisma_client is not None and store_model_in_db is True:
+ config = await self._update_config_from_db(
+ config=config,
+ prisma_client=prisma_client,
+ store_model_in_db=store_model_in_db,
+ )
+
+ ## PRINT YAML FOR CONFIRMING IT WORKS
+ printed_yaml = copy.deepcopy(config)
+ printed_yaml.pop("environment_variables", None)
+
+ config = self._check_for_os_environ_vars(config=config)
+
+ self.update_config_state(config=config)
+ return config
+
+ def update_config_state(self, config: dict):
+ self.config = config
+
+ def get_config_state(self):
+ """
+ Returns a deep copy of the config,
+
+ Do this, to avoid mutating the config state outside of allowed methods
+ """
+ try:
+ return copy.deepcopy(self.config)
+ except Exception as e:
+ verbose_proxy_logger.debug(
+ "ProxyConfig:get_config_state(): Error returning copy of config state. self.config={}\nError: {}".format(
+ self.config, e
+ )
+ )
+ return {}
+
+ def load_credential_list(self, config: dict) -> List[CredentialItem]:
+ """
+ Load the credential list from the database
+ """
+ credential_list_dict = config.get("credential_list")
+ credential_list = []
+ if credential_list_dict:
+ credential_list = [CredentialItem(**cred) for cred in credential_list_dict]
+ return credential_list
+
+ async def load_config( # noqa: PLR0915
+ self, router: Optional[litellm.Router], config_file_path: str
+ ):
+ """
+ Load config values into proxy global state
+ """
+ global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, user_custom_sso, use_background_health_checks, health_check_interval, use_queue, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details, callback_settings
+
+ config: dict = await self.get_config(config_file_path=config_file_path)
+
+ ## ENVIRONMENT VARIABLES
+ environment_variables = config.get("environment_variables", None)
+ if environment_variables:
+ for key, value in environment_variables.items():
+ os.environ[key] = str(get_secret(secret_name=key, default_value=value))
+
+ # check if litellm_license in general_settings
+ if "LITELLM_LICENSE" in environment_variables:
+ _license_check.license_str = os.getenv("LITELLM_LICENSE", None)
+ premium_user = _license_check.is_premium()
+
+ ## Callback settings
+ callback_settings = config.get("callback_settings", None)
+
+ ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
+ litellm_settings = config.get("litellm_settings", None)
+ if litellm_settings is None:
+ litellm_settings = {}
+ if litellm_settings:
+ # ANSI escape code for blue text
+ blue_color_code = "\033[94m"
+ reset_color_code = "\033[0m"
+ for key, value in litellm_settings.items():
+ if key == "cache" and value is True:
+ print(f"{blue_color_code}\nSetting Cache on Proxy") # noqa
+ from litellm.caching.caching import Cache
+
+ cache_params = {}
+ if "cache_params" in litellm_settings:
+ cache_params_in_config = litellm_settings["cache_params"]
+ # overwrie cache_params with cache_params_in_config
+ cache_params.update(cache_params_in_config)
+
+ cache_type = cache_params.get("type", "redis")
+
+ verbose_proxy_logger.debug("passed cache type=%s", cache_type)
+
+ if (
+ cache_type == "redis" or cache_type == "redis-semantic"
+ ) and len(cache_params.keys()) == 0:
+ cache_host = get_secret("REDIS_HOST", None)
+ cache_port = get_secret("REDIS_PORT", None)
+ cache_password = None
+ cache_params.update(
+ {
+ "type": cache_type,
+ "host": cache_host,
+ "port": cache_port,
+ }
+ )
+
+ if get_secret("REDIS_PASSWORD", None) is not None:
+ cache_password = get_secret("REDIS_PASSWORD", None)
+ cache_params.update(
+ {
+ "password": cache_password,
+ }
+ )
+
+ # Assuming cache_type, cache_host, cache_port, and cache_password are strings
+ verbose_proxy_logger.debug(
+ "%sCache Type:%s %s",
+ blue_color_code,
+ reset_color_code,
+ cache_type,
+ )
+ verbose_proxy_logger.debug(
+ "%sCache Host:%s %s",
+ blue_color_code,
+ reset_color_code,
+ cache_host,
+ )
+ verbose_proxy_logger.debug(
+ "%sCache Port:%s %s",
+ blue_color_code,
+ reset_color_code,
+ cache_port,
+ )
+ verbose_proxy_logger.debug(
+ "%sCache Password:%s %s",
+ blue_color_code,
+ reset_color_code,
+ cache_password,
+ )
+ if cache_type == "redis-semantic":
+ # by default this should always be async
+ cache_params.update({"redis_semantic_cache_use_async": True})
+
+ # users can pass os.environ/ variables on the proxy - we should read them from the env
+ for key, value in cache_params.items():
+ if type(value) is str and value.startswith("os.environ/"):
+ cache_params[key] = get_secret(value)
+
+ ## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
+ self._init_cache(cache_params=cache_params)
+ if litellm.cache is not None:
+ verbose_proxy_logger.debug(
+ f"{blue_color_code}Set Cache on LiteLLM Proxy{reset_color_code}"
+ )
+ elif key == "cache" and value is False:
+ pass
+ elif key == "guardrails":
+ guardrail_name_config_map = initialize_guardrails(
+ guardrails_config=value,
+ premium_user=premium_user,
+ config_file_path=config_file_path,
+ litellm_settings=litellm_settings,
+ )
+
+ litellm.guardrail_name_config_map = guardrail_name_config_map
+ elif key == "callbacks":
+
+ initialize_callbacks_on_proxy(
+ value=value,
+ premium_user=premium_user,
+ config_file_path=config_file_path,
+ litellm_settings=litellm_settings,
+ )
+
+ elif key == "post_call_rules":
+ litellm.post_call_rules = [
+ get_instance_fn(value=value, config_file_path=config_file_path)
+ ]
+ verbose_proxy_logger.debug(
+ f"litellm.post_call_rules: {litellm.post_call_rules}"
+ )
+ elif key == "max_internal_user_budget":
+ litellm.max_internal_user_budget = float(value) # type: ignore
+ elif key == "default_max_internal_user_budget":
+ litellm.default_max_internal_user_budget = float(value)
+ if litellm.max_internal_user_budget is None:
+ litellm.max_internal_user_budget = (
+ litellm.default_max_internal_user_budget
+ )
+ elif key == "custom_provider_map":
+ from litellm.utils import custom_llm_setup
+
+ litellm.custom_provider_map = [
+ {
+ "provider": item["provider"],
+ "custom_handler": get_instance_fn(
+ value=item["custom_handler"],
+ config_file_path=config_file_path,
+ ),
+ }
+ for item in value
+ ]
+
+ custom_llm_setup()
+ elif key == "success_callback":
+ litellm.success_callback = []
+
+ # initialize success callbacks
+ for callback in value:
+ # user passed custom_callbacks.async_on_succes_logger. They need us to import a function
+ if "." in callback:
+ litellm.logging_callback_manager.add_litellm_success_callback(
+ get_instance_fn(value=callback)
+ )
+ # these are litellm callbacks - "langfuse", "sentry", "wandb"
+ else:
+ litellm.logging_callback_manager.add_litellm_success_callback(
+ callback
+ )
+ if "prometheus" in callback:
+ verbose_proxy_logger.debug(
+ "Starting Prometheus Metrics on /metrics"
+ )
+ from prometheus_client import make_asgi_app
+
+ # Add prometheus asgi middleware to route /metrics requests
+ metrics_app = make_asgi_app()
+ app.mount("/metrics", metrics_app)
+ print( # noqa
+ f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}"
+ ) # noqa
+ elif key == "failure_callback":
+ litellm.failure_callback = []
+
+ # initialize success callbacks
+ for callback in value:
+ # user passed custom_callbacks.async_on_succes_logger. They need us to import a function
+ if "." in callback:
+ litellm.logging_callback_manager.add_litellm_failure_callback(
+ get_instance_fn(value=callback)
+ )
+ # these are litellm callbacks - "langfuse", "sentry", "wandb"
+ else:
+ litellm.logging_callback_manager.add_litellm_failure_callback(
+ callback
+ )
+ print( # noqa
+ f"{blue_color_code} Initialized Failure Callbacks - {litellm.failure_callback} {reset_color_code}"
+ ) # noqa
+ elif key == "cache_params":
+ # this is set in the cache branch
+ # see usage here: https://docs.litellm.ai/docs/proxy/caching
+ pass
+ elif key == "default_team_settings":
+ for idx, team_setting in enumerate(
+ value
+ ): # run through pydantic validation
+ try:
+ TeamDefaultSettings(**team_setting)
+ except Exception:
+ if isinstance(team_setting, dict):
+ raise Exception(
+ f"team_id missing from default_team_settings at index={idx}\npassed in value={team_setting.keys()}"
+ )
+ raise Exception(
+ f"team_id missing from default_team_settings at index={idx}\npassed in value={type(team_setting)}"
+ )
+ verbose_proxy_logger.debug(
+ f"{blue_color_code} setting litellm.{key}={value}{reset_color_code}"
+ )
+ setattr(litellm, key, value)
+ elif key == "upperbound_key_generate_params":
+ if value is not None and isinstance(value, dict):
+ for _k, _v in value.items():
+ if isinstance(_v, str) and _v.startswith("os.environ/"):
+ value[_k] = get_secret(_v)
+ litellm.upperbound_key_generate_params = (
+ LiteLLM_UpperboundKeyGenerateParams(**value)
+ )
+ else:
+ raise Exception(
+ f"Invalid value set for upperbound_key_generate_params - value={value}"
+ )
+ else:
+ verbose_proxy_logger.debug(
+ f"{blue_color_code} setting litellm.{key}={value}{reset_color_code}"
+ )
+ setattr(litellm, key, value)
+
+ ## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
+ general_settings = config.get("general_settings", {})
+ if general_settings is None:
+ general_settings = {}
+ if general_settings:
+ ### LOAD SECRET MANAGER ###
+ key_management_system = general_settings.get("key_management_system", None)
+ self.initialize_secret_manager(key_management_system=key_management_system)
+ key_management_settings = general_settings.get(
+ "key_management_settings", None
+ )
+ if key_management_settings is not None:
+ litellm._key_management_settings = KeyManagementSettings(
+ **key_management_settings
+ )
+ ### [DEPRECATED] LOAD FROM GOOGLE KMS ### old way of loading from google kms
+ use_google_kms = general_settings.get("use_google_kms", False)
+ load_google_kms(use_google_kms=use_google_kms)
+ ### [DEPRECATED] LOAD FROM AZURE KEY VAULT ### old way of loading from azure secret manager
+ use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
+ load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
+ ### ALERTING ###
+ self._load_alerting_settings(general_settings=general_settings)
+ ### CONNECT TO DATABASE ###
+ database_url = general_settings.get("database_url", None)
+ if database_url and database_url.startswith("os.environ/"):
+ verbose_proxy_logger.debug("GOING INTO LITELLM.GET_SECRET!")
+ database_url = get_secret(database_url)
+ verbose_proxy_logger.debug("RETRIEVED DB URL: %s", database_url)
+ ### MASTER KEY ###
+ master_key = general_settings.get(
+ "master_key", get_secret("LITELLM_MASTER_KEY", None)
+ )
+
+ if master_key and master_key.startswith("os.environ/"):
+ master_key = get_secret(master_key) # type: ignore
+ if not isinstance(master_key, str):
+ raise Exception(
+ "Master key must be a string. Current type - {}".format(
+ type(master_key)
+ )
+ )
+
+ if master_key is not None and isinstance(master_key, str):
+ litellm_master_key_hash = hash_token(master_key)
+ ### USER API KEY CACHE IN-MEMORY TTL ###
+ user_api_key_cache_ttl = general_settings.get(
+ "user_api_key_cache_ttl", None
+ )
+ if user_api_key_cache_ttl is not None:
+ user_api_key_cache.update_cache_ttl(
+ default_in_memory_ttl=float(user_api_key_cache_ttl),
+ default_redis_ttl=None, # user_api_key_cache is an in-memory cache
+ )
+ ### STORE MODEL IN DB ### feature flag for `/model/new`
+ store_model_in_db = general_settings.get("store_model_in_db", False)
+ if store_model_in_db is None:
+ store_model_in_db = False
+ ### CUSTOM API KEY AUTH ###
+ ## pass filepath
+ custom_auth = general_settings.get("custom_auth", None)
+ if custom_auth is not None:
+ user_custom_auth = get_instance_fn(
+ value=custom_auth, config_file_path=config_file_path
+ )
+
+ custom_key_generate = general_settings.get("custom_key_generate", None)
+ if custom_key_generate is not None:
+ user_custom_key_generate = get_instance_fn(
+ value=custom_key_generate, config_file_path=config_file_path
+ )
+
+ custom_sso = general_settings.get("custom_sso", None)
+ if custom_sso is not None:
+ user_custom_sso = get_instance_fn(
+ value=custom_sso, config_file_path=config_file_path
+ )
+
+ ## pass through endpoints
+ if general_settings.get("pass_through_endpoints", None) is not None:
+ await initialize_pass_through_endpoints(
+ pass_through_endpoints=general_settings["pass_through_endpoints"]
+ )
+ ## ADMIN UI ACCESS ##
+ ui_access_mode = general_settings.get(
+ "ui_access_mode", "all"
+ ) # can be either ["admin_only" or "all"]
+ ### ALLOWED IP ###
+ allowed_ips = general_settings.get("allowed_ips", None)
+ if allowed_ips is not None and premium_user is False:
+ raise ValueError(
+ "allowed_ips is an Enterprise Feature. Please add a valid LITELLM_LICENSE to your envionment."
+ )
+ ## BUDGET RESCHEDULER ##
+ proxy_budget_rescheduler_min_time = general_settings.get(
+ "proxy_budget_rescheduler_min_time", proxy_budget_rescheduler_min_time
+ )
+ proxy_budget_rescheduler_max_time = general_settings.get(
+ "proxy_budget_rescheduler_max_time", proxy_budget_rescheduler_max_time
+ )
+ ## BATCH WRITER ##
+ proxy_batch_write_at = general_settings.get(
+ "proxy_batch_write_at", proxy_batch_write_at
+ )
+ ## DISABLE SPEND LOGS ## - gives a perf improvement
+ disable_spend_logs = general_settings.get(
+ "disable_spend_logs", disable_spend_logs
+ )
+ ### BACKGROUND HEALTH CHECKS ###
+ # Enable background health checks
+ use_background_health_checks = general_settings.get(
+ "background_health_checks", False
+ )
+ health_check_interval = general_settings.get("health_check_interval", 300)
+ health_check_details = general_settings.get("health_check_details", True)
+
+ ### RBAC ###
+ rbac_role_permissions = general_settings.get("role_permissions", None)
+ if rbac_role_permissions is not None:
+ general_settings["role_permissions"] = [ # validate role permissions
+ RoleBasedPermissions(**role_permission)
+ for role_permission in rbac_role_permissions
+ ]
+
+ ## check if user has set a premium feature in general_settings
+ if (
+ general_settings.get("enforced_params") is not None
+ and premium_user is not True
+ ):
+ raise ValueError(
+ "Trying to use `enforced_params`"
+ + CommonProxyErrors.not_premium_user.value
+ )
+
+ # check if litellm_license in general_settings
+ if "litellm_license" in general_settings:
+ _license_check.license_str = general_settings["litellm_license"]
+ premium_user = _license_check.is_premium()
+
+ router_params: dict = {
+ "cache_responses": litellm.cache
+ is not None, # cache if user passed in cache values
+ }
+ ## MODEL LIST
+ model_list = config.get("model_list", None)
+ if model_list:
+ router_params["model_list"] = model_list
+ print( # noqa
+ "\033[32mLiteLLM: Proxy initialized with Config, Set models:\033[0m"
+ ) # noqa
+ for model in model_list:
+ ### LOAD FROM os.environ/ ###
+ for k, v in model["litellm_params"].items():
+ if isinstance(v, str) and v.startswith("os.environ/"):
+ model["litellm_params"][k] = get_secret(v)
+ print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa
+ litellm_model_name = model["litellm_params"]["model"]
+ litellm_model_api_base = model["litellm_params"].get("api_base", None)
+ if "ollama" in litellm_model_name and litellm_model_api_base is None:
+ run_ollama_serve()
+
+ ## ASSISTANT SETTINGS
+ assistants_config: Optional[AssistantsTypedDict] = None
+ assistant_settings = config.get("assistant_settings", None)
+ if assistant_settings:
+ for k, v in assistant_settings["litellm_params"].items():
+ if isinstance(v, str) and v.startswith("os.environ/"):
+ _v = v.replace("os.environ/", "")
+ v = os.getenv(_v)
+ assistant_settings["litellm_params"][k] = v
+ assistants_config = AssistantsTypedDict(**assistant_settings) # type: ignore
+
+ ## /fine_tuning/jobs endpoints config
+ finetuning_config = config.get("finetune_settings", None)
+ set_fine_tuning_config(config=finetuning_config)
+
+ ## /files endpoint config
+ files_config = config.get("files_settings", None)
+ set_files_config(config=files_config)
+
+ ## default config for vertex ai routes
+ default_vertex_config = config.get("default_vertex_config", None)
+ set_default_vertex_config(config=default_vertex_config)
+
+ ## ROUTER SETTINGS (e.g. routing_strategy, ...)
+ router_settings = config.get("router_settings", None)
+ if router_settings and isinstance(router_settings, dict):
+ arg_spec = inspect.getfullargspec(litellm.Router)
+ # model list already set
+ exclude_args = {
+ "self",
+ "model_list",
+ }
+
+ available_args = [x for x in arg_spec.args if x not in exclude_args]
+
+ for k, v in router_settings.items():
+ if k in available_args:
+ router_params[k] = v
+ router = litellm.Router(
+ **router_params,
+ assistants_config=assistants_config,
+ router_general_settings=RouterGeneralSettings(
+ async_only_mode=True # only init async clients
+ ),
+ ) # type:ignore
+
+ if redis_usage_cache is not None and router.cache.redis_cache is None:
+ router._update_redis_cache(cache=redis_usage_cache)
+
+ # Guardrail settings
+ guardrails_v2: Optional[List[Dict]] = None
+
+ if config is not None:
+ guardrails_v2 = config.get("guardrails", None)
+ if guardrails_v2:
+ init_guardrails_v2(
+ all_guardrails=guardrails_v2, config_file_path=config_file_path
+ )
+
+ ## CREDENTIALS
+ credential_list_dict = self.load_credential_list(config=config)
+ litellm.credential_list = credential_list_dict
+ return router, router.get_model_list(), general_settings
+
+ def _load_alerting_settings(self, general_settings: dict):
+ """
+ Initialize alerting settings
+ """
+ from litellm.litellm_core_utils.litellm_logging import (
+ _init_custom_logger_compatible_class,
+ )
+
+ _alerting_callbacks = general_settings.get("alerting", None)
+ verbose_proxy_logger.debug(f"_alerting_callbacks: {general_settings}")
+ if _alerting_callbacks is None:
+ return
+ for _alert in _alerting_callbacks:
+ if _alert == "slack":
+ # [OLD] v0 implementation
+ proxy_logging_obj.update_values(
+ alerting=general_settings.get("alerting", None),
+ alerting_threshold=general_settings.get("alerting_threshold", 600),
+ alert_types=general_settings.get("alert_types", None),
+ alert_to_webhook_url=general_settings.get(
+ "alert_to_webhook_url", None
+ ),
+ alerting_args=general_settings.get("alerting_args", None),
+ redis_cache=redis_usage_cache,
+ )
+ else:
+ # [NEW] v1 implementation - init as a custom logger
+ if _alert in litellm._known_custom_logger_compatible_callbacks:
+ _logger = _init_custom_logger_compatible_class(
+ logging_integration=_alert,
+ internal_usage_cache=None,
+ llm_router=None,
+ custom_logger_init_args={
+ "alerting_args": general_settings.get("alerting_args", None)
+ },
+ )
+ if _logger is not None:
+ litellm.logging_callback_manager.add_litellm_callback(_logger)
+ pass
+
+ def initialize_secret_manager(self, key_management_system: Optional[str]):
+ """
+ Initialize the relevant secret manager if `key_management_system` is provided
+ """
+ if key_management_system is not None:
+ if key_management_system == KeyManagementSystem.AZURE_KEY_VAULT.value:
+ ### LOAD FROM AZURE KEY VAULT ###
+ load_from_azure_key_vault(use_azure_key_vault=True)
+ elif key_management_system == KeyManagementSystem.GOOGLE_KMS.value:
+ ### LOAD FROM GOOGLE KMS ###
+ load_google_kms(use_google_kms=True)
+ elif (
+ key_management_system
+ == KeyManagementSystem.AWS_SECRET_MANAGER.value # noqa: F405
+ ):
+ from litellm.secret_managers.aws_secret_manager_v2 import (
+ AWSSecretsManagerV2,
+ )
+
+ AWSSecretsManagerV2.load_aws_secret_manager(use_aws_secret_manager=True)
+ elif key_management_system == KeyManagementSystem.AWS_KMS.value:
+ load_aws_kms(use_aws_kms=True)
+ elif (
+ key_management_system == KeyManagementSystem.GOOGLE_SECRET_MANAGER.value
+ ):
+ from litellm.secret_managers.google_secret_manager import (
+ GoogleSecretManager,
+ )
+
+ GoogleSecretManager()
+ elif key_management_system == KeyManagementSystem.HASHICORP_VAULT.value:
+ from litellm.secret_managers.hashicorp_secret_manager import (
+ HashicorpSecretManager,
+ )
+
+ HashicorpSecretManager()
+ else:
+ raise ValueError("Invalid Key Management System selected")
+
+ def get_model_info_with_id(self, model, db_model=False) -> RouterModelInfo:
+ """
+ Common logic across add + delete router models
+ Parameters:
+ - deployment
+ - db_model -> flag for differentiating model stored in db vs. config -> used on UI
+
+ Return model info w/ id
+ """
+ _id: Optional[str] = getattr(model, "model_id", None)
+ if _id is not None:
+ model.model_info["id"] = _id
+ model.model_info["db_model"] = True
+
+ if premium_user is True:
+ # seeing "created_at", "updated_at", "created_by", "updated_by" is a LiteLLM Enterprise Feature
+ model.model_info["created_at"] = getattr(model, "created_at", None)
+ model.model_info["updated_at"] = getattr(model, "updated_at", None)
+ model.model_info["created_by"] = getattr(model, "created_by", None)
+ model.model_info["updated_by"] = getattr(model, "updated_by", None)
+
+ if model.model_info is not None and isinstance(model.model_info, dict):
+ if "id" not in model.model_info:
+ model.model_info["id"] = model.model_id
+ if "db_model" in model.model_info and model.model_info["db_model"] is False:
+ model.model_info["db_model"] = db_model
+ _model_info = RouterModelInfo(**model.model_info)
+
+ else:
+ _model_info = RouterModelInfo(id=model.model_id, db_model=db_model)
+ return _model_info
+
+ async def _delete_deployment(self, db_models: list) -> int:
+ """
+ (Helper function of add deployment) -> combined to reduce prisma db calls
+
+ - Create all up list of model id's (db + config)
+ - Compare all up list to router model id's
+ - Remove any that are missing
+
+ Return:
+ - int - returns number of deleted deployments
+ """
+ global user_config_file_path, llm_router
+ combined_id_list = []
+
+ ## BASE CASES ##
+ # if llm_router is None or db_models is empty, return 0
+ if llm_router is None or len(db_models) == 0:
+ return 0
+
+ ## DB MODELS ##
+ for m in db_models:
+ model_info = self.get_model_info_with_id(model=m)
+ if model_info.id is not None:
+ combined_id_list.append(model_info.id)
+
+ ## CONFIG MODELS ##
+ config = await self.get_config(config_file_path=user_config_file_path)
+ model_list = config.get("model_list", None)
+ if model_list:
+ for model in model_list:
+ ### LOAD FROM os.environ/ ###
+ for k, v in model["litellm_params"].items():
+ if isinstance(v, str) and v.startswith("os.environ/"):
+ model["litellm_params"][k] = get_secret(v)
+
+ ## check if they have model-id's ##
+ model_id = model.get("model_info", {}).get("id", None)
+ if model_id is None:
+ ## else - generate stable id's ##
+ model_id = llm_router._generate_model_id(
+ model_group=model["model_name"],
+ litellm_params=model["litellm_params"],
+ )
+ combined_id_list.append(model_id) # ADD CONFIG MODEL TO COMBINED LIST
+
+ router_model_ids = llm_router.get_model_ids()
+ # Check for model IDs in llm_router not present in combined_id_list and delete them
+
+ deleted_deployments = 0
+ for model_id in router_model_ids:
+ if model_id not in combined_id_list:
+ is_deleted = llm_router.delete_deployment(id=model_id)
+ if is_deleted is not None:
+ deleted_deployments += 1
+ return deleted_deployments
+
+ def _add_deployment(self, db_models: list) -> int:
+ """
+ Iterate through db models
+
+ for any not in router - add them.
+
+ Return - number of deployments added
+ """
+ import base64
+
+ if master_key is None or not isinstance(master_key, str):
+ raise Exception(
+ f"Master key is not initialized or formatted. master_key={master_key}"
+ )
+
+ if llm_router is None:
+ return 0
+
+ added_models = 0
+ ## ADD MODEL LOGIC
+ for m in db_models:
+ _litellm_params = m.litellm_params
+ if isinstance(_litellm_params, dict):
+ # decrypt values
+ for k, v in _litellm_params.items():
+ if isinstance(v, str):
+ # decrypt value
+ _value = decrypt_value_helper(value=v)
+ if _value is None:
+ raise Exception("Unable to decrypt value={}".format(v))
+ # sanity check if string > size 0
+ if len(_value) > 0:
+ _litellm_params[k] = _value
+ _litellm_params = LiteLLM_Params(**_litellm_params)
+
+ else:
+ verbose_proxy_logger.error(
+ f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
+ )
+ continue # skip to next model
+ _model_info = self.get_model_info_with_id(
+ model=m, db_model=True
+ ) ## šŸ‘ˆ FLAG = True for db_models
+
+ added = llm_router.upsert_deployment(
+ deployment=Deployment(
+ model_name=m.model_name,
+ litellm_params=_litellm_params,
+ model_info=_model_info,
+ )
+ )
+
+ if added is not None:
+ added_models += 1
+ return added_models
+
+ def decrypt_model_list_from_db(self, new_models: list) -> list:
+ _model_list: list = []
+ for m in new_models:
+ _litellm_params = m.litellm_params
+ if isinstance(_litellm_params, dict):
+ # decrypt values
+ for k, v in _litellm_params.items():
+ decrypted_value = decrypt_value_helper(value=v)
+ _litellm_params[k] = decrypted_value
+ _litellm_params = LiteLLM_Params(**_litellm_params)
+ else:
+ verbose_proxy_logger.error(
+ f"Invalid model added to proxy db. Invalid litellm params. litellm_params={_litellm_params}"
+ )
+ continue # skip to next model
+
+ _model_info = self.get_model_info_with_id(model=m)
+ _model_list.append(
+ Deployment(
+ model_name=m.model_name,
+ litellm_params=_litellm_params,
+ model_info=_model_info,
+ ).to_json(exclude_none=True)
+ )
+
+ return _model_list
+
+ async def _update_llm_router(
+ self,
+ new_models: list,
+ proxy_logging_obj: ProxyLogging,
+ ):
+ global llm_router, llm_model_list, master_key, general_settings
+
+ try:
+ if llm_router is None and master_key is not None:
+ verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
+
+ _model_list: list = self.decrypt_model_list_from_db(
+ new_models=new_models
+ )
+ if len(_model_list) > 0:
+ verbose_proxy_logger.debug(f"_model_list: {_model_list}")
+ llm_router = litellm.Router(
+ model_list=_model_list,
+ router_general_settings=RouterGeneralSettings(
+ async_only_mode=True # only init async clients
+ ),
+ )
+ verbose_proxy_logger.debug(f"updated llm_router: {llm_router}")
+ else:
+ verbose_proxy_logger.debug(f"len new_models: {len(new_models)}")
+ ## DELETE MODEL LOGIC
+ await self._delete_deployment(db_models=new_models)
+
+ ## ADD MODEL LOGIC
+ self._add_deployment(db_models=new_models)
+
+ except Exception as e:
+ verbose_proxy_logger.exception(
+ f"Error adding/deleting model to llm_router: {str(e)}"
+ )
+
+ if llm_router is not None:
+ llm_model_list = llm_router.get_model_list()
+
+ # check if user set any callbacks in Config Table
+ config_data = await proxy_config.get_config()
+ self._add_callbacks_from_db_config(config_data)
+
+ # we need to set env variables too
+ self._add_environment_variables_from_db_config(config_data)
+
+ # router settings
+ await self._add_router_settings_from_db_config(
+ config_data=config_data, llm_router=llm_router, prisma_client=prisma_client
+ )
+
+ # general settings
+ self._add_general_settings_from_db_config(
+ config_data=config_data,
+ general_settings=general_settings,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+
+ def _add_callbacks_from_db_config(self, config_data: dict) -> None:
+ """
+ Adds callbacks from DB config to litellm
+ """
+ litellm_settings = config_data.get("litellm_settings", {}) or {}
+ success_callbacks = litellm_settings.get("success_callback", None)
+ failure_callbacks = litellm_settings.get("failure_callback", None)
+
+ if success_callbacks is not None and isinstance(success_callbacks, list):
+ for success_callback in success_callbacks:
+ if (
+ success_callback
+ in litellm._known_custom_logger_compatible_callbacks
+ ):
+ _add_custom_logger_callback_to_specific_event(
+ success_callback, "success"
+ )
+ elif success_callback not in litellm.success_callback:
+ litellm.logging_callback_manager.add_litellm_success_callback(
+ success_callback
+ )
+
+ # Add failure callbacks from DB to litellm
+ if failure_callbacks is not None and isinstance(failure_callbacks, list):
+ for failure_callback in failure_callbacks:
+ if (
+ failure_callback
+ in litellm._known_custom_logger_compatible_callbacks
+ ):
+ _add_custom_logger_callback_to_specific_event(
+ failure_callback, "failure"
+ )
+ elif failure_callback not in litellm.failure_callback:
+ litellm.logging_callback_manager.add_litellm_failure_callback(
+ failure_callback
+ )
+
+ def _add_environment_variables_from_db_config(self, config_data: dict) -> None:
+ """
+ Adds environment variables from DB config to litellm
+ """
+ environment_variables = config_data.get("environment_variables", {})
+ self._decrypt_and_set_db_env_variables(environment_variables)
+
+ def _encrypt_env_variables(
+ self, environment_variables: dict, new_encryption_key: Optional[str] = None
+ ) -> dict:
+ """
+ Encrypts a dictionary of environment variables and returns them.
+ """
+ encrypted_env_vars = {}
+ for k, v in environment_variables.items():
+ encrypted_value = encrypt_value_helper(
+ value=v, new_encryption_key=new_encryption_key
+ )
+ encrypted_env_vars[k] = encrypted_value
+ return encrypted_env_vars
+
+ def _decrypt_and_set_db_env_variables(self, environment_variables: dict) -> dict:
+ """
+ Decrypts a dictionary of environment variables and then sets them in the environment
+
+ Args:
+ environment_variables: dict - dictionary of environment variables to decrypt and set
+ eg. `{"LANGFUSE_PUBLIC_KEY": "kFiKa1VZukMmD8RB6WXB9F......."}`
+ """
+ decrypted_env_vars = {}
+ for k, v in environment_variables.items():
+ try:
+ decrypted_value = decrypt_value_helper(value=v)
+ if decrypted_value is not None:
+ os.environ[k] = decrypted_value
+ decrypted_env_vars[k] = decrypted_value
+ except Exception as e:
+ verbose_proxy_logger.error(
+ "Error setting env variable: %s - %s", k, str(e)
+ )
+ return decrypted_env_vars
+
+ async def _add_router_settings_from_db_config(
+ self,
+ config_data: dict,
+ llm_router: Optional[Router],
+ prisma_client: Optional[PrismaClient],
+ ) -> None:
+ """
+ Adds router settings from DB config to litellm proxy
+ """
+ if llm_router is not None and prisma_client is not None:
+ db_router_settings = await prisma_client.db.litellm_config.find_first(
+ where={"param_name": "router_settings"}
+ )
+ if (
+ db_router_settings is not None
+ and db_router_settings.param_value is not None
+ ):
+ _router_settings = db_router_settings.param_value
+ llm_router.update_settings(**_router_settings)
+
+ def _add_general_settings_from_db_config(
+ self, config_data: dict, general_settings: dict, proxy_logging_obj: ProxyLogging
+ ) -> None:
+ """
+ Adds general settings from DB config to litellm proxy
+
+ Args:
+ config_data: dict
+ general_settings: dict - global general_settings currently in use
+ proxy_logging_obj: ProxyLogging
+ """
+ _general_settings = config_data.get("general_settings", {})
+ if "alerting" in _general_settings:
+ if (
+ general_settings is not None
+ and general_settings.get("alerting", None) is not None
+ and isinstance(general_settings["alerting"], list)
+ and _general_settings.get("alerting", None) is not None
+ and isinstance(_general_settings["alerting"], list)
+ ):
+ verbose_proxy_logger.debug(
+ "Overriding Default 'alerting' values with db 'alerting' values."
+ )
+ general_settings["alerting"] = _general_settings[
+ "alerting"
+ ] # override yaml values with db
+ proxy_logging_obj.alerting = general_settings["alerting"]
+ proxy_logging_obj.slack_alerting_instance.alerting = general_settings[
+ "alerting"
+ ]
+ elif general_settings is None:
+ general_settings = {}
+ general_settings["alerting"] = _general_settings["alerting"]
+ proxy_logging_obj.alerting = general_settings["alerting"]
+ proxy_logging_obj.slack_alerting_instance.alerting = general_settings[
+ "alerting"
+ ]
+ elif isinstance(general_settings, dict):
+ general_settings["alerting"] = _general_settings["alerting"]
+ proxy_logging_obj.alerting = general_settings["alerting"]
+ proxy_logging_obj.slack_alerting_instance.alerting = general_settings[
+ "alerting"
+ ]
+
+ if "alert_types" in _general_settings:
+ general_settings["alert_types"] = _general_settings["alert_types"]
+ proxy_logging_obj.alert_types = general_settings["alert_types"]
+ proxy_logging_obj.slack_alerting_instance.update_values(
+ alert_types=general_settings["alert_types"], llm_router=llm_router
+ )
+
+ if "alert_to_webhook_url" in _general_settings:
+ general_settings["alert_to_webhook_url"] = _general_settings[
+ "alert_to_webhook_url"
+ ]
+ proxy_logging_obj.slack_alerting_instance.update_values(
+ alert_to_webhook_url=general_settings["alert_to_webhook_url"],
+ llm_router=llm_router,
+ )
+
+ async def _update_general_settings(self, db_general_settings: Optional[Json]):
+ """
+ Pull from DB, read general settings value
+ """
+ global general_settings
+ if db_general_settings is None:
+ return
+ _general_settings = dict(db_general_settings)
+ ## MAX PARALLEL REQUESTS ##
+ if "max_parallel_requests" in _general_settings:
+ general_settings["max_parallel_requests"] = _general_settings[
+ "max_parallel_requests"
+ ]
+
+ if "global_max_parallel_requests" in _general_settings:
+ general_settings["global_max_parallel_requests"] = _general_settings[
+ "global_max_parallel_requests"
+ ]
+
+ ## ALERTING ARGS ##
+ if "alerting_args" in _general_settings:
+ general_settings["alerting_args"] = _general_settings["alerting_args"]
+ proxy_logging_obj.slack_alerting_instance.update_values(
+ alerting_args=general_settings["alerting_args"],
+ )
+
+ ## PASS-THROUGH ENDPOINTS ##
+ if "pass_through_endpoints" in _general_settings:
+ general_settings["pass_through_endpoints"] = _general_settings[
+ "pass_through_endpoints"
+ ]
+ await initialize_pass_through_endpoints(
+ pass_through_endpoints=general_settings["pass_through_endpoints"]
+ )
+
+ def _update_config_fields(
+ self,
+ current_config: dict,
+ param_name: Literal[
+ "general_settings",
+ "router_settings",
+ "litellm_settings",
+ "environment_variables",
+ ],
+ db_param_value: Any,
+ ) -> dict:
+ """
+ Updates the config fields with the new values from the DB
+
+ Args:
+ current_config (dict): Current configuration dictionary to update
+ param_name (Literal): Name of the parameter to update
+ db_param_value (Any): New value from the database
+
+ Returns:
+ dict: Updated configuration dictionary
+ """
+ if param_name == "environment_variables":
+ self._decrypt_and_set_db_env_variables(db_param_value)
+ return current_config
+
+ # If param doesn't exist in config, add it
+ if param_name not in current_config:
+ current_config[param_name] = db_param_value
+ return current_config
+
+ # For dictionary values, update only non-empty values
+ if isinstance(current_config[param_name], dict):
+ # Only keep non None values from db_param_value
+ non_empty_values = {k: v for k, v in db_param_value.items() if v}
+
+ # Update the config with non-empty values
+ current_config[param_name].update(non_empty_values)
+ else:
+ current_config[param_name] = db_param_value
+ return current_config
+
+ async def _update_config_from_db(
+ self,
+ prisma_client: PrismaClient,
+ config: dict,
+ store_model_in_db: Optional[bool],
+ ):
+ if store_model_in_db is not True:
+ verbose_proxy_logger.info(
+ "'store_model_in_db' is not True, skipping db updates"
+ )
+ return config
+
+ _tasks = []
+ keys = [
+ "general_settings",
+ "router_settings",
+ "litellm_settings",
+ "environment_variables",
+ ]
+ for k in keys:
+ response = prisma_client.get_generic_data(
+ key="param_name", value=k, table_name="config"
+ )
+ _tasks.append(response)
+
+ responses = await asyncio.gather(*_tasks)
+ for response in responses:
+ if response is None:
+ continue
+
+ param_name = getattr(response, "param_name", None)
+ param_value = getattr(response, "param_value", None)
+ verbose_proxy_logger.debug(
+ f"param_name={param_name}, param_value={param_value}"
+ )
+
+ if param_name is not None and param_value is not None:
+ config = self._update_config_fields(
+ current_config=config,
+ param_name=param_name,
+ db_param_value=param_value,
+ )
+
+ return config
+
+ async def _get_models_from_db(self, prisma_client: PrismaClient) -> list:
+ try:
+ new_models = await prisma_client.db.litellm_proxymodeltable.find_many()
+ except Exception as e:
+ verbose_proxy_logger.exception(
+ "litellm.proxy_server.py::add_deployment() - Error getting new models from DB - {}".format(
+ str(e)
+ )
+ )
+ new_models = []
+
+ return new_models
+
+ async def add_deployment(
+ self,
+ prisma_client: PrismaClient,
+ proxy_logging_obj: ProxyLogging,
+ ):
+ """
+ - Check db for new models
+ - Check if model id's in router already
+ - If not, add to router
+ """
+ global llm_router, llm_model_list, master_key, general_settings
+
+ try:
+ if master_key is None or not isinstance(master_key, str):
+ raise ValueError(
+ f"Master key is not initialized or formatted. master_key={master_key}"
+ )
+
+ new_models = await self._get_models_from_db(prisma_client=prisma_client)
+
+ # update llm router
+ await self._update_llm_router(
+ new_models=new_models, proxy_logging_obj=proxy_logging_obj
+ )
+
+ db_general_settings = await prisma_client.db.litellm_config.find_first(
+ where={"param_name": "general_settings"}
+ )
+
+ # update general settings
+ if db_general_settings is not None:
+ await self._update_general_settings(
+ db_general_settings=db_general_settings.param_value,
+ )
+
+ except Exception as e:
+ verbose_proxy_logger.exception(
+ "litellm.proxy.proxy_server.py::ProxyConfig:add_deployment - {}".format(
+ str(e)
+ )
+ )
+
+ def decrypt_credentials(self, credential: Union[dict, BaseModel]) -> CredentialItem:
+ if isinstance(credential, dict):
+ credential_object = CredentialItem(**credential)
+ elif isinstance(credential, BaseModel):
+ credential_object = CredentialItem(**credential.model_dump())
+
+ decrypted_credential_values = {}
+ for k, v in credential_object.credential_values.items():
+ decrypted_credential_values[k] = decrypt_value_helper(v) or v
+
+ credential_object.credential_values = decrypted_credential_values
+ return credential_object
+
+ async def delete_credentials(self, db_credentials: List[CredentialItem]):
+ """
+ Create all-up list of db credentials + local credentials
+ Compare to the litellm.credential_list
+ Delete any from litellm.credential_list that are not in the all-up list
+ """
+ ## CONFIG credentials ##
+ config = await self.get_config(config_file_path=user_config_file_path)
+ credential_list = self.load_credential_list(config=config)
+
+ ## COMBINED LIST ##
+ combined_list = db_credentials + credential_list
+
+ ## DELETE ##
+ idx_to_delete = []
+ for idx, credential in enumerate(litellm.credential_list):
+ if credential.credential_name not in [
+ cred.credential_name for cred in combined_list
+ ]:
+ idx_to_delete.append(idx)
+ for idx in sorted(idx_to_delete, reverse=True):
+ litellm.credential_list.pop(idx)
+
+ async def get_credentials(self, prisma_client: PrismaClient):
+ try:
+ credentials = await prisma_client.db.litellm_credentialstable.find_many()
+ credentials = [self.decrypt_credentials(cred) for cred in credentials]
+ await self.delete_credentials(
+ credentials
+ ) # delete credentials that are not in the all-up list
+ CredentialAccessor.upsert_credentials(
+ credentials
+ ) # upsert credentials that are in the all-up list
+ except Exception as e:
+ verbose_proxy_logger.exception(
+ "litellm.proxy_server.py::get_credentials() - Error getting credentials from DB - {}".format(
+ str(e)
+ )
+ )
+ return []
+
+
+proxy_config = ProxyConfig()
+
+
+def save_worker_config(**data):
+ import json
+
+ os.environ["WORKER_CONFIG"] = json.dumps(data)
+
+
+async def initialize( # noqa: PLR0915
+ model=None,
+ alias=None,
+ api_base=None,
+ api_version=None,
+ debug=False,
+ detailed_debug=False,
+ temperature=None,
+ max_tokens=None,
+ request_timeout=600,
+ max_budget=None,
+ telemetry=False,
+ drop_params=True,
+ add_function_to_prompt=True,
+ headers=None,
+ save=False,
+ use_queue=False,
+ config=None,
+):
+ global user_model, user_api_base, user_debug, user_detailed_debug, user_user_max_tokens, user_request_timeout, user_temperature, user_telemetry, user_headers, experimental, llm_model_list, llm_router, general_settings, master_key, user_custom_auth, prisma_client
+ if os.getenv("LITELLM_DONT_SHOW_FEEDBACK_BOX", "").lower() != "true":
+ generate_feedback_box()
+ user_model = model
+ user_debug = debug
+ if debug is True: # this needs to be first, so users can see Router init debugg
+ import logging
+
+ from litellm._logging import (
+ verbose_logger,
+ verbose_proxy_logger,
+ verbose_router_logger,
+ )
+
+ # this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
+ verbose_logger.setLevel(level=logging.INFO) # sets package logs to info
+ verbose_router_logger.setLevel(level=logging.INFO) # set router logs to info
+ verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info
+ if detailed_debug is True:
+ import logging
+
+ from litellm._logging import (
+ verbose_logger,
+ verbose_proxy_logger,
+ verbose_router_logger,
+ )
+
+ verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
+ verbose_router_logger.setLevel(level=logging.DEBUG) # set router logs to debug
+ verbose_proxy_logger.setLevel(level=logging.DEBUG) # set proxy logs to debug
+ elif debug is False and detailed_debug is False:
+ # users can control proxy debugging using env variable = 'LITELLM_LOG'
+ litellm_log_setting = os.environ.get("LITELLM_LOG", "")
+ if litellm_log_setting is not None:
+ if litellm_log_setting.upper() == "INFO":
+ import logging
+
+ from litellm._logging import verbose_proxy_logger, verbose_router_logger
+
+ # this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
+
+ verbose_router_logger.setLevel(
+ level=logging.INFO
+ ) # set router logs to info
+ verbose_proxy_logger.setLevel(
+ level=logging.INFO
+ ) # set proxy logs to info
+ elif litellm_log_setting.upper() == "DEBUG":
+ import logging
+
+ from litellm._logging import verbose_proxy_logger, verbose_router_logger
+
+ verbose_router_logger.setLevel(
+ level=logging.DEBUG
+ ) # set router logs to info
+ verbose_proxy_logger.setLevel(
+ level=logging.DEBUG
+ ) # set proxy logs to debug
+ dynamic_config = {"general": {}, user_model: {}}
+ if config:
+ (
+ llm_router,
+ llm_model_list,
+ general_settings,
+ ) = await proxy_config.load_config(router=llm_router, config_file_path=config)
+ if headers: # model-specific param
+ user_headers = headers
+ dynamic_config[user_model]["headers"] = headers
+ if api_base: # model-specific param
+ user_api_base = api_base
+ dynamic_config[user_model]["api_base"] = api_base
+ if api_version:
+ os.environ["AZURE_API_VERSION"] = (
+ api_version # set this for azure - litellm can read this from the env
+ )
+ if max_tokens: # model-specific param
+ dynamic_config[user_model]["max_tokens"] = max_tokens
+ if temperature: # model-specific param
+ user_temperature = temperature
+ dynamic_config[user_model]["temperature"] = temperature
+ if request_timeout:
+ user_request_timeout = request_timeout
+ dynamic_config[user_model]["request_timeout"] = request_timeout
+ if alias: # model-specific param
+ dynamic_config[user_model]["alias"] = alias
+ if drop_params is True: # litellm-specific param
+ litellm.drop_params = True
+ dynamic_config["general"]["drop_params"] = True
+ if add_function_to_prompt is True: # litellm-specific param
+ litellm.add_function_to_prompt = True
+ dynamic_config["general"]["add_function_to_prompt"] = True
+ if max_budget: # litellm-specific param
+ litellm.max_budget = max_budget
+ dynamic_config["general"]["max_budget"] = max_budget
+ if experimental:
+ pass
+ user_telemetry = telemetry
+
+
+# for streaming
+def data_generator(response):
+ verbose_proxy_logger.debug("inside generator")
+ for chunk in response:
+ verbose_proxy_logger.debug("returned chunk: %s", chunk)
+ try:
+ yield f"data: {json.dumps(chunk.dict())}\n\n"
+ except Exception:
+ yield f"data: {json.dumps(chunk)}\n\n"
+
+
+async def async_assistants_data_generator(
+ response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
+):
+ verbose_proxy_logger.debug("inside generator")
+ try:
+ time.time()
+ async with response as chunk:
+
+ ### CALL HOOKS ### - modify outgoing data
+ chunk = await proxy_logging_obj.async_post_call_streaming_hook(
+ user_api_key_dict=user_api_key_dict, response=chunk
+ )
+
+ # chunk = chunk.model_dump_json(exclude_none=True)
+ async for c in chunk: # type: ignore
+ c = c.model_dump_json(exclude_none=True)
+ try:
+ yield f"data: {c}\n\n"
+ except Exception as e:
+ yield f"data: {str(e)}\n\n"
+
+ # Streaming is done, yield the [DONE] chunk
+ done_message = "[DONE]"
+ yield f"data: {done_message}\n\n"
+ except Exception as e:
+ verbose_proxy_logger.exception(
+ "litellm.proxy.proxy_server.async_assistants_data_generator(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ await proxy_logging_obj.post_call_failure_hook(
+ user_api_key_dict=user_api_key_dict,
+ original_exception=e,
+ request_data=request_data,
+ )
+ verbose_proxy_logger.debug(
+ f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
+ )
+ if isinstance(e, HTTPException):
+ raise e
+ else:
+ error_traceback = traceback.format_exc()
+ error_msg = f"{str(e)}\n\n{error_traceback}"
+
+ proxy_exception = ProxyException(
+ message=getattr(e, "message", error_msg),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", 500),
+ )
+ error_returned = json.dumps({"error": proxy_exception.to_dict()})
+ yield f"data: {error_returned}\n\n"
+
+
+async def async_data_generator(
+ response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
+):
+ verbose_proxy_logger.debug("inside generator")
+ try:
+ async for chunk in proxy_logging_obj.async_post_call_streaming_iterator_hook(
+ user_api_key_dict=user_api_key_dict,
+ response=response,
+ request_data=request_data,
+ ):
+ verbose_proxy_logger.debug(
+ "async_data_generator: received streaming chunk - {}".format(chunk)
+ )
+ ### CALL HOOKS ### - modify outgoing data
+ chunk = await proxy_logging_obj.async_post_call_streaming_hook(
+ user_api_key_dict=user_api_key_dict, response=chunk
+ )
+
+ if isinstance(chunk, BaseModel):
+ chunk = chunk.model_dump_json(exclude_none=True, exclude_unset=True)
+
+ try:
+ yield f"data: {chunk}\n\n"
+ except Exception as e:
+ yield f"data: {str(e)}\n\n"
+
+ # Streaming is done, yield the [DONE] chunk
+ done_message = "[DONE]"
+ yield f"data: {done_message}\n\n"
+ except Exception as e:
+ verbose_proxy_logger.exception(
+ "litellm.proxy.proxy_server.async_data_generator(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ await proxy_logging_obj.post_call_failure_hook(
+ user_api_key_dict=user_api_key_dict,
+ original_exception=e,
+ request_data=request_data,
+ )
+ verbose_proxy_logger.debug(
+ f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`"
+ )
+
+ if isinstance(e, HTTPException):
+ raise e
+ elif isinstance(e, StreamingCallbackError):
+ error_msg = str(e)
+ else:
+ error_traceback = traceback.format_exc()
+ error_msg = f"{str(e)}\n\n{error_traceback}"
+
+ proxy_exception = ProxyException(
+ message=getattr(e, "message", error_msg),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", 500),
+ )
+ error_returned = json.dumps({"error": proxy_exception.to_dict()})
+ yield f"data: {error_returned}\n\n"
+
+
+def select_data_generator(
+ response, user_api_key_dict: UserAPIKeyAuth, request_data: dict
+):
+ return async_data_generator(
+ response=response,
+ user_api_key_dict=user_api_key_dict,
+ request_data=request_data,
+ )
+
+
+def get_litellm_model_info(model: dict = {}):
+ model_info = model.get("model_info", {})
+ model_to_lookup = model.get("litellm_params", {}).get("model", None)
+ try:
+ if "azure" in model_to_lookup:
+ model_to_lookup = model_info.get("base_model", None)
+ litellm_model_info = litellm.get_model_info(model_to_lookup)
+ return litellm_model_info
+ except Exception:
+ # this should not block returning on /model/info
+ # if litellm does not have info on the model it should return {}
+ return {}
+
+
+def on_backoff(details):
+ # The 'tries' key in the details dictionary contains the number of completed tries
+ verbose_proxy_logger.debug("Backing off... this was attempt # %s", details["tries"])
+
+
+def giveup(e):
+ result = not (
+ isinstance(e, ProxyException)
+ and getattr(e, "message", None) is not None
+ and isinstance(e.message, str)
+ and "Max parallel request limit reached" in e.message
+ )
+
+ if (
+ general_settings.get("disable_retry_on_max_parallel_request_limit_error")
+ is True
+ ):
+ return True # giveup if queuing max parallel request limits is disabled
+
+ if result:
+ verbose_proxy_logger.info(json.dumps({"event": "giveup", "exception": str(e)}))
+ return result
+
+
+class ProxyStartupEvent:
+ @classmethod
+ def _initialize_startup_logging(
+ cls,
+ llm_router: Optional[Router],
+ proxy_logging_obj: ProxyLogging,
+ redis_usage_cache: Optional[RedisCache],
+ ):
+ """Initialize logging and alerting on startup"""
+ ## COST TRACKING ##
+ cost_tracking()
+
+ proxy_logging_obj.startup_event(
+ llm_router=llm_router, redis_usage_cache=redis_usage_cache
+ )
+
+ @classmethod
+ def _initialize_jwt_auth(
+ cls,
+ general_settings: dict,
+ prisma_client: Optional[PrismaClient],
+ user_api_key_cache: DualCache,
+ ):
+ """Initialize JWT auth on startup"""
+ if general_settings.get("litellm_jwtauth", None) is not None:
+ for k, v in general_settings["litellm_jwtauth"].items():
+ if isinstance(v, str) and v.startswith("os.environ/"):
+ general_settings["litellm_jwtauth"][k] = get_secret(v)
+ litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"])
+ else:
+ litellm_jwtauth = LiteLLM_JWTAuth()
+ jwt_handler.update_environment(
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ litellm_jwtauth=litellm_jwtauth,
+ )
+
+ @classmethod
+ def _add_proxy_budget_to_db(cls, litellm_proxy_budget_name: str):
+ """Adds a global proxy budget to db"""
+ if litellm.budget_duration is None:
+ raise Exception(
+ "budget_duration not set on Proxy. budget_duration is required to use max_budget."
+ )
+
+ # add proxy budget to db in the user table
+ asyncio.create_task(
+ generate_key_helper_fn( # type: ignore
+ request_type="user",
+ user_id=litellm_proxy_budget_name,
+ duration=None,
+ models=[],
+ aliases={},
+ config={},
+ spend=0,
+ max_budget=litellm.max_budget,
+ budget_duration=litellm.budget_duration,
+ query_type="update_data",
+ update_key_values={
+ "max_budget": litellm.max_budget,
+ "budget_duration": litellm.budget_duration,
+ },
+ )
+ )
+
+ @classmethod
+ async def initialize_scheduled_background_jobs(
+ cls,
+ general_settings: dict,
+ prisma_client: PrismaClient,
+ proxy_budget_rescheduler_min_time: int,
+ proxy_budget_rescheduler_max_time: int,
+ proxy_batch_write_at: int,
+ proxy_logging_obj: ProxyLogging,
+ ):
+ """Initializes scheduled background jobs"""
+ global store_model_in_db
+ scheduler = AsyncIOScheduler()
+ interval = random.randint(
+ proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time
+ ) # random interval, so multiple workers avoid resetting budget at the same time
+ batch_writing_interval = random.randint(
+ proxy_batch_write_at - 3, proxy_batch_write_at + 3
+ ) # random interval, so multiple workers avoid batch writing at the same time
+
+ ### RESET BUDGET ###
+ if general_settings.get("disable_reset_budget", False) is False:
+ budget_reset_job = ResetBudgetJob(
+ proxy_logging_obj=proxy_logging_obj,
+ prisma_client=prisma_client,
+ )
+ scheduler.add_job(
+ budget_reset_job.reset_budget,
+ "interval",
+ seconds=interval,
+ )
+
+ ### UPDATE SPEND ###
+ scheduler.add_job(
+ update_spend,
+ "interval",
+ seconds=batch_writing_interval,
+ args=[prisma_client, db_writer_client, proxy_logging_obj],
+ )
+
+ ### ADD NEW MODELS ###
+ store_model_in_db = (
+ get_secret_bool("STORE_MODEL_IN_DB", store_model_in_db) or store_model_in_db
+ )
+
+ if store_model_in_db is True:
+ scheduler.add_job(
+ proxy_config.add_deployment,
+ "interval",
+ seconds=10,
+ args=[prisma_client, proxy_logging_obj],
+ )
+
+ # this will load all existing models on proxy startup
+ await proxy_config.add_deployment(
+ prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
+ )
+
+ ### GET STORED CREDENTIALS ###
+ scheduler.add_job(
+ proxy_config.get_credentials,
+ "interval",
+ seconds=10,
+ args=[prisma_client],
+ )
+ await proxy_config.get_credentials(prisma_client=prisma_client)
+ if (
+ proxy_logging_obj is not None
+ and proxy_logging_obj.slack_alerting_instance.alerting is not None
+ and prisma_client is not None
+ ):
+ print("Alerting: Initializing Weekly/Monthly Spend Reports") # noqa
+ ### Schedule weekly/monthly spend reports ###
+ ### Schedule spend reports ###
+ spend_report_frequency: str = (
+ general_settings.get("spend_report_frequency", "7d") or "7d"
+ )
+
+ # Parse the frequency
+ days = int(spend_report_frequency[:-1])
+ if spend_report_frequency[-1].lower() != "d":
+ raise ValueError(
+ "spend_report_frequency must be specified in days, e.g., '1d', '7d'"
+ )
+
+ scheduler.add_job(
+ proxy_logging_obj.slack_alerting_instance.send_weekly_spend_report,
+ "interval",
+ days=days,
+ next_run_time=datetime.now()
+ + timedelta(seconds=10), # Start 10 seconds from now
+ args=[spend_report_frequency],
+ )
+
+ scheduler.add_job(
+ proxy_logging_obj.slack_alerting_instance.send_monthly_spend_report,
+ "cron",
+ day=1,
+ )
+
+ # Beta Feature - only used when prometheus api is in .env
+ if os.getenv("PROMETHEUS_URL"):
+ from zoneinfo import ZoneInfo
+
+ scheduler.add_job(
+ proxy_logging_obj.slack_alerting_instance.send_fallback_stats_from_prometheus,
+ "cron",
+ hour=9,
+ minute=0,
+ timezone=ZoneInfo("America/Los_Angeles"), # Pacific Time
+ )
+ await proxy_logging_obj.slack_alerting_instance.send_fallback_stats_from_prometheus()
+
+ scheduler.start()
+
+ @classmethod
+ async def _setup_prisma_client(
+ cls,
+ database_url: Optional[str],
+ proxy_logging_obj: ProxyLogging,
+ user_api_key_cache: DualCache,
+ ) -> Optional[PrismaClient]:
+ """
+ - Sets up prisma client
+ - Adds necessary views to proxy
+ """
+ prisma_client: Optional[PrismaClient] = None
+ if database_url is not None:
+ try:
+ prisma_client = PrismaClient(
+ database_url=database_url, proxy_logging_obj=proxy_logging_obj
+ )
+ except Exception as e:
+ raise e
+
+ await prisma_client.connect()
+
+ ## Add necessary views to proxy ##
+ asyncio.create_task(
+ prisma_client.check_view_exists()
+ ) # check if all necessary views exist. Don't block execution
+
+ asyncio.create_task(
+ prisma_client._set_spend_logs_row_count_in_proxy_state()
+ ) # set the spend logs row count in proxy state. Don't block execution
+
+ # run a health check to ensure the DB is ready
+ if (
+ get_secret_bool("DISABLE_PRISMA_HEALTH_CHECK_ON_STARTUP", False)
+ is not True
+ ):
+ await prisma_client.health_check()
+ return prisma_client
+
+ @classmethod
+ def _init_dd_tracer(cls):
+ """
+ Initialize dd tracer - if `USE_DDTRACE=true` in .env
+
+ DD tracer is used to trace Python applications.
+ Doc: https://docs.datadoghq.com/tracing/trace_collection/automatic_instrumentation/dd_libraries/python/
+ """
+ from litellm.litellm_core_utils.dd_tracing import _should_use_dd_tracer
+
+ if _should_use_dd_tracer():
+ import ddtrace
+
+ ddtrace.patch_all(logging=True, openai=False)
+
+
+#### API ENDPOINTS ####
+@router.get(
+ "/v1/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"]
+)
+@router.get(
+ "/models", dependencies=[Depends(user_api_key_auth)], tags=["model management"]
+) # if project requires model list
+async def model_list(
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+ return_wildcard_routes: Optional[bool] = False,
+ team_id: Optional[str] = None,
+):
+ """
+ Use `/model/info` - to get detailed model information, example - pricing, mode, etc.
+
+ This is just for compatibility with openai projects like aider.
+ """
+ global llm_model_list, general_settings, llm_router, prisma_client, user_api_key_cache, proxy_logging_obj
+ all_models = []
+ model_access_groups: Dict[str, List[str]] = defaultdict(list)
+ ## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
+ if llm_router is None:
+ proxy_model_list = []
+ else:
+ proxy_model_list = llm_router.get_model_names()
+ model_access_groups = llm_router.get_model_access_groups()
+ key_models = get_key_models(
+ user_api_key_dict=user_api_key_dict,
+ proxy_model_list=proxy_model_list,
+ model_access_groups=model_access_groups,
+ )
+
+ team_models: List[str] = user_api_key_dict.team_models
+
+ if team_id:
+ team_object = await get_team_object(
+ team_id=team_id,
+ prisma_client=prisma_client,
+ user_api_key_cache=user_api_key_cache,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+ validate_membership(user_api_key_dict=user_api_key_dict, team_table=team_object)
+ team_models = team_object.models
+
+ team_models = get_team_models(
+ team_models=team_models,
+ proxy_model_list=proxy_model_list,
+ model_access_groups=model_access_groups,
+ )
+
+ all_models = get_complete_model_list(
+ key_models=key_models if not team_models else [],
+ team_models=team_models,
+ proxy_model_list=proxy_model_list,
+ user_model=user_model,
+ infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
+ return_wildcard_routes=return_wildcard_routes,
+ )
+
+ return dict(
+ data=[
+ {
+ "id": model,
+ "object": "model",
+ "created": 1677610602,
+ "owned_by": "openai",
+ }
+ for model in all_models
+ ],
+ object="list",
+ )
+
+
+@router.post(
+ "/v1/chat/completions",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["chat/completions"],
+)
+@router.post(
+ "/chat/completions",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["chat/completions"],
+)
+@router.post(
+ "/engines/{model:path}/chat/completions",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["chat/completions"],
+)
+@router.post(
+ "/openai/deployments/{model:path}/chat/completions",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["chat/completions"],
+ responses={200: {"description": "Successful response"}, **ERROR_RESPONSES},
+) # azure compatible endpoint
+@backoff.on_exception(
+ backoff.expo,
+ Exception, # base exception to catch for the backoff
+ max_tries=global_max_parallel_request_retries, # maximum number of retries
+ max_time=global_max_parallel_request_retry_timeout, # maximum total time to retry for
+ on_backoff=on_backoff, # specifying the function to call on backoff
+ giveup=giveup,
+ logger=verbose_proxy_logger,
+)
+async def chat_completion( # noqa: PLR0915
+ request: Request,
+ fastapi_response: Response,
+ model: Optional[str] = None,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+
+ Follows the exact same API spec as `OpenAI's Chat API https://platform.openai.com/docs/api-reference/chat`
+
+ ```bash
+ curl -X POST http://localhost:4000/v1/chat/completions \
+
+ -H "Content-Type: application/json" \
+
+ -H "Authorization: Bearer sk-1234" \
+
+ -d '{
+ "model": "gpt-4o",
+ "messages": [
+ {
+ "role": "user",
+ "content": "Hello!"
+ }
+ ]
+ }'
+ ```
+
+ """
+ global general_settings, user_debug, proxy_logging_obj, llm_model_list
+ global user_temperature, user_request_timeout, user_max_tokens, user_api_base
+ data = await _read_request_body(request=request)
+ base_llm_response_processor = ProxyBaseLLMRequestProcessing(data=data)
+ try:
+ return await base_llm_response_processor.base_process_llm_request(
+ request=request,
+ fastapi_response=fastapi_response,
+ user_api_key_dict=user_api_key_dict,
+ route_type="acompletion",
+ proxy_logging_obj=proxy_logging_obj,
+ llm_router=llm_router,
+ general_settings=general_settings,
+ proxy_config=proxy_config,
+ select_data_generator=select_data_generator,
+ model=model,
+ user_model=user_model,
+ user_temperature=user_temperature,
+ user_request_timeout=user_request_timeout,
+ user_max_tokens=user_max_tokens,
+ user_api_base=user_api_base,
+ version=version,
+ )
+ except RejectedRequestError as e:
+ _data = e.request_data
+ await proxy_logging_obj.post_call_failure_hook(
+ user_api_key_dict=user_api_key_dict,
+ original_exception=e,
+ request_data=_data,
+ )
+ _chat_response = litellm.ModelResponse()
+ _chat_response.choices[0].message.content = e.message # type: ignore
+
+ if data.get("stream", None) is not None and data["stream"] is True:
+ _iterator = litellm.utils.ModelResponseIterator(
+ model_response=_chat_response, convert_to_delta=True
+ )
+ _streaming_response = litellm.CustomStreamWrapper(
+ completion_stream=_iterator,
+ model=data.get("model", ""),
+ custom_llm_provider="cached_response",
+ logging_obj=data.get("litellm_logging_obj", None),
+ )
+ selected_data_generator = select_data_generator(
+ response=_streaming_response,
+ user_api_key_dict=user_api_key_dict,
+ request_data=_data,
+ )
+
+ return StreamingResponse(
+ selected_data_generator,
+ media_type="text/event-stream",
+ )
+ _usage = litellm.Usage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
+ _chat_response.usage = _usage # type: ignore
+ return _chat_response
+ except Exception as e:
+ raise await base_llm_response_processor._handle_llm_api_exception(
+ e=e,
+ user_api_key_dict=user_api_key_dict,
+ proxy_logging_obj=proxy_logging_obj,
+ )
+
+
+@router.post(
+ "/v1/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"]
+)
+@router.post(
+ "/completions", dependencies=[Depends(user_api_key_auth)], tags=["completions"]
+)
+@router.post(
+ "/engines/{model:path}/completions",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["completions"],
+)
+@router.post(
+ "/openai/deployments/{model:path}/completions",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["completions"],
+)
+async def completion( # noqa: PLR0915
+ request: Request,
+ fastapi_response: Response,
+ model: Optional[str] = None,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ Follows the exact same API spec as `OpenAI's Completions API https://platform.openai.com/docs/api-reference/completions`
+
+ ```bash
+ curl -X POST http://localhost:4000/v1/completions \
+
+ -H "Content-Type: application/json" \
+
+ -H "Authorization: Bearer sk-1234" \
+
+ -d '{
+ "model": "gpt-3.5-turbo-instruct",
+ "prompt": "Once upon a time",
+ "max_tokens": 50,
+ "temperature": 0.7
+ }'
+ ```
+ """
+ global user_temperature, user_request_timeout, user_max_tokens, user_api_base
+ data = {}
+ try:
+ data = await _read_request_body(request=request)
+
+ data["model"] = (
+ general_settings.get("completion_model", None) # server default
+ or user_model # model name passed via cli args
+ or model # for azure deployments
+ or data.get("model", None)
+ )
+ if user_model:
+ data["model"] = user_model
+
+ data = await add_litellm_data_to_request(
+ data=data,
+ request=request,
+ general_settings=general_settings,
+ user_api_key_dict=user_api_key_dict,
+ version=version,
+ proxy_config=proxy_config,
+ )
+
+ # override with user settings, these are params passed via cli
+ if user_temperature:
+ data["temperature"] = user_temperature
+ if user_request_timeout:
+ data["request_timeout"] = user_request_timeout
+ if user_max_tokens:
+ data["max_tokens"] = user_max_tokens
+ if user_api_base:
+ data["api_base"] = user_api_base
+
+ ### MODEL ALIAS MAPPING ###
+ # check if model name in model alias map
+ # get the actual model name
+ if data["model"] in litellm.model_alias_map:
+ data["model"] = litellm.model_alias_map[data["model"]]
+
+ ### CALL HOOKS ### - modify incoming data before calling the model
+ data = await proxy_logging_obj.pre_call_hook( # type: ignore
+ user_api_key_dict=user_api_key_dict, data=data, call_type="text_completion"
+ )
+
+ ### ROUTE THE REQUESTs ###
+ llm_call = await route_request(
+ data=data,
+ route_type="atext_completion",
+ llm_router=llm_router,
+ user_model=user_model,
+ )
+
+ # Await the llm_response task
+ response = await llm_call
+
+ hidden_params = getattr(response, "_hidden_params", {}) or {}
+ model_id = hidden_params.get("model_id", None) or ""
+ cache_key = hidden_params.get("cache_key", None) or ""
+ api_base = hidden_params.get("api_base", None) or ""
+ response_cost = hidden_params.get("response_cost", None) or ""
+ litellm_call_id = hidden_params.get("litellm_call_id", None) or ""
+
+ ### ALERTING ###
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
+
+ verbose_proxy_logger.debug("final response: %s", response)
+ if (
+ "stream" in data and data["stream"] is True
+ ): # use generate_responses to stream responses
+ custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
+ user_api_key_dict=user_api_key_dict,
+ call_id=litellm_call_id,
+ model_id=model_id,
+ cache_key=cache_key,
+ api_base=api_base,
+ version=version,
+ response_cost=response_cost,
+ hidden_params=hidden_params,
+ request_data=data,
+ )
+ selected_data_generator = select_data_generator(
+ response=response,
+ user_api_key_dict=user_api_key_dict,
+ request_data=data,
+ )
+
+ return StreamingResponse(
+ selected_data_generator,
+ media_type="text/event-stream",
+ headers=custom_headers,
+ )
+ ### CALL HOOKS ### - modify outgoing data
+ response = await proxy_logging_obj.post_call_success_hook(
+ data=data, user_api_key_dict=user_api_key_dict, response=response # type: ignore
+ )
+
+ fastapi_response.headers.update(
+ ProxyBaseLLMRequestProcessing.get_custom_headers(
+ user_api_key_dict=user_api_key_dict,
+ call_id=litellm_call_id,
+ model_id=model_id,
+ cache_key=cache_key,
+ api_base=api_base,
+ version=version,
+ response_cost=response_cost,
+ request_data=data,
+ hidden_params=hidden_params,
+ )
+ )
+ await check_response_size_is_safe(response=response)
+ return response
+ except RejectedRequestError as e:
+ _data = e.request_data
+ await proxy_logging_obj.post_call_failure_hook(
+ user_api_key_dict=user_api_key_dict,
+ original_exception=e,
+ request_data=_data,
+ )
+ if _data.get("stream", None) is not None and _data["stream"] is True:
+ _chat_response = litellm.ModelResponse()
+ _usage = litellm.Usage(
+ prompt_tokens=0,
+ completion_tokens=0,
+ total_tokens=0,
+ )
+ _chat_response.usage = _usage # type: ignore
+ _chat_response.choices[0].message.content = e.message # type: ignore
+ _iterator = litellm.utils.ModelResponseIterator(
+ model_response=_chat_response, convert_to_delta=True
+ )
+ _streaming_response = litellm.TextCompletionStreamWrapper(
+ completion_stream=_iterator,
+ model=_data.get("model", ""),
+ )
+
+ selected_data_generator = select_data_generator(
+ response=_streaming_response,
+ user_api_key_dict=user_api_key_dict,
+ request_data=data,
+ )
+
+ return StreamingResponse(
+ selected_data_generator,
+ media_type="text/event-stream",
+ headers={},
+ )
+ else:
+ _response = litellm.TextCompletionResponse()
+ _response.choices[0].text = e.message
+ return _response
+ except Exception as e:
+ await proxy_logging_obj.post_call_failure_hook(
+ user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
+ )
+ verbose_proxy_logger.exception(
+ "litellm.proxy.proxy_server.completion(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ error_msg = f"{str(e)}"
+ raise ProxyException(
+ message=getattr(e, "message", error_msg),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ openai_code=getattr(e, "code", None),
+ code=getattr(e, "status_code", 500),
+ )
+
+
+@router.post(
+ "/v1/embeddings",
+ dependencies=[Depends(user_api_key_auth)],
+ response_class=ORJSONResponse,
+ tags=["embeddings"],
+)
+@router.post(
+ "/embeddings",
+ dependencies=[Depends(user_api_key_auth)],
+ response_class=ORJSONResponse,
+ tags=["embeddings"],
+)
+@router.post(
+ "/engines/{model:path}/embeddings",
+ dependencies=[Depends(user_api_key_auth)],
+ response_class=ORJSONResponse,
+ tags=["embeddings"],
+) # azure compatible endpoint
+@router.post(
+ "/openai/deployments/{model:path}/embeddings",
+ dependencies=[Depends(user_api_key_auth)],
+ response_class=ORJSONResponse,
+ tags=["embeddings"],
+) # azure compatible endpoint
+async def embeddings( # noqa: PLR0915
+ request: Request,
+ fastapi_response: Response,
+ model: Optional[str] = None,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ Follows the exact same API spec as `OpenAI's Embeddings API https://platform.openai.com/docs/api-reference/embeddings`
+
+ ```bash
+ curl -X POST http://localhost:4000/v1/embeddings \
+
+ -H "Content-Type: application/json" \
+
+ -H "Authorization: Bearer sk-1234" \
+
+ -d '{
+ "model": "text-embedding-ada-002",
+ "input": "The quick brown fox jumps over the lazy dog"
+ }'
+ ```
+
+"""
+ global proxy_logging_obj
+ data: Any = {}
+ try:
+ # Use orjson to parse JSON data, orjson speeds up requests significantly
+ body = await request.body()
+ data = orjson.loads(body)
+
+ verbose_proxy_logger.debug(
+ "Request received by LiteLLM:\n%s",
+ json.dumps(data, indent=4),
+ )
+
+ # Include original request and headers in the data
+ data = await add_litellm_data_to_request(
+ data=data,
+ request=request,
+ general_settings=general_settings,
+ user_api_key_dict=user_api_key_dict,
+ version=version,
+ proxy_config=proxy_config,
+ )
+
+ data["model"] = (
+ general_settings.get("embedding_model", None) # server default
+ or user_model # model name passed via cli args
+ or model # for azure deployments
+ or data.get("model", None) # default passed in http request
+ )
+ if user_model:
+ data["model"] = user_model
+
+ ### MODEL ALIAS MAPPING ###
+ # check if model name in model alias map
+ # get the actual model name
+ if data["model"] in litellm.model_alias_map:
+ data["model"] = litellm.model_alias_map[data["model"]]
+
+ router_model_names = llm_router.model_names if llm_router is not None else []
+ if (
+ "input" in data
+ and isinstance(data["input"], list)
+ and len(data["input"]) > 0
+ and isinstance(data["input"][0], list)
+ and isinstance(data["input"][0][0], int)
+ ): # check if array of tokens passed in
+ # check if non-openai/azure model called - e.g. for langchain integration
+ if llm_model_list is not None and data["model"] in router_model_names:
+ for m in llm_model_list:
+ if m["model_name"] == data["model"] and (
+ m["litellm_params"]["model"] in litellm.open_ai_embedding_models
+ or m["litellm_params"]["model"].startswith("azure/")
+ ):
+ pass
+ else:
+ # non-openai/azure embedding model called with token input
+ input_list = []
+ for i in data["input"]:
+ input_list.append(
+ litellm.decode(model="gpt-3.5-turbo", tokens=i)
+ )
+ data["input"] = input_list
+ break
+
+ ### CALL HOOKS ### - modify incoming data / reject request before calling the model
+ data = await proxy_logging_obj.pre_call_hook(
+ user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings"
+ )
+
+ tasks = []
+ tasks.append(
+ proxy_logging_obj.during_call_hook(
+ data=data,
+ user_api_key_dict=user_api_key_dict,
+ call_type="embeddings",
+ )
+ )
+
+ ## ROUTE TO CORRECT ENDPOINT ##
+ llm_call = await route_request(
+ data=data,
+ route_type="aembedding",
+ llm_router=llm_router,
+ user_model=user_model,
+ )
+ tasks.append(llm_call)
+
+ # wait for call to end
+ llm_responses = asyncio.gather(
+ *tasks
+ ) # run the moderation check in parallel to the actual llm api call
+
+ responses = await llm_responses
+
+ response = responses[1]
+
+ ### ALERTING ###
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
+
+ ### RESPONSE HEADERS ###
+ hidden_params = getattr(response, "_hidden_params", {}) or {}
+ model_id = hidden_params.get("model_id", None) or ""
+ cache_key = hidden_params.get("cache_key", None) or ""
+ api_base = hidden_params.get("api_base", None) or ""
+ response_cost = hidden_params.get("response_cost", None) or ""
+ litellm_call_id = hidden_params.get("litellm_call_id", None) or ""
+ additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
+
+ fastapi_response.headers.update(
+ ProxyBaseLLMRequestProcessing.get_custom_headers(
+ user_api_key_dict=user_api_key_dict,
+ model_id=model_id,
+ cache_key=cache_key,
+ api_base=api_base,
+ version=version,
+ response_cost=response_cost,
+ model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
+ call_id=litellm_call_id,
+ request_data=data,
+ hidden_params=hidden_params,
+ **additional_headers,
+ )
+ )
+ await check_response_size_is_safe(response=response)
+
+ return response
+ except Exception as e:
+ await proxy_logging_obj.post_call_failure_hook(
+ user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
+ )
+ litellm_debug_info = getattr(e, "litellm_debug_info", "")
+ verbose_proxy_logger.debug(
+ "\033[1;31mAn error occurred: %s %s\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`",
+ e,
+ litellm_debug_info,
+ )
+ verbose_proxy_logger.exception(
+ "litellm.proxy.proxy_server.embeddings(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ if isinstance(e, HTTPException):
+ message = get_error_message_str(e)
+ raise ProxyException(
+ message=message,
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+ )
+ else:
+ error_msg = f"{str(e)}"
+ raise ProxyException(
+ message=getattr(e, "message", error_msg),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ openai_code=getattr(e, "code", None),
+ code=getattr(e, "status_code", 500),
+ )
+
+
+@router.post(
+ "/v1/images/generations",
+ dependencies=[Depends(user_api_key_auth)],
+ response_class=ORJSONResponse,
+ tags=["images"],
+)
+@router.post(
+ "/images/generations",
+ dependencies=[Depends(user_api_key_auth)],
+ response_class=ORJSONResponse,
+ tags=["images"],
+)
+async def image_generation(
+ request: Request,
+ fastapi_response: Response,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ global proxy_logging_obj
+ data = {}
+ try:
+ # Use orjson to parse JSON data, orjson speeds up requests significantly
+ body = await request.body()
+ data = orjson.loads(body)
+
+ # Include original request and headers in the data
+ data = await add_litellm_data_to_request(
+ data=data,
+ request=request,
+ general_settings=general_settings,
+ user_api_key_dict=user_api_key_dict,
+ version=version,
+ proxy_config=proxy_config,
+ )
+
+ data["model"] = (
+ general_settings.get("image_generation_model", None) # server default
+ or user_model # model name passed via cli args
+ or data.get("model", None) # default passed in http request
+ )
+ if user_model:
+ data["model"] = user_model
+
+ ### MODEL ALIAS MAPPING ###
+ # check if model name in model alias map
+ # get the actual model name
+ if data["model"] in litellm.model_alias_map:
+ data["model"] = litellm.model_alias_map[data["model"]]
+
+ ### CALL HOOKS ### - modify incoming data / reject request before calling the model
+ data = await proxy_logging_obj.pre_call_hook(
+ user_api_key_dict=user_api_key_dict, data=data, call_type="image_generation"
+ )
+
+ ## ROUTE TO CORRECT ENDPOINT ##
+ llm_call = await route_request(
+ data=data,
+ route_type="aimage_generation",
+ llm_router=llm_router,
+ user_model=user_model,
+ )
+ response = await llm_call
+
+ ### ALERTING ###
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
+ ### RESPONSE HEADERS ###
+ hidden_params = getattr(response, "_hidden_params", {}) or {}
+ model_id = hidden_params.get("model_id", None) or ""
+ cache_key = hidden_params.get("cache_key", None) or ""
+ api_base = hidden_params.get("api_base", None) or ""
+ response_cost = hidden_params.get("response_cost", None) or ""
+ litellm_call_id = hidden_params.get("litellm_call_id", None) or ""
+
+ fastapi_response.headers.update(
+ ProxyBaseLLMRequestProcessing.get_custom_headers(
+ user_api_key_dict=user_api_key_dict,
+ model_id=model_id,
+ cache_key=cache_key,
+ api_base=api_base,
+ version=version,
+ response_cost=response_cost,
+ model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
+ call_id=litellm_call_id,
+ request_data=data,
+ hidden_params=hidden_params,
+ )
+ )
+
+ return response
+ except Exception as e:
+ await proxy_logging_obj.post_call_failure_hook(
+ user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
+ )
+ verbose_proxy_logger.error(
+ "litellm.proxy.proxy_server.image_generation(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ verbose_proxy_logger.debug(traceback.format_exc())
+ if isinstance(e, HTTPException):
+ raise ProxyException(
+ message=getattr(e, "message", str(e)),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+ )
+ else:
+ error_msg = f"{str(e)}"
+ raise ProxyException(
+ message=getattr(e, "message", error_msg),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ openai_code=getattr(e, "code", None),
+ code=getattr(e, "status_code", 500),
+ )
+
+
+@router.post(
+ "/v1/audio/speech",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["audio"],
+)
+@router.post(
+ "/audio/speech",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["audio"],
+)
+async def audio_speech(
+ request: Request,
+ fastapi_response: Response,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ Same params as:
+
+ https://platform.openai.com/docs/api-reference/audio/createSpeech
+ """
+ global proxy_logging_obj
+ data: Dict = {}
+ try:
+ # Use orjson to parse JSON data, orjson speeds up requests significantly
+ body = await request.body()
+ data = orjson.loads(body)
+
+ # Include original request and headers in the data
+ data = await add_litellm_data_to_request(
+ data=data,
+ request=request,
+ general_settings=general_settings,
+ user_api_key_dict=user_api_key_dict,
+ version=version,
+ proxy_config=proxy_config,
+ )
+
+ if data.get("user", None) is None and user_api_key_dict.user_id is not None:
+ data["user"] = user_api_key_dict.user_id
+
+ if user_model:
+ data["model"] = user_model
+
+ ### CALL HOOKS ### - modify incoming data / reject request before calling the model
+ data = await proxy_logging_obj.pre_call_hook(
+ user_api_key_dict=user_api_key_dict, data=data, call_type="image_generation"
+ )
+
+ ## ROUTE TO CORRECT ENDPOINT ##
+ llm_call = await route_request(
+ data=data,
+ route_type="aspeech",
+ llm_router=llm_router,
+ user_model=user_model,
+ )
+ response = await llm_call
+
+ ### ALERTING ###
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
+
+ ### RESPONSE HEADERS ###
+ hidden_params = getattr(response, "_hidden_params", {}) or {}
+ model_id = hidden_params.get("model_id", None) or ""
+ cache_key = hidden_params.get("cache_key", None) or ""
+ api_base = hidden_params.get("api_base", None) or ""
+ response_cost = hidden_params.get("response_cost", None) or ""
+ litellm_call_id = hidden_params.get("litellm_call_id", None) or ""
+
+ # Printing each chunk size
+ async def generate(_response: HttpxBinaryResponseContent):
+ _generator = await _response.aiter_bytes(chunk_size=1024)
+ async for chunk in _generator:
+ yield chunk
+
+ custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers(
+ user_api_key_dict=user_api_key_dict,
+ model_id=model_id,
+ cache_key=cache_key,
+ api_base=api_base,
+ version=version,
+ response_cost=response_cost,
+ model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
+ fastest_response_batch_completion=None,
+ call_id=litellm_call_id,
+ request_data=data,
+ hidden_params=hidden_params,
+ )
+
+ select_data_generator(
+ response=response,
+ user_api_key_dict=user_api_key_dict,
+ request_data=data,
+ )
+ return StreamingResponse(
+ generate(response), media_type="audio/mpeg", headers=custom_headers # type: ignore
+ )
+
+ except Exception as e:
+ verbose_proxy_logger.error(
+ "litellm.proxy.proxy_server.audio_speech(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ verbose_proxy_logger.debug(traceback.format_exc())
+ raise e
+
+
+@router.post(
+ "/v1/audio/transcriptions",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["audio"],
+)
+@router.post(
+ "/audio/transcriptions",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["audio"],
+)
+async def audio_transcriptions(
+ request: Request,
+ fastapi_response: Response,
+ file: UploadFile = File(...),
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ Same params as:
+
+ https://platform.openai.com/docs/api-reference/audio/createTranscription?lang=curl
+ """
+ global proxy_logging_obj
+ data: Dict = {}
+ try:
+ # Use orjson to parse JSON data, orjson speeds up requests significantly
+ form_data = await request.form()
+ data = {key: value for key, value in form_data.items() if key != "file"}
+
+ # Include original request and headers in the data
+ data = await add_litellm_data_to_request(
+ data=data,
+ request=request,
+ general_settings=general_settings,
+ user_api_key_dict=user_api_key_dict,
+ version=version,
+ proxy_config=proxy_config,
+ )
+
+ if data.get("user", None) is None and user_api_key_dict.user_id is not None:
+ data["user"] = user_api_key_dict.user_id
+
+ data["model"] = (
+ general_settings.get("moderation_model", None) # server default
+ or user_model # model name passed via cli args
+ or data.get("model", None) # default passed in http request
+ )
+ if user_model:
+ data["model"] = user_model
+
+ router_model_names = llm_router.model_names if llm_router is not None else []
+
+ if file.filename is None:
+ raise ProxyException(
+ message="File name is None. Please check your file name",
+ code=status.HTTP_400_BAD_REQUEST,
+ type="bad_request",
+ param="file",
+ )
+
+ # Check if File can be read in memory before reading
+ check_file_size_under_limit(
+ request_data=data,
+ file=file,
+ router_model_names=router_model_names,
+ )
+
+ file_content = await file.read()
+ file_object = io.BytesIO(file_content)
+ file_object.name = file.filename
+ data["file"] = file_object
+ try:
+ ### CALL HOOKS ### - modify incoming data / reject request before calling the model
+ data = await proxy_logging_obj.pre_call_hook(
+ user_api_key_dict=user_api_key_dict,
+ data=data,
+ call_type="audio_transcription",
+ )
+
+ ## ROUTE TO CORRECT ENDPOINT ##
+ llm_call = await route_request(
+ data=data,
+ route_type="atranscription",
+ llm_router=llm_router,
+ user_model=user_model,
+ )
+ response = await llm_call
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=str(e))
+ finally:
+ file_object.close() # close the file read in by io library
+
+ ### ALERTING ###
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
+
+ ### RESPONSE HEADERS ###
+ hidden_params = getattr(response, "_hidden_params", {}) or {}
+ model_id = hidden_params.get("model_id", None) or ""
+ cache_key = hidden_params.get("cache_key", None) or ""
+ api_base = hidden_params.get("api_base", None) or ""
+ response_cost = hidden_params.get("response_cost", None) or ""
+ litellm_call_id = hidden_params.get("litellm_call_id", None) or ""
+ additional_headers: dict = hidden_params.get("additional_headers", {}) or {}
+
+ fastapi_response.headers.update(
+ ProxyBaseLLMRequestProcessing.get_custom_headers(
+ user_api_key_dict=user_api_key_dict,
+ model_id=model_id,
+ cache_key=cache_key,
+ api_base=api_base,
+ version=version,
+ response_cost=response_cost,
+ model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
+ call_id=litellm_call_id,
+ request_data=data,
+ hidden_params=hidden_params,
+ **additional_headers,
+ )
+ )
+
+ return response
+ except Exception as e:
+ await proxy_logging_obj.post_call_failure_hook(
+ user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
+ )
+ verbose_proxy_logger.exception(
+ "litellm.proxy.proxy_server.audio_transcription(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ if isinstance(e, HTTPException):
+ raise ProxyException(
+ message=getattr(e, "message", str(e.detail)),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+ )
+ else:
+ error_msg = f"{str(e)}"
+ raise ProxyException(
+ message=getattr(e, "message", error_msg),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ openai_code=getattr(e, "code", None),
+ code=getattr(e, "status_code", 500),
+ )
+
+
+######################################################################
+
+# /v1/realtime Endpoints
+
+######################################################################
+from fastapi import FastAPI, WebSocket, WebSocketDisconnect
+
+from litellm import _arealtime
+
+
+@app.websocket("/v1/realtime")
+@app.websocket("/realtime")
+async def websocket_endpoint(
+ websocket: WebSocket,
+ model: str,
+ user_api_key_dict=Depends(user_api_key_auth_websocket),
+):
+ import websockets
+
+ await websocket.accept()
+
+ data = {
+ "model": model,
+ "websocket": websocket,
+ }
+
+ ### ROUTE THE REQUEST ###
+ try:
+ llm_call = await route_request(
+ data=data,
+ route_type="_arealtime",
+ llm_router=llm_router,
+ user_model=user_model,
+ )
+
+ await llm_call
+ except websockets.exceptions.InvalidStatusCode as e: # type: ignore
+ verbose_proxy_logger.exception("Invalid status code")
+ await websocket.close(code=e.status_code, reason="Invalid status code")
+ except Exception:
+ verbose_proxy_logger.exception("Internal server error")
+ await websocket.close(code=1011, reason="Internal server error")
+
+
+######################################################################
+
+# /v1/assistant Endpoints
+
+
+######################################################################
+
+
+@router.get(
+ "/v1/assistants",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["assistants"],
+)
+@router.get(
+ "/assistants",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["assistants"],
+)
+async def get_assistants(
+ request: Request,
+ fastapi_response: Response,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ Returns a list of assistants.
+
+ API Reference docs - https://platform.openai.com/docs/api-reference/assistants/listAssistants
+ """
+ global proxy_logging_obj
+ data: Dict = {}
+ try:
+ # Use orjson to parse JSON data, orjson speeds up requests significantly
+ await request.body()
+
+ # Include original request and headers in the data
+ data = await add_litellm_data_to_request(
+ data=data,
+ request=request,
+ general_settings=general_settings,
+ user_api_key_dict=user_api_key_dict,
+ version=version,
+ proxy_config=proxy_config,
+ )
+
+ # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
+ if llm_router is None:
+ raise HTTPException(
+ status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
+ )
+ response = await llm_router.aget_assistants(**data)
+
+ ### ALERTING ###
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
+
+ ### RESPONSE HEADERS ###
+ hidden_params = getattr(response, "_hidden_params", {}) or {}
+ model_id = hidden_params.get("model_id", None) or ""
+ cache_key = hidden_params.get("cache_key", None) or ""
+ api_base = hidden_params.get("api_base", None) or ""
+
+ fastapi_response.headers.update(
+ ProxyBaseLLMRequestProcessing.get_custom_headers(
+ user_api_key_dict=user_api_key_dict,
+ model_id=model_id,
+ cache_key=cache_key,
+ api_base=api_base,
+ version=version,
+ model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
+ request_data=data,
+ hidden_params=hidden_params,
+ )
+ )
+
+ return response
+ except Exception as e:
+ await proxy_logging_obj.post_call_failure_hook(
+ user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
+ )
+ verbose_proxy_logger.error(
+ "litellm.proxy.proxy_server.get_assistants(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ verbose_proxy_logger.debug(traceback.format_exc())
+ if isinstance(e, HTTPException):
+ raise ProxyException(
+ message=getattr(e, "message", str(e.detail)),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+ )
+ else:
+ error_msg = f"{str(e)}"
+ raise ProxyException(
+ message=getattr(e, "message", error_msg),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ openai_code=getattr(e, "code", None),
+ code=getattr(e, "status_code", 500),
+ )
+
+
+@router.post(
+ "/v1/assistants",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["assistants"],
+)
+@router.post(
+ "/assistants",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["assistants"],
+)
+async def create_assistant(
+ request: Request,
+ fastapi_response: Response,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ Create assistant
+
+ API Reference docs - https://platform.openai.com/docs/api-reference/assistants/createAssistant
+ """
+ global proxy_logging_obj
+ data = {} # ensure data always dict
+ try:
+ # Use orjson to parse JSON data, orjson speeds up requests significantly
+ body = await request.body()
+ data = orjson.loads(body)
+
+ # Include original request and headers in the data
+ data = await add_litellm_data_to_request(
+ data=data,
+ request=request,
+ general_settings=general_settings,
+ user_api_key_dict=user_api_key_dict,
+ version=version,
+ proxy_config=proxy_config,
+ )
+
+ # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
+ if llm_router is None:
+ raise HTTPException(
+ status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
+ )
+ response = await llm_router.acreate_assistants(**data)
+
+ ### ALERTING ###
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
+
+ ### RESPONSE HEADERS ###
+ hidden_params = getattr(response, "_hidden_params", {}) or {}
+ model_id = hidden_params.get("model_id", None) or ""
+ cache_key = hidden_params.get("cache_key", None) or ""
+ api_base = hidden_params.get("api_base", None) or ""
+
+ fastapi_response.headers.update(
+ ProxyBaseLLMRequestProcessing.get_custom_headers(
+ user_api_key_dict=user_api_key_dict,
+ model_id=model_id,
+ cache_key=cache_key,
+ api_base=api_base,
+ version=version,
+ model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
+ request_data=data,
+ hidden_params=hidden_params,
+ )
+ )
+
+ return response
+ except Exception as e:
+ await proxy_logging_obj.post_call_failure_hook(
+ user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
+ )
+ verbose_proxy_logger.error(
+ "litellm.proxy.proxy_server.create_assistant(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ verbose_proxy_logger.debug(traceback.format_exc())
+ if isinstance(e, HTTPException):
+ raise ProxyException(
+ message=getattr(e, "message", str(e.detail)),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+ )
+ else:
+ error_msg = f"{str(e)}"
+ raise ProxyException(
+ message=getattr(e, "message", error_msg),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "code", getattr(e, "status_code", 500)),
+ )
+
+
+@router.delete(
+ "/v1/assistants/{assistant_id:path}",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["assistants"],
+)
+@router.delete(
+ "/assistants/{assistant_id:path}",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["assistants"],
+)
+async def delete_assistant(
+ request: Request,
+ assistant_id: str,
+ fastapi_response: Response,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ Delete assistant
+
+ API Reference docs - https://platform.openai.com/docs/api-reference/assistants/createAssistant
+ """
+ global proxy_logging_obj
+ data: Dict = {}
+ try:
+ # Use orjson to parse JSON data, orjson speeds up requests significantly
+
+ # Include original request and headers in the data
+ data = await add_litellm_data_to_request(
+ data=data,
+ request=request,
+ general_settings=general_settings,
+ user_api_key_dict=user_api_key_dict,
+ version=version,
+ proxy_config=proxy_config,
+ )
+
+ # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
+ if llm_router is None:
+ raise HTTPException(
+ status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
+ )
+ response = await llm_router.adelete_assistant(assistant_id=assistant_id, **data)
+
+ ### ALERTING ###
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
+
+ ### RESPONSE HEADERS ###
+ hidden_params = getattr(response, "_hidden_params", {}) or {}
+ model_id = hidden_params.get("model_id", None) or ""
+ cache_key = hidden_params.get("cache_key", None) or ""
+ api_base = hidden_params.get("api_base", None) or ""
+
+ fastapi_response.headers.update(
+ ProxyBaseLLMRequestProcessing.get_custom_headers(
+ user_api_key_dict=user_api_key_dict,
+ model_id=model_id,
+ cache_key=cache_key,
+ api_base=api_base,
+ version=version,
+ model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
+ request_data=data,
+ hidden_params=hidden_params,
+ )
+ )
+
+ return response
+ except Exception as e:
+ await proxy_logging_obj.post_call_failure_hook(
+ user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
+ )
+ verbose_proxy_logger.error(
+ "litellm.proxy.proxy_server.delete_assistant(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ verbose_proxy_logger.debug(traceback.format_exc())
+ if isinstance(e, HTTPException):
+ raise ProxyException(
+ message=getattr(e, "message", str(e.detail)),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+ )
+ else:
+ error_msg = f"{str(e)}"
+ raise ProxyException(
+ message=getattr(e, "message", error_msg),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "code", getattr(e, "status_code", 500)),
+ )
+
+
+@router.post(
+ "/v1/threads",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["assistants"],
+)
+@router.post(
+ "/threads",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["assistants"],
+)
+async def create_threads(
+ request: Request,
+ fastapi_response: Response,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ Create a thread.
+
+ API Reference - https://platform.openai.com/docs/api-reference/threads/createThread
+ """
+ global proxy_logging_obj
+ data: Dict = {}
+ try:
+ # Use orjson to parse JSON data, orjson speeds up requests significantly
+ await request.body()
+
+ # Include original request and headers in the data
+ data = await add_litellm_data_to_request(
+ data=data,
+ request=request,
+ general_settings=general_settings,
+ user_api_key_dict=user_api_key_dict,
+ version=version,
+ proxy_config=proxy_config,
+ )
+
+ # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
+ if llm_router is None:
+ raise HTTPException(
+ status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
+ )
+ response = await llm_router.acreate_thread(**data)
+
+ ### ALERTING ###
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
+
+ ### RESPONSE HEADERS ###
+ hidden_params = getattr(response, "_hidden_params", {}) or {}
+ model_id = hidden_params.get("model_id", None) or ""
+ cache_key = hidden_params.get("cache_key", None) or ""
+ api_base = hidden_params.get("api_base", None) or ""
+
+ fastapi_response.headers.update(
+ ProxyBaseLLMRequestProcessing.get_custom_headers(
+ user_api_key_dict=user_api_key_dict,
+ model_id=model_id,
+ cache_key=cache_key,
+ api_base=api_base,
+ version=version,
+ model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
+ request_data=data,
+ hidden_params=hidden_params,
+ )
+ )
+
+ return response
+ except Exception as e:
+ await proxy_logging_obj.post_call_failure_hook(
+ user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
+ )
+ verbose_proxy_logger.error(
+ "litellm.proxy.proxy_server.create_threads(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ verbose_proxy_logger.debug(traceback.format_exc())
+ if isinstance(e, HTTPException):
+ raise ProxyException(
+ message=getattr(e, "message", str(e.detail)),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+ )
+ else:
+ error_msg = f"{str(e)}"
+ raise ProxyException(
+ message=getattr(e, "message", error_msg),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "code", getattr(e, "status_code", 500)),
+ )
+
+
+@router.get(
+ "/v1/threads/{thread_id}",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["assistants"],
+)
+@router.get(
+ "/threads/{thread_id}",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["assistants"],
+)
+async def get_thread(
+ request: Request,
+ thread_id: str,
+ fastapi_response: Response,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ Retrieves a thread.
+
+ API Reference - https://platform.openai.com/docs/api-reference/threads/getThread
+ """
+ global proxy_logging_obj
+ data: Dict = {}
+ try:
+
+ # Include original request and headers in the data
+ data = await add_litellm_data_to_request(
+ data=data,
+ request=request,
+ general_settings=general_settings,
+ user_api_key_dict=user_api_key_dict,
+ version=version,
+ proxy_config=proxy_config,
+ )
+
+ # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
+ if llm_router is None:
+ raise HTTPException(
+ status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
+ )
+ response = await llm_router.aget_thread(thread_id=thread_id, **data)
+
+ ### ALERTING ###
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
+
+ ### RESPONSE HEADERS ###
+ hidden_params = getattr(response, "_hidden_params", {}) or {}
+ model_id = hidden_params.get("model_id", None) or ""
+ cache_key = hidden_params.get("cache_key", None) or ""
+ api_base = hidden_params.get("api_base", None) or ""
+
+ fastapi_response.headers.update(
+ ProxyBaseLLMRequestProcessing.get_custom_headers(
+ user_api_key_dict=user_api_key_dict,
+ model_id=model_id,
+ cache_key=cache_key,
+ api_base=api_base,
+ version=version,
+ model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
+ request_data=data,
+ hidden_params=hidden_params,
+ )
+ )
+
+ return response
+ except Exception as e:
+ await proxy_logging_obj.post_call_failure_hook(
+ user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
+ )
+ verbose_proxy_logger.error(
+ "litellm.proxy.proxy_server.get_thread(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ verbose_proxy_logger.debug(traceback.format_exc())
+ if isinstance(e, HTTPException):
+ raise ProxyException(
+ message=getattr(e, "message", str(e.detail)),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+ )
+ else:
+ error_msg = f"{str(e)}"
+ raise ProxyException(
+ message=getattr(e, "message", error_msg),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "code", getattr(e, "status_code", 500)),
+ )
+
+
+@router.post(
+ "/v1/threads/{thread_id}/messages",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["assistants"],
+)
+@router.post(
+ "/threads/{thread_id}/messages",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["assistants"],
+)
+async def add_messages(
+ request: Request,
+ thread_id: str,
+ fastapi_response: Response,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ Create a message.
+
+ API Reference - https://platform.openai.com/docs/api-reference/messages/createMessage
+ """
+ global proxy_logging_obj
+ data: Dict = {}
+ try:
+ # Use orjson to parse JSON data, orjson speeds up requests significantly
+ body = await request.body()
+ data = orjson.loads(body)
+
+ # Include original request and headers in the data
+ data = await add_litellm_data_to_request(
+ data=data,
+ request=request,
+ general_settings=general_settings,
+ user_api_key_dict=user_api_key_dict,
+ version=version,
+ proxy_config=proxy_config,
+ )
+
+ # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
+ if llm_router is None:
+ raise HTTPException(
+ status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
+ )
+ response = await llm_router.a_add_message(thread_id=thread_id, **data)
+
+ ### ALERTING ###
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
+
+ ### RESPONSE HEADERS ###
+ hidden_params = getattr(response, "_hidden_params", {}) or {}
+ model_id = hidden_params.get("model_id", None) or ""
+ cache_key = hidden_params.get("cache_key", None) or ""
+ api_base = hidden_params.get("api_base", None) or ""
+
+ fastapi_response.headers.update(
+ ProxyBaseLLMRequestProcessing.get_custom_headers(
+ user_api_key_dict=user_api_key_dict,
+ model_id=model_id,
+ cache_key=cache_key,
+ api_base=api_base,
+ version=version,
+ model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
+ request_data=data,
+ hidden_params=hidden_params,
+ )
+ )
+
+ return response
+ except Exception as e:
+ await proxy_logging_obj.post_call_failure_hook(
+ user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
+ )
+ verbose_proxy_logger.error(
+ "litellm.proxy.proxy_server.add_messages(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ verbose_proxy_logger.debug(traceback.format_exc())
+ if isinstance(e, HTTPException):
+ raise ProxyException(
+ message=getattr(e, "message", str(e.detail)),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+ )
+ else:
+ error_msg = f"{str(e)}"
+ raise ProxyException(
+ message=getattr(e, "message", error_msg),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "code", getattr(e, "status_code", 500)),
+ )
+
+
+@router.get(
+ "/v1/threads/{thread_id}/messages",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["assistants"],
+)
+@router.get(
+ "/threads/{thread_id}/messages",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["assistants"],
+)
+async def get_messages(
+ request: Request,
+ thread_id: str,
+ fastapi_response: Response,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ Returns a list of messages for a given thread.
+
+ API Reference - https://platform.openai.com/docs/api-reference/messages/listMessages
+ """
+ global proxy_logging_obj
+ data: Dict = {}
+ try:
+ # Include original request and headers in the data
+ data = await add_litellm_data_to_request(
+ data=data,
+ request=request,
+ general_settings=general_settings,
+ user_api_key_dict=user_api_key_dict,
+ version=version,
+ proxy_config=proxy_config,
+ )
+
+ # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
+ if llm_router is None:
+ raise HTTPException(
+ status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
+ )
+ response = await llm_router.aget_messages(thread_id=thread_id, **data)
+
+ ### ALERTING ###
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
+
+ ### RESPONSE HEADERS ###
+ hidden_params = getattr(response, "_hidden_params", {}) or {}
+ model_id = hidden_params.get("model_id", None) or ""
+ cache_key = hidden_params.get("cache_key", None) or ""
+ api_base = hidden_params.get("api_base", None) or ""
+
+ fastapi_response.headers.update(
+ ProxyBaseLLMRequestProcessing.get_custom_headers(
+ user_api_key_dict=user_api_key_dict,
+ model_id=model_id,
+ cache_key=cache_key,
+ api_base=api_base,
+ version=version,
+ model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
+ request_data=data,
+ hidden_params=hidden_params,
+ )
+ )
+
+ return response
+ except Exception as e:
+ await proxy_logging_obj.post_call_failure_hook(
+ user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
+ )
+ verbose_proxy_logger.error(
+ "litellm.proxy.proxy_server.get_messages(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ verbose_proxy_logger.debug(traceback.format_exc())
+ if isinstance(e, HTTPException):
+ raise ProxyException(
+ message=getattr(e, "message", str(e.detail)),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+ )
+ else:
+ error_msg = f"{str(e)}"
+ raise ProxyException(
+ message=getattr(e, "message", error_msg),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "code", getattr(e, "status_code", 500)),
+ )
+
+
+@router.post(
+ "/v1/threads/{thread_id}/runs",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["assistants"],
+)
+@router.post(
+ "/threads/{thread_id}/runs",
+ dependencies=[Depends(user_api_key_auth)],
+ tags=["assistants"],
+)
+async def run_thread(
+ request: Request,
+ thread_id: str,
+ fastapi_response: Response,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ Create a run.
+
+ API Reference: https://platform.openai.com/docs/api-reference/runs/createRun
+ """
+ global proxy_logging_obj
+ data: Dict = {}
+ try:
+ body = await request.body()
+ data = orjson.loads(body)
+ # Include original request and headers in the data
+ data = await add_litellm_data_to_request(
+ data=data,
+ request=request,
+ general_settings=general_settings,
+ user_api_key_dict=user_api_key_dict,
+ version=version,
+ proxy_config=proxy_config,
+ )
+
+ # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch
+ if llm_router is None:
+ raise HTTPException(
+ status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
+ )
+ response = await llm_router.arun_thread(thread_id=thread_id, **data)
+
+ if (
+ "stream" in data and data["stream"] is True
+ ): # use generate_responses to stream responses
+ return StreamingResponse(
+ async_assistants_data_generator(
+ user_api_key_dict=user_api_key_dict,
+ response=response,
+ request_data=data,
+ ),
+ media_type="text/event-stream",
+ )
+
+ ### ALERTING ###
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
+
+ ### RESPONSE HEADERS ###
+ hidden_params = getattr(response, "_hidden_params", {}) or {}
+ model_id = hidden_params.get("model_id", None) or ""
+ cache_key = hidden_params.get("cache_key", None) or ""
+ api_base = hidden_params.get("api_base", None) or ""
+
+ fastapi_response.headers.update(
+ ProxyBaseLLMRequestProcessing.get_custom_headers(
+ user_api_key_dict=user_api_key_dict,
+ model_id=model_id,
+ cache_key=cache_key,
+ api_base=api_base,
+ version=version,
+ model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
+ request_data=data,
+ hidden_params=hidden_params,
+ )
+ )
+
+ return response
+ except Exception as e:
+ await proxy_logging_obj.post_call_failure_hook(
+ user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
+ )
+ verbose_proxy_logger.error(
+ "litellm.proxy.proxy_server.run_thread(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ verbose_proxy_logger.debug(traceback.format_exc())
+ if isinstance(e, HTTPException):
+ raise ProxyException(
+ message=getattr(e, "message", str(e.detail)),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+ )
+ else:
+ error_msg = f"{str(e)}"
+ raise ProxyException(
+ message=getattr(e, "message", error_msg),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "code", getattr(e, "status_code", 500)),
+ )
+
+
+@router.post(
+ "/v1/moderations",
+ dependencies=[Depends(user_api_key_auth)],
+ response_class=ORJSONResponse,
+ tags=["moderations"],
+)
+@router.post(
+ "/moderations",
+ dependencies=[Depends(user_api_key_auth)],
+ response_class=ORJSONResponse,
+ tags=["moderations"],
+)
+async def moderations(
+ request: Request,
+ fastapi_response: Response,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ The moderations endpoint is a tool you can use to check whether content complies with an LLM Providers policies.
+
+ Quick Start
+ ```
+ curl --location 'http://0.0.0.0:4000/moderations' \
+ --header 'Content-Type: application/json' \
+ --header 'Authorization: Bearer sk-1234' \
+ --data '{"input": "Sample text goes here", "model": "text-moderation-stable"}'
+ ```
+ """
+ global proxy_logging_obj
+ data: Dict = {}
+ try:
+ # Use orjson to parse JSON data, orjson speeds up requests significantly
+ body = await request.body()
+ data = orjson.loads(body)
+
+ # Include original request and headers in the data
+ data = await add_litellm_data_to_request(
+ data=data,
+ request=request,
+ general_settings=general_settings,
+ user_api_key_dict=user_api_key_dict,
+ version=version,
+ proxy_config=proxy_config,
+ )
+
+ data["model"] = (
+ general_settings.get("moderation_model", None) # server default
+ or user_model # model name passed via cli args
+ or data.get("model") # default passed in http request
+ )
+ if user_model:
+ data["model"] = user_model
+
+ ### CALL HOOKS ### - modify incoming data / reject request before calling the model
+ data = await proxy_logging_obj.pre_call_hook(
+ user_api_key_dict=user_api_key_dict, data=data, call_type="moderation"
+ )
+
+ time.time()
+
+ ## ROUTE TO CORRECT ENDPOINT ##
+ llm_call = await route_request(
+ data=data,
+ route_type="amoderation",
+ llm_router=llm_router,
+ user_model=user_model,
+ )
+ response = await llm_call
+
+ ### ALERTING ###
+ asyncio.create_task(
+ proxy_logging_obj.update_request_status(
+ litellm_call_id=data.get("litellm_call_id", ""), status="success"
+ )
+ )
+
+ ### RESPONSE HEADERS ###
+ hidden_params = getattr(response, "_hidden_params", {}) or {}
+ model_id = hidden_params.get("model_id", None) or ""
+ cache_key = hidden_params.get("cache_key", None) or ""
+ api_base = hidden_params.get("api_base", None) or ""
+
+ fastapi_response.headers.update(
+ ProxyBaseLLMRequestProcessing.get_custom_headers(
+ user_api_key_dict=user_api_key_dict,
+ model_id=model_id,
+ cache_key=cache_key,
+ api_base=api_base,
+ version=version,
+ model_region=getattr(user_api_key_dict, "allowed_model_region", ""),
+ request_data=data,
+ hidden_params=hidden_params,
+ )
+ )
+
+ return response
+ except Exception as e:
+ await proxy_logging_obj.post_call_failure_hook(
+ user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
+ )
+ verbose_proxy_logger.exception(
+ "litellm.proxy.proxy_server.moderations(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ if isinstance(e, HTTPException):
+ raise ProxyException(
+ message=getattr(e, "message", str(e)),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+ )
+ else:
+ error_msg = f"{str(e)}"
+ raise ProxyException(
+ message=getattr(e, "message", error_msg),
+ type=getattr(e, "type", "None"),
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", 500),
+ )
+
+
+#### DEV UTILS ####
+
+# @router.get(
+# "/utils/available_routes",
+# tags=["llm utils"],
+# dependencies=[Depends(user_api_key_auth)],
+# )
+# async def get_available_routes(user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)):
+
+
+@router.post(
+ "/utils/token_counter",
+ tags=["llm utils"],
+ dependencies=[Depends(user_api_key_auth)],
+ response_model=TokenCountResponse,
+)
+async def token_counter(request: TokenCountRequest):
+ """ """
+ from litellm import token_counter
+
+ global llm_router
+
+ prompt = request.prompt
+ messages = request.messages
+ if prompt is None and messages is None:
+ raise HTTPException(
+ status_code=400, detail="prompt or messages must be provided"
+ )
+
+ deployment = None
+ litellm_model_name = None
+ model_info: Optional[ModelMapInfo] = None
+ if llm_router is not None:
+ # get 1 deployment corresponding to the model
+ for _model in llm_router.model_list:
+ if _model["model_name"] == request.model:
+ deployment = _model
+ model_info = llm_router.get_router_model_info(
+ deployment=deployment,
+ received_model_name=request.model,
+ )
+ break
+ if deployment is not None:
+ litellm_model_name = deployment.get("litellm_params", {}).get("model")
+ # remove the custom_llm_provider_prefix in the litellm_model_name
+ if "/" in litellm_model_name:
+ litellm_model_name = litellm_model_name.split("/", 1)[1]
+
+ model_to_use = (
+ litellm_model_name or request.model
+ ) # use litellm model name, if it's not avalable then fallback to request.model
+
+ custom_tokenizer: Optional[CustomHuggingfaceTokenizer] = None
+ if model_info is not None:
+ custom_tokenizer = cast(
+ Optional[CustomHuggingfaceTokenizer],
+ model_info.get("custom_tokenizer", None),
+ )
+ _tokenizer_used = litellm.utils._select_tokenizer(
+ model=model_to_use, custom_tokenizer=custom_tokenizer
+ )
+
+ tokenizer_used = str(_tokenizer_used["type"])
+ total_tokens = token_counter(
+ model=model_to_use,
+ text=prompt,
+ messages=messages,
+ custom_tokenizer=_tokenizer_used, # type: ignore
+ )
+ return TokenCountResponse(
+ total_tokens=total_tokens,
+ request_model=request.model,
+ model_used=model_to_use,
+ tokenizer_type=tokenizer_used,
+ )
+
+
+@router.get(
+ "/utils/supported_openai_params",
+ tags=["llm utils"],
+ dependencies=[Depends(user_api_key_auth)],
+)
+async def supported_openai_params(model: str):
+ """
+ Returns supported openai params for a given litellm model name
+
+ e.g. `gpt-4` vs `gpt-3.5-turbo`
+
+ Example curl:
+ ```
+ curl -X GET --location 'http://localhost:4000/utils/supported_openai_params?model=gpt-3.5-turbo-16k' \
+ --header 'Authorization: Bearer sk-1234'
+ ```
+ """
+ try:
+ model, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model)
+ return {
+ "supported_openai_params": litellm.get_supported_openai_params(
+ model=model, custom_llm_provider=custom_llm_provider
+ )
+ }
+ except Exception:
+ raise HTTPException(
+ status_code=400, detail={"error": "Could not map model={}".format(model)}
+ )
+
+
+@router.post(
+ "/utils/transform_request",
+ tags=["llm utils"],
+ dependencies=[Depends(user_api_key_auth)],
+ response_model=RawRequestTypedDict,
+)
+async def transform_request(request: TransformRequestBody):
+ from litellm.utils import return_raw_request
+
+ return return_raw_request(endpoint=request.call_type, kwargs=request.request_body)
+
+
+@router.get(
+ "/v2/model/info",
+ description="v2 - returns all the models set on the config.yaml, shows 'user_access' = True if the user has access to the model. Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)",
+ tags=["model management"],
+ dependencies=[Depends(user_api_key_auth)],
+ include_in_schema=False,
+)
+async def model_info_v2(
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+ model: Optional[str] = fastapi.Query(
+ None, description="Specify the model name (optional)"
+ ),
+ debug: Optional[bool] = False,
+):
+ """
+ BETA ENDPOINT. Might change unexpectedly. Use `/v1/model/info` for now.
+ """
+ global llm_model_list, general_settings, user_config_file_path, proxy_config, llm_router
+
+ if llm_router is None:
+ raise HTTPException(
+ status_code=500,
+ detail={
+ "error": f"No model list passed, models router={llm_router}. You can add a model through the config.yaml or on the LiteLLM Admin UI."
+ },
+ )
+
+ # Load existing config
+ await proxy_config.get_config()
+ all_models = copy.deepcopy(llm_router.model_list)
+
+ if user_model is not None:
+ # if user does not use a config.yaml, https://github.com/BerriAI/litellm/issues/2061
+ all_models += [user_model]
+
+ # check all models user has access to in user_api_key_dict
+ if len(user_api_key_dict.models) > 0:
+ pass
+
+ if model is not None:
+ all_models = [m for m in all_models if m["model_name"] == model]
+
+ # fill in model info based on config.yaml and litellm model_prices_and_context_window.json
+ for _model in all_models:
+ # provided model_info in config.yaml
+ model_info = _model.get("model_info", {})
+ if debug is True:
+ _openai_client = "None"
+ if llm_router is not None:
+ _openai_client = (
+ llm_router._get_client(
+ deployment=_model, kwargs={}, client_type="async"
+ )
+ or "None"
+ )
+ else:
+ _openai_client = "llm_router_is_None"
+ openai_client = str(_openai_client)
+ _model["openai_client"] = openai_client
+
+ # read litellm model_prices_and_context_window.json to get the following:
+ # input_cost_per_token, output_cost_per_token, max_tokens
+ litellm_model_info = get_litellm_model_info(model=_model)
+
+ # 2nd pass on the model, try seeing if we can find model in litellm model_cost map
+ if litellm_model_info == {}:
+ # use litellm_param model_name to get model_info
+ litellm_params = _model.get("litellm_params", {})
+ litellm_model = litellm_params.get("model", None)
+ try:
+ litellm_model_info = litellm.get_model_info(model=litellm_model)
+ except Exception:
+ litellm_model_info = {}
+ # 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
+ if litellm_model_info == {}:
+ # use litellm_param model_name to get model_info
+ litellm_params = _model.get("litellm_params", {})
+ litellm_model = litellm_params.get("model", None)
+ split_model = litellm_model.split("/")
+ if len(split_model) > 0:
+ litellm_model = split_model[-1]
+ try:
+ litellm_model_info = litellm.get_model_info(
+ model=litellm_model, custom_llm_provider=split_model[0]
+ )
+ except Exception:
+ litellm_model_info = {}
+ for k, v in litellm_model_info.items():
+ if k not in model_info:
+ model_info[k] = v
+ _model["model_info"] = model_info
+ # don't return the api key / vertex credentials
+ # don't return the llm credentials
+ _model["litellm_params"].pop("api_key", None)
+ _model["litellm_params"].pop("vertex_credentials", None)
+ _model["litellm_params"].pop("aws_access_key_id", None)
+ _model["litellm_params"].pop("aws_secret_access_key", None)
+
+ verbose_proxy_logger.debug("all_models: %s", all_models)
+ return {"data": all_models}
+
+
+@router.get(
+ "/model/streaming_metrics",
+ description="View time to first token for models in spend logs",
+ tags=["model management"],
+ include_in_schema=False,
+ dependencies=[Depends(user_api_key_auth)],
+)
+async def model_streaming_metrics(
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+ _selected_model_group: Optional[str] = None,
+ startTime: Optional[datetime] = None,
+ endTime: Optional[datetime] = None,
+):
+ global prisma_client, llm_router
+ if prisma_client is None:
+ raise ProxyException(
+ message=CommonProxyErrors.db_not_connected_error.value,
+ type="internal_error",
+ param="None",
+ code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ )
+
+ startTime = startTime or datetime.now() - timedelta(days=7) # show over past week
+ endTime = endTime or datetime.now()
+
+ is_same_day = startTime.date() == endTime.date()
+ if is_same_day:
+ sql_query = """
+ SELECT
+ api_base,
+ model_group,
+ model,
+ "startTime",
+ request_id,
+ EXTRACT(epoch FROM ("completionStartTime" - "startTime")) AS time_to_first_token
+ FROM
+ "LiteLLM_SpendLogs"
+ WHERE
+ "model_group" = $1 AND "cache_hit" != 'True'
+ AND "completionStartTime" IS NOT NULL
+ AND "completionStartTime" != "endTime"
+ AND DATE("startTime") = DATE($2::timestamp)
+ GROUP BY
+ api_base,
+ model_group,
+ model,
+ request_id
+ ORDER BY
+ time_to_first_token DESC;
+ """
+ else:
+ sql_query = """
+ SELECT
+ api_base,
+ model_group,
+ model,
+ DATE_TRUNC('day', "startTime")::DATE AS day,
+ AVG(EXTRACT(epoch FROM ("completionStartTime" - "startTime"))) AS time_to_first_token
+ FROM
+ "LiteLLM_SpendLogs"
+ WHERE
+ "startTime" BETWEEN $2::timestamp AND $3::timestamp
+ AND "model_group" = $1 AND "cache_hit" != 'True'
+ AND "completionStartTime" IS NOT NULL
+ AND "completionStartTime" != "endTime"
+ GROUP BY
+ api_base,
+ model_group,
+ model,
+ day
+ ORDER BY
+ time_to_first_token DESC;
+ """
+
+ _all_api_bases = set()
+ db_response = await prisma_client.db.query_raw(
+ sql_query, _selected_model_group, startTime, endTime
+ )
+ _daily_entries: dict = {} # {"Jun 23": {"model1": 0.002, "model2": 0.003}}
+ if db_response is not None:
+ for model_data in db_response:
+ _api_base = model_data["api_base"]
+ _model = model_data["model"]
+ time_to_first_token = model_data["time_to_first_token"]
+ unique_key = ""
+ if is_same_day:
+ _request_id = model_data["request_id"]
+ unique_key = _request_id
+ if _request_id not in _daily_entries:
+ _daily_entries[_request_id] = {}
+ else:
+ _day = model_data["day"]
+ unique_key = _day
+ time_to_first_token = model_data["time_to_first_token"]
+ if _day not in _daily_entries:
+ _daily_entries[_day] = {}
+ _combined_model_name = str(_model)
+ if "https://" in _api_base:
+ _combined_model_name = str(_api_base)
+ if "/openai/" in _combined_model_name:
+ _combined_model_name = _combined_model_name.split("/openai/")[0]
+
+ _all_api_bases.add(_combined_model_name)
+
+ _daily_entries[unique_key][_combined_model_name] = time_to_first_token
+
+ """
+ each entry needs to be like this:
+ {
+ date: 'Jun 23',
+ 'gpt-4-https://api.openai.com/v1/': 0.002,
+ 'gpt-43-https://api.openai.com-12/v1/': 0.002,
+ }
+ """
+ # convert daily entries to list of dicts
+
+ response: List[dict] = []
+
+ # sort daily entries by date
+ _daily_entries = dict(sorted(_daily_entries.items(), key=lambda item: item[0]))
+ for day in _daily_entries:
+ entry = {"date": str(day)}
+ for model_key, latency in _daily_entries[day].items():
+ entry[model_key] = latency
+ response.append(entry)
+
+ return {
+ "data": response,
+ "all_api_bases": list(_all_api_bases),
+ }
+
+
+@router.get(
+ "/model/metrics",
+ description="View number of requests & avg latency per model on config.yaml",
+ tags=["model management"],
+ include_in_schema=False,
+ dependencies=[Depends(user_api_key_auth)],
+)
+async def model_metrics(
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+ _selected_model_group: Optional[str] = "gpt-4-32k",
+ startTime: Optional[datetime] = None,
+ endTime: Optional[datetime] = None,
+ api_key: Optional[str] = None,
+ customer: Optional[str] = None,
+):
+ global prisma_client, llm_router
+ if prisma_client is None:
+ raise ProxyException(
+ message="Prisma Client is not initialized",
+ type="internal_error",
+ param="None",
+ code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ )
+ startTime = startTime or datetime.now() - timedelta(days=30)
+ endTime = endTime or datetime.now()
+
+ if api_key is None or api_key == "undefined":
+ api_key = "null"
+
+ if customer is None or customer == "undefined":
+ customer = "null"
+
+ sql_query = """
+ SELECT
+ api_base,
+ model_group,
+ model,
+ DATE_TRUNC('day', "startTime")::DATE AS day,
+ AVG(EXTRACT(epoch FROM ("endTime" - "startTime")) / NULLIF("completion_tokens", 0)) AS avg_latency_per_token
+ FROM
+ "LiteLLM_SpendLogs"
+ WHERE
+ "startTime" >= $2::timestamp AND "startTime" <= $3::timestamp
+ AND "model_group" = $1 AND "cache_hit" != 'True'
+ AND (
+ CASE
+ WHEN $4 != 'null' THEN "api_key" = $4
+ ELSE TRUE
+ END
+ )
+ AND (
+ CASE
+ WHEN $5 != 'null' THEN "end_user" = $5
+ ELSE TRUE
+ END
+ )
+ GROUP BY
+ api_base,
+ model_group,
+ model,
+ day
+ HAVING
+ SUM(completion_tokens) > 0
+ ORDER BY
+ avg_latency_per_token DESC;
+ """
+ _all_api_bases = set()
+ db_response = await prisma_client.db.query_raw(
+ sql_query, _selected_model_group, startTime, endTime, api_key, customer
+ )
+ _daily_entries: dict = {} # {"Jun 23": {"model1": 0.002, "model2": 0.003}}
+
+ if db_response is not None:
+ for model_data in db_response:
+ _api_base = model_data["api_base"]
+ _model = model_data["model"]
+ _day = model_data["day"]
+ _avg_latency_per_token = model_data["avg_latency_per_token"]
+ if _day not in _daily_entries:
+ _daily_entries[_day] = {}
+ _combined_model_name = str(_model)
+ if _api_base is not None and "https://" in _api_base:
+ _combined_model_name = str(_api_base)
+ if _combined_model_name is not None and "/openai/" in _combined_model_name:
+ _combined_model_name = _combined_model_name.split("/openai/")[0]
+
+ _all_api_bases.add(_combined_model_name)
+ _daily_entries[_day][_combined_model_name] = _avg_latency_per_token
+
+ """
+ each entry needs to be like this:
+ {
+ date: 'Jun 23',
+ 'gpt-4-https://api.openai.com/v1/': 0.002,
+ 'gpt-43-https://api.openai.com-12/v1/': 0.002,
+ }
+ """
+ # convert daily entries to list of dicts
+
+ response: List[dict] = []
+
+ # sort daily entries by date
+ _daily_entries = dict(sorted(_daily_entries.items(), key=lambda item: item[0]))
+ for day in _daily_entries:
+ entry = {"date": str(day)}
+ for model_key, latency in _daily_entries[day].items():
+ entry[model_key] = latency
+ response.append(entry)
+
+ return {
+ "data": response,
+ "all_api_bases": list(_all_api_bases),
+ }
+
+
+@router.get(
+ "/model/metrics/slow_responses",
+ description="View number of hanging requests per model_group",
+ tags=["model management"],
+ include_in_schema=False,
+ dependencies=[Depends(user_api_key_auth)],
+)
+async def model_metrics_slow_responses(
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+ _selected_model_group: Optional[str] = "gpt-4-32k",
+ startTime: Optional[datetime] = None,
+ endTime: Optional[datetime] = None,
+ api_key: Optional[str] = None,
+ customer: Optional[str] = None,
+):
+ global prisma_client, llm_router, proxy_logging_obj
+ if prisma_client is None:
+ raise ProxyException(
+ message="Prisma Client is not initialized",
+ type="internal_error",
+ param="None",
+ code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ )
+ if api_key is None or api_key == "undefined":
+ api_key = "null"
+
+ if customer is None or customer == "undefined":
+ customer = "null"
+
+ startTime = startTime or datetime.now() - timedelta(days=30)
+ endTime = endTime or datetime.now()
+
+ alerting_threshold = (
+ proxy_logging_obj.slack_alerting_instance.alerting_threshold or 300
+ )
+ alerting_threshold = int(alerting_threshold)
+
+ sql_query = """
+SELECT
+ api_base,
+ COUNT(*) AS total_count,
+ SUM(CASE
+ WHEN ("endTime" - "startTime") >= (INTERVAL '1 SECOND' * CAST($1 AS INTEGER)) THEN 1
+ ELSE 0
+ END) AS slow_count
+FROM
+ "LiteLLM_SpendLogs"
+WHERE
+ "model_group" = $2
+ AND "cache_hit" != 'True'
+ AND "startTime" >= $3::timestamp
+ AND "startTime" <= $4::timestamp
+ AND (
+ CASE
+ WHEN $5 != 'null' THEN "api_key" = $5
+ ELSE TRUE
+ END
+ )
+ AND (
+ CASE
+ WHEN $6 != 'null' THEN "end_user" = $6
+ ELSE TRUE
+ END
+ )
+GROUP BY
+ api_base
+ORDER BY
+ slow_count DESC;
+ """
+
+ db_response = await prisma_client.db.query_raw(
+ sql_query,
+ alerting_threshold,
+ _selected_model_group,
+ startTime,
+ endTime,
+ api_key,
+ customer,
+ )
+
+ if db_response is not None:
+ for row in db_response:
+ _api_base = row.get("api_base") or ""
+ if "/openai/" in _api_base:
+ _api_base = _api_base.split("/openai/")[0]
+ row["api_base"] = _api_base
+ return db_response
+
+
+@router.get(
+ "/model/metrics/exceptions",
+ description="View number of failed requests per model on config.yaml",
+ tags=["model management"],
+ include_in_schema=False,
+ dependencies=[Depends(user_api_key_auth)],
+)
+async def model_metrics_exceptions(
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+ _selected_model_group: Optional[str] = None,
+ startTime: Optional[datetime] = None,
+ endTime: Optional[datetime] = None,
+ api_key: Optional[str] = None,
+ customer: Optional[str] = None,
+):
+ global prisma_client, llm_router
+ if prisma_client is None:
+ raise ProxyException(
+ message="Prisma Client is not initialized",
+ type="internal_error",
+ param="None",
+ code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ )
+
+ startTime = startTime or datetime.now() - timedelta(days=30)
+ endTime = endTime or datetime.now()
+
+ if api_key is None or api_key == "undefined":
+ api_key = "null"
+
+ """
+ """
+ sql_query = """
+ WITH cte AS (
+ SELECT
+ CASE WHEN api_base = '' THEN litellm_model_name ELSE CONCAT(litellm_model_name, '-', api_base) END AS combined_model_api_base,
+ exception_type,
+ COUNT(*) AS num_rate_limit_exceptions
+ FROM "LiteLLM_ErrorLogs"
+ WHERE
+ "startTime" >= $1::timestamp
+ AND "endTime" <= $2::timestamp
+ AND model_group = $3
+ GROUP BY combined_model_api_base, exception_type
+ )
+ SELECT
+ combined_model_api_base,
+ COUNT(*) AS total_exceptions,
+ json_object_agg(exception_type, num_rate_limit_exceptions) AS exception_counts
+ FROM cte
+ GROUP BY combined_model_api_base
+ ORDER BY total_exceptions DESC
+ LIMIT 200;
+ """
+ db_response = await prisma_client.db.query_raw(
+ sql_query, startTime, endTime, _selected_model_group, api_key
+ )
+ response: List[dict] = []
+ exception_types = set()
+
+ """
+ Return Data
+ {
+ "combined_model_api_base": "gpt-3.5-turbo-https://api.openai.com/v1/,
+ "total_exceptions": 5,
+ "BadRequestException": 5,
+ "TimeoutException": 2
+ }
+ """
+
+ if db_response is not None:
+ # loop through all models
+ for model_data in db_response:
+ model = model_data.get("combined_model_api_base", "")
+ total_exceptions = model_data.get("total_exceptions", 0)
+ exception_counts = model_data.get("exception_counts", {})
+ curr_row = {
+ "model": model,
+ "total_exceptions": total_exceptions,
+ }
+ curr_row.update(exception_counts)
+ response.append(curr_row)
+ for k, v in exception_counts.items():
+ exception_types.add(k)
+
+ return {"data": response, "exception_types": list(exception_types)}
+
+
+def _get_proxy_model_info(model: dict) -> dict:
+ # provided model_info in config.yaml
+ model_info = model.get("model_info", {})
+
+ # read litellm model_prices_and_context_window.json to get the following:
+ # input_cost_per_token, output_cost_per_token, max_tokens
+ litellm_model_info = get_litellm_model_info(model=model)
+
+ # 2nd pass on the model, try seeing if we can find model in litellm model_cost map
+ if litellm_model_info == {}:
+ # use litellm_param model_name to get model_info
+ litellm_params = model.get("litellm_params", {})
+ litellm_model = litellm_params.get("model", None)
+ try:
+ litellm_model_info = litellm.get_model_info(model=litellm_model)
+ except Exception:
+ litellm_model_info = {}
+ # 3rd pass on the model, try seeing if we can find model but without the "/" in model cost map
+ if litellm_model_info == {}:
+ # use litellm_param model_name to get model_info
+ litellm_params = model.get("litellm_params", {})
+ litellm_model = litellm_params.get("model", None)
+ split_model = litellm_model.split("/")
+ if len(split_model) > 0:
+ litellm_model = split_model[-1]
+ try:
+ litellm_model_info = litellm.get_model_info(
+ model=litellm_model, custom_llm_provider=split_model[0]
+ )
+ except Exception:
+ litellm_model_info = {}
+ for k, v in litellm_model_info.items():
+ if k not in model_info:
+ model_info[k] = v
+ model["model_info"] = model_info
+ # don't return the llm credentials
+ model = remove_sensitive_info_from_deployment(deployment_dict=model)
+
+ return model
+
+
+@router.get(
+ "/model/info",
+ tags=["model management"],
+ dependencies=[Depends(user_api_key_auth)],
+)
+@router.get(
+ "/v1/model/info",
+ tags=["model management"],
+ dependencies=[Depends(user_api_key_auth)],
+)
+async def model_info_v1( # noqa: PLR0915
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+ litellm_model_id: Optional[str] = None,
+):
+ """
+ Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)
+
+ Parameters:
+ litellm_model_id: Optional[str] = None (this is the value of `x-litellm-model-id` returned in response headers)
+
+ - When litellm_model_id is passed, it will return the info for that specific model
+ - When litellm_model_id is not passed, it will return the info for all models
+
+ Returns:
+ Returns a dictionary containing information about each model.
+
+ Example Response:
+ ```json
+ {
+ "data": [
+ {
+ "model_name": "fake-openai-endpoint",
+ "litellm_params": {
+ "api_base": "https://exampleopenaiendpoint-production.up.railway.app/",
+ "model": "openai/fake"
+ },
+ "model_info": {
+ "id": "112f74fab24a7a5245d2ced3536dd8f5f9192c57ee6e332af0f0512e08bed5af",
+ "db_model": false
+ }
+ }
+ ]
+ }
+
+ ```
+ """
+ global llm_model_list, general_settings, user_config_file_path, proxy_config, llm_router, user_model
+
+ if user_model is not None:
+ # user is trying to get specific model from litellm router
+ try:
+ model_info: Dict = cast(Dict, litellm.get_model_info(model=user_model))
+ except Exception:
+ model_info = {}
+ _deployment_info = Deployment(
+ model_name="*",
+ litellm_params=LiteLLM_Params(
+ model=user_model,
+ ),
+ model_info=model_info,
+ )
+ _deployment_info_dict = _deployment_info.model_dump()
+ _deployment_info_dict = remove_sensitive_info_from_deployment(
+ deployment_dict=_deployment_info_dict
+ )
+ return {"data": _deployment_info_dict}
+
+ if llm_model_list is None:
+ raise HTTPException(
+ status_code=500,
+ detail={
+ "error": "LLM Model List not loaded in. Make sure you passed models in your config.yaml or on the LiteLLM Admin UI. - https://docs.litellm.ai/docs/proxy/configs"
+ },
+ )
+
+ if llm_router is None:
+ raise HTTPException(
+ status_code=500,
+ detail={
+ "error": "LLM Router is not loaded in. Make sure you passed models in your config.yaml or on the LiteLLM Admin UI. - https://docs.litellm.ai/docs/proxy/configs"
+ },
+ )
+
+ if litellm_model_id is not None:
+ # user is trying to get specific model from litellm router
+ deployment_info = llm_router.get_deployment(model_id=litellm_model_id)
+ if deployment_info is None:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": f"Model id = {litellm_model_id} not found on litellm proxy"
+ },
+ )
+ _deployment_info_dict = _get_proxy_model_info(
+ model=deployment_info.model_dump(exclude_none=True)
+ )
+ return {"data": [_deployment_info_dict]}
+
+ all_models: List[dict] = []
+ model_access_groups: Dict[str, List[str]] = defaultdict(list)
+ ## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
+ if llm_router is None:
+ proxy_model_list = []
+ else:
+ proxy_model_list = llm_router.get_model_names()
+ model_access_groups = llm_router.get_model_access_groups()
+ key_models = get_key_models(
+ user_api_key_dict=user_api_key_dict,
+ proxy_model_list=proxy_model_list,
+ model_access_groups=model_access_groups,
+ )
+ team_models = get_team_models(
+ team_models=user_api_key_dict.team_models,
+ proxy_model_list=proxy_model_list,
+ model_access_groups=model_access_groups,
+ )
+ all_models_str = get_complete_model_list(
+ key_models=key_models,
+ team_models=team_models,
+ proxy_model_list=proxy_model_list,
+ user_model=user_model,
+ infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
+ )
+
+ if len(all_models_str) > 0:
+ model_names = all_models_str
+ llm_model_list = llm_router.get_model_list()
+ if llm_model_list is not None:
+ _relevant_models = [
+ m for m in llm_model_list if m["model_name"] in model_names
+ ]
+ all_models = copy.deepcopy(_relevant_models) # type: ignore
+ else:
+ all_models = []
+
+ for model in all_models:
+ model = _get_proxy_model_info(model=model)
+
+ verbose_proxy_logger.debug("all_models: %s", all_models)
+ return {"data": all_models}
+
+
+def _get_model_group_info(
+ llm_router: Router, all_models_str: List[str], model_group: Optional[str]
+) -> List[ModelGroupInfo]:
+ model_groups: List[ModelGroupInfo] = []
+ for model in all_models_str:
+ if model_group is not None and model_group != model:
+ continue
+
+ _model_group_info = llm_router.get_model_group_info(model_group=model)
+ if _model_group_info is not None:
+ model_groups.append(_model_group_info)
+ return model_groups
+
+
+@router.get(
+ "/model_group/info",
+ tags=["model management"],
+ dependencies=[Depends(user_api_key_auth)],
+)
+async def model_group_info(
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+ model_group: Optional[str] = None,
+):
+ """
+ Get information about all the deployments on litellm proxy, including config.yaml descriptions (except api key and api base)
+
+ - /model_group/info returns all model groups. End users of proxy should use /model_group/info since those models will be used for /chat/completions, /embeddings, etc.
+ - /model_group/info?model_group=rerank-english-v3.0 returns all model groups for a specific model group (`model_name` in config.yaml)
+
+
+
+ Example Request (All Models):
+ ```shell
+ curl -X 'GET' \
+ 'http://localhost:4000/model_group/info' \
+ -H 'accept: application/json' \
+ -H 'x-api-key: sk-1234'
+ ```
+
+ Example Request (Specific Model Group):
+ ```shell
+ curl -X 'GET' \
+ 'http://localhost:4000/model_group/info?model_group=rerank-english-v3.0' \
+ -H 'accept: application/json' \
+ -H 'Authorization: Bearer sk-1234'
+ ```
+
+ Example Request (Specific Wildcard Model Group): (e.g. `model_name: openai/*` on config.yaml)
+ ```shell
+ curl -X 'GET' \
+ 'http://localhost:4000/model_group/info?model_group=openai/tts-1'
+ -H 'accept: application/json' \
+ -H 'Authorization: Bearersk-1234'
+ ```
+
+ Learn how to use and set wildcard models [here](https://docs.litellm.ai/docs/wildcard_routing)
+
+ Example Response:
+ ```json
+ {
+ "data": [
+ {
+ "model_group": "rerank-english-v3.0",
+ "providers": [
+ "cohere"
+ ],
+ "max_input_tokens": null,
+ "max_output_tokens": null,
+ "input_cost_per_token": 0.0,
+ "output_cost_per_token": 0.0,
+ "mode": null,
+ "tpm": null,
+ "rpm": null,
+ "supports_parallel_function_calling": false,
+ "supports_vision": false,
+ "supports_function_calling": false,
+ "supported_openai_params": [
+ "stream",
+ "temperature",
+ "max_tokens",
+ "logit_bias",
+ "top_p",
+ "frequency_penalty",
+ "presence_penalty",
+ "stop",
+ "n",
+ "extra_headers"
+ ]
+ },
+ {
+ "model_group": "gpt-3.5-turbo",
+ "providers": [
+ "openai"
+ ],
+ "max_input_tokens": 16385.0,
+ "max_output_tokens": 4096.0,
+ "input_cost_per_token": 1.5e-06,
+ "output_cost_per_token": 2e-06,
+ "mode": "chat",
+ "tpm": null,
+ "rpm": null,
+ "supports_parallel_function_calling": false,
+ "supports_vision": false,
+ "supports_function_calling": true,
+ "supported_openai_params": [
+ "frequency_penalty",
+ "logit_bias",
+ "logprobs",
+ "top_logprobs",
+ "max_tokens",
+ "max_completion_tokens",
+ "n",
+ "presence_penalty",
+ "seed",
+ "stop",
+ "stream",
+ "stream_options",
+ "temperature",
+ "top_p",
+ "tools",
+ "tool_choice",
+ "function_call",
+ "functions",
+ "max_retries",
+ "extra_headers",
+ "parallel_tool_calls",
+ "response_format"
+ ]
+ },
+ {
+ "model_group": "llava-hf",
+ "providers": [
+ "openai"
+ ],
+ "max_input_tokens": null,
+ "max_output_tokens": null,
+ "input_cost_per_token": 0.0,
+ "output_cost_per_token": 0.0,
+ "mode": null,
+ "tpm": null,
+ "rpm": null,
+ "supports_parallel_function_calling": false,
+ "supports_vision": true,
+ "supports_function_calling": false,
+ "supported_openai_params": [
+ "frequency_penalty",
+ "logit_bias",
+ "logprobs",
+ "top_logprobs",
+ "max_tokens",
+ "max_completion_tokens",
+ "n",
+ "presence_penalty",
+ "seed",
+ "stop",
+ "stream",
+ "stream_options",
+ "temperature",
+ "top_p",
+ "tools",
+ "tool_choice",
+ "function_call",
+ "functions",
+ "max_retries",
+ "extra_headers",
+ "parallel_tool_calls",
+ "response_format"
+ ]
+ }
+ ]
+ }
+ ```
+ """
+ global llm_model_list, general_settings, user_config_file_path, proxy_config, llm_router
+
+ if llm_model_list is None:
+ raise HTTPException(
+ status_code=500, detail={"error": "LLM Model List not loaded in"}
+ )
+ if llm_router is None:
+ raise HTTPException(
+ status_code=500, detail={"error": "LLM Router is not loaded in"}
+ )
+ ## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
+ model_access_groups: Dict[str, List[str]] = defaultdict(list)
+ if llm_router is None:
+ proxy_model_list = []
+ else:
+ proxy_model_list = llm_router.get_model_names()
+ model_access_groups = llm_router.get_model_access_groups()
+
+ key_models = get_key_models(
+ user_api_key_dict=user_api_key_dict,
+ proxy_model_list=proxy_model_list,
+ model_access_groups=model_access_groups,
+ )
+ team_models = get_team_models(
+ team_models=user_api_key_dict.team_models,
+ proxy_model_list=proxy_model_list,
+ model_access_groups=model_access_groups,
+ )
+ all_models_str = get_complete_model_list(
+ key_models=key_models,
+ team_models=team_models,
+ proxy_model_list=proxy_model_list,
+ user_model=user_model,
+ infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
+ )
+
+ model_groups: List[ModelGroupInfo] = _get_model_group_info(
+ llm_router=llm_router, all_models_str=all_models_str, model_group=model_group
+ )
+
+ return {"data": model_groups}
+
+
+@router.get(
+ "/model/settings",
+ description="Returns provider name, description, and required parameters for each provider",
+ tags=["model management"],
+ dependencies=[Depends(user_api_key_auth)],
+ include_in_schema=False,
+)
+async def model_settings():
+ """
+ Used by UI to generate 'model add' page
+ {
+ field_name=field_name,
+ field_type=allowed_args[field_name]["type"], # string/int
+ field_description=field_info.description or "", # human-friendly description
+ field_value=general_settings.get(field_name, None), # example value
+ }
+ """
+
+ returned_list = []
+ for provider in litellm.provider_list:
+ returned_list.append(
+ ProviderInfo(
+ name=provider,
+ fields=litellm.get_provider_fields(custom_llm_provider=provider),
+ )
+ )
+
+ return returned_list
+
+
+#### ALERTING MANAGEMENT ENDPOINTS ####
+
+
+@router.get(
+ "/alerting/settings",
+ description="Return the configurable alerting param, description, and current value",
+ tags=["alerting"],
+ dependencies=[Depends(user_api_key_auth)],
+ include_in_schema=False,
+)
+async def alerting_settings(
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ global proxy_logging_obj, prisma_client
+ """
+ Used by UI to generate 'alerting settings' page
+ {
+ field_name=field_name,
+ field_type=allowed_args[field_name]["type"], # string/int
+ field_description=field_info.description or "", # human-friendly description
+ field_value=general_settings.get(field_name, None), # example value
+ }
+ """
+ if prisma_client is None:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": CommonProxyErrors.db_not_connected_error.value},
+ )
+
+ if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "{}, your role={}".format(
+ CommonProxyErrors.not_allowed_access.value,
+ user_api_key_dict.user_role,
+ )
+ },
+ )
+
+ ## get general settings from db
+ db_general_settings = await prisma_client.db.litellm_config.find_first(
+ where={"param_name": "general_settings"}
+ )
+
+ if db_general_settings is not None and db_general_settings.param_value is not None:
+ db_general_settings_dict = dict(db_general_settings.param_value)
+ alerting_args_dict: dict = db_general_settings_dict.get("alerting_args", {}) # type: ignore
+ alerting_values: Optional[list] = db_general_settings_dict.get("alerting") # type: ignore
+ else:
+ alerting_args_dict = {}
+ alerting_values = None
+
+ allowed_args = {
+ "slack_alerting": {"type": "Boolean"},
+ "daily_report_frequency": {"type": "Integer"},
+ "report_check_interval": {"type": "Integer"},
+ "budget_alert_ttl": {"type": "Integer"},
+ "outage_alert_ttl": {"type": "Integer"},
+ "region_outage_alert_ttl": {"type": "Integer"},
+ "minor_outage_alert_threshold": {"type": "Integer"},
+ "major_outage_alert_threshold": {"type": "Integer"},
+ "max_outage_alert_list_size": {"type": "Integer"},
+ }
+
+ _slack_alerting: SlackAlerting = proxy_logging_obj.slack_alerting_instance
+ _slack_alerting_args_dict = _slack_alerting.alerting_args.model_dump()
+
+ return_val = []
+
+ is_slack_enabled = False
+
+ if general_settings.get("alerting") and isinstance(
+ general_settings["alerting"], list
+ ):
+ if "slack" in general_settings["alerting"]:
+ is_slack_enabled = True
+
+ _response_obj = ConfigList(
+ field_name="slack_alerting",
+ field_type=allowed_args["slack_alerting"]["type"],
+ field_description="Enable slack alerting for monitoring proxy in production: llm outages, budgets, spend tracking failures.",
+ field_value=is_slack_enabled,
+ stored_in_db=True if alerting_values is not None else False,
+ field_default_value=None,
+ premium_field=False,
+ )
+ return_val.append(_response_obj)
+
+ for field_name, field_info in SlackAlertingArgs.model_fields.items():
+ if field_name in allowed_args:
+
+ _stored_in_db: Optional[bool] = None
+ if field_name in alerting_args_dict:
+ _stored_in_db = True
+ else:
+ _stored_in_db = False
+
+ _response_obj = ConfigList(
+ field_name=field_name,
+ field_type=allowed_args[field_name]["type"],
+ field_description=field_info.description or "",
+ field_value=_slack_alerting_args_dict.get(field_name, None),
+ stored_in_db=_stored_in_db,
+ field_default_value=field_info.default,
+ premium_field=(
+ True if field_name == "region_outage_alert_ttl" else False
+ ),
+ )
+ return_val.append(_response_obj)
+ return return_val
+
+
+#### EXPERIMENTAL QUEUING ####
+@router.post(
+ "/queue/chat/completions",
+ tags=["experimental"],
+ dependencies=[Depends(user_api_key_auth)],
+ include_in_schema=False,
+)
+async def async_queue_request(
+ request: Request,
+ fastapi_response: Response,
+ model: Optional[str] = None,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ global general_settings, user_debug, proxy_logging_obj
+ """
+ v2 attempt at a background worker to handle queuing.
+
+ Just supports /chat/completion calls currently.
+
+ Now using a FastAPI background task + /chat/completions compatible endpoint
+ """
+ data = {}
+ try:
+ data = await request.json() # type: ignore
+
+ # Include original request and headers in the data
+ data["proxy_server_request"] = {
+ "url": str(request.url),
+ "method": request.method,
+ "headers": dict(request.headers),
+ "body": copy.copy(data), # use copy instead of deepcopy
+ }
+
+ verbose_proxy_logger.debug("receiving data: %s", data)
+ data["model"] = (
+ general_settings.get("completion_model", None) # server default
+ or user_model # model name passed via cli args
+ or model # for azure deployments
+ or data.get("model", None) # default passed in http request
+ )
+
+ # users can pass in 'user' param to /chat/completions. Don't override it
+ if data.get("user", None) is None and user_api_key_dict.user_id is not None:
+ # if users are using user_api_key_auth, set `user` in `data`
+ data["user"] = user_api_key_dict.user_id
+
+ if "metadata" not in data:
+ data["metadata"] = {}
+ data["metadata"]["user_api_key"] = user_api_key_dict.api_key
+ data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata
+ _headers = dict(request.headers)
+ _headers.pop(
+ "authorization", None
+ ) # do not store the original `sk-..` api key in the db
+ data["metadata"]["headers"] = _headers
+ data["metadata"]["user_api_key_alias"] = getattr(
+ user_api_key_dict, "key_alias", None
+ )
+ data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id
+ data["metadata"]["user_api_key_team_id"] = getattr(
+ user_api_key_dict, "team_id", None
+ )
+ data["metadata"]["endpoint"] = str(request.url)
+
+ global user_temperature, user_request_timeout, user_max_tokens, user_api_base
+ # override with user settings, these are params passed via cli
+ if user_temperature:
+ data["temperature"] = user_temperature
+ if user_request_timeout:
+ data["request_timeout"] = user_request_timeout
+ if user_max_tokens:
+ data["max_tokens"] = user_max_tokens
+ if user_api_base:
+ data["api_base"] = user_api_base
+
+ if llm_router is None:
+ raise HTTPException(
+ status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
+ )
+
+ response = await llm_router.schedule_acompletion(**data)
+
+ if (
+ "stream" in data and data["stream"] is True
+ ): # use generate_responses to stream responses
+ return StreamingResponse(
+ async_data_generator(
+ user_api_key_dict=user_api_key_dict,
+ response=response,
+ request_data=data,
+ ),
+ media_type="text/event-stream",
+ )
+
+ fastapi_response.headers.update({"x-litellm-priority": str(data["priority"])})
+ return response
+ except Exception as e:
+ await proxy_logging_obj.post_call_failure_hook(
+ user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data
+ )
+ if isinstance(e, HTTPException):
+ raise ProxyException(
+ message=getattr(e, "detail", f"Authentication Error({str(e)})"),
+ type=ProxyErrorTypes.auth_error,
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+ )
+ elif isinstance(e, ProxyException):
+ raise e
+ raise ProxyException(
+ message="Authentication Error, " + str(e),
+ type=ProxyErrorTypes.auth_error,
+ param=getattr(e, "param", "None"),
+ code=status.HTTP_400_BAD_REQUEST,
+ )
+
+
+@app.get("/fallback/login", tags=["experimental"], include_in_schema=False)
+async def fallback_login(request: Request):
+ """
+ Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
+ PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/"
+ Example:
+ """
+ # get url from request
+ redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url))
+ ui_username = os.getenv("UI_USERNAME")
+ if redirect_url.endswith("/"):
+ redirect_url += "sso/callback"
+ else:
+ redirect_url += "/sso/callback"
+
+ if ui_username is not None:
+ # No Google, Microsoft SSO
+ # Use UI Credentials set in .env
+ from fastapi.responses import HTMLResponse
+
+ return HTMLResponse(content=html_form, status_code=200)
+ else:
+ from fastapi.responses import HTMLResponse
+
+ return HTMLResponse(content=html_form, status_code=200)
+
+
+@router.post(
+ "/login", include_in_schema=False
+) # hidden since this is a helper for UI sso login
+async def login(request: Request): # noqa: PLR0915
+ global premium_user, general_settings
+ try:
+ import multipart
+ except ImportError:
+ subprocess.run(["pip", "install", "python-multipart"])
+ global master_key
+ if master_key is None:
+ raise ProxyException(
+ message="Master Key not set for Proxy. Please set Master Key to use Admin UI. Set `LITELLM_MASTER_KEY` in .env or set general_settings:master_key in config.yaml. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
+ type=ProxyErrorTypes.auth_error,
+ param="master_key",
+ code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ )
+ form = await request.form()
+ username = str(form.get("username"))
+ password = str(form.get("password"))
+ ui_username = os.getenv("UI_USERNAME", "admin")
+ ui_password = os.getenv("UI_PASSWORD", None)
+ if ui_password is None:
+ ui_password = str(master_key) if master_key is not None else None
+ if ui_password is None:
+ raise ProxyException(
+ message="set Proxy master key to use UI. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
+ type=ProxyErrorTypes.auth_error,
+ param="UI_PASSWORD",
+ code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ )
+
+ # check if we can find the `username` in the db. on the ui, users can enter username=their email
+ _user_row = None
+ user_role: Optional[
+ Literal[
+ LitellmUserRoles.PROXY_ADMIN,
+ LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
+ LitellmUserRoles.INTERNAL_USER,
+ LitellmUserRoles.INTERNAL_USER_VIEW_ONLY,
+ ]
+ ] = None
+ if prisma_client is not None:
+ _user_row = await prisma_client.db.litellm_usertable.find_first(
+ where={"user_email": {"equals": username}}
+ )
+ disabled_non_admin_personal_key_creation = (
+ get_disabled_non_admin_personal_key_creation()
+ )
+ """
+ To login to Admin UI, we support the following
+ - Login with UI_USERNAME and UI_PASSWORD
+ - Login with Invite Link `user_email` and `password` combination
+ """
+ if secrets.compare_digest(username, ui_username) and secrets.compare_digest(
+ password, ui_password
+ ):
+ # Non SSO -> If user is using UI_USERNAME and UI_PASSWORD they are Proxy admin
+ user_role = LitellmUserRoles.PROXY_ADMIN
+ user_id = litellm_proxy_admin_name
+
+ # we want the key created to have PROXY_ADMIN_PERMISSIONS
+ key_user_id = litellm_proxy_admin_name
+ if (
+ os.getenv("PROXY_ADMIN_ID", None) is not None
+ and os.environ["PROXY_ADMIN_ID"] == user_id
+ ) or user_id == litellm_proxy_admin_name:
+ # checks if user is admin
+ key_user_id = os.getenv("PROXY_ADMIN_ID", litellm_proxy_admin_name)
+
+ # Admin is Authe'd in - generate key for the UI to access Proxy
+
+ # ensure this user is set as the proxy admin, in this route there is no sso, we can assume this user is only the admin
+ await user_update(
+ data=UpdateUserRequest(
+ user_id=key_user_id,
+ user_role=user_role,
+ )
+ )
+ if os.getenv("DATABASE_URL") is not None:
+ response = await generate_key_helper_fn(
+ request_type="key",
+ **{
+ "user_role": LitellmUserRoles.PROXY_ADMIN,
+ "duration": "24hr",
+ "key_max_budget": litellm.max_ui_session_budget,
+ "models": [],
+ "aliases": {},
+ "config": {},
+ "spend": 0,
+ "user_id": key_user_id,
+ "team_id": "litellm-dashboard",
+ }, # type: ignore
+ )
+ else:
+ raise ProxyException(
+ message="No Database connected. Set DATABASE_URL in .env. If set, use `--detailed_debug` to debug issue.",
+ type=ProxyErrorTypes.auth_error,
+ param="DATABASE_URL",
+ code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ )
+ key = response["token"] # type: ignore
+ litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "")
+ if litellm_dashboard_ui.endswith("/"):
+ litellm_dashboard_ui += "ui/"
+ else:
+ litellm_dashboard_ui += "/ui/"
+ import jwt
+
+ jwt_token = jwt.encode( # type: ignore
+ {
+ "user_id": user_id,
+ "key": key,
+ "user_email": None,
+ "user_role": user_role, # this is the path without sso - we can assume only admins will use this
+ "login_method": "username_password",
+ "premium_user": premium_user,
+ "auth_header_name": general_settings.get(
+ "litellm_key_header_name", "Authorization"
+ ),
+ "disabled_non_admin_personal_key_creation": disabled_non_admin_personal_key_creation,
+ },
+ master_key,
+ algorithm="HS256",
+ )
+ litellm_dashboard_ui += "?userID=" + user_id
+ redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
+ redirect_response.set_cookie(key="token", value=jwt_token)
+ return redirect_response
+ elif _user_row is not None:
+ """
+ When sharing invite links
+
+ -> if the user has no role in the DB assume they are only a viewer
+ """
+ user_id = getattr(_user_row, "user_id", "unknown")
+ user_role = getattr(
+ _user_row, "user_role", LitellmUserRoles.INTERNAL_USER_VIEW_ONLY
+ )
+ user_email = getattr(_user_row, "user_email", "unknown")
+ _password = getattr(_user_row, "password", "unknown")
+
+ # check if password == _user_row.password
+ hash_password = hash_token(token=password)
+ if secrets.compare_digest(password, _password) or secrets.compare_digest(
+ hash_password, _password
+ ):
+ if os.getenv("DATABASE_URL") is not None:
+ response = await generate_key_helper_fn(
+ request_type="key",
+ **{ # type: ignore
+ "user_role": user_role,
+ "duration": "24hr",
+ "key_max_budget": litellm.max_ui_session_budget,
+ "models": [],
+ "aliases": {},
+ "config": {},
+ "spend": 0,
+ "user_id": user_id,
+ "team_id": "litellm-dashboard",
+ },
+ )
+ else:
+ raise ProxyException(
+ message="No Database connected. Set DATABASE_URL in .env. If set, use `--detailed_debug` to debug issue.",
+ type=ProxyErrorTypes.auth_error,
+ param="DATABASE_URL",
+ code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ )
+ key = response["token"] # type: ignore
+ litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "")
+ if litellm_dashboard_ui.endswith("/"):
+ litellm_dashboard_ui += "ui/"
+ else:
+ litellm_dashboard_ui += "/ui/"
+ import jwt
+
+ jwt_token = jwt.encode( # type: ignore
+ {
+ "user_id": user_id,
+ "key": key,
+ "user_email": user_email,
+ "user_role": user_role,
+ "login_method": "username_password",
+ "premium_user": premium_user,
+ "auth_header_name": general_settings.get(
+ "litellm_key_header_name", "Authorization"
+ ),
+ "disabled_non_admin_personal_key_creation": disabled_non_admin_personal_key_creation,
+ },
+ master_key,
+ algorithm="HS256",
+ )
+ litellm_dashboard_ui += "?userID=" + user_id
+ redirect_response = RedirectResponse(
+ url=litellm_dashboard_ui, status_code=303
+ )
+ redirect_response.set_cookie(key="token", value=jwt_token)
+ return redirect_response
+ else:
+ raise ProxyException(
+ message=f"Invalid credentials used to access UI.\nNot valid credentials for {username}",
+ type=ProxyErrorTypes.auth_error,
+ param="invalid_credentials",
+ code=status.HTTP_401_UNAUTHORIZED,
+ )
+ else:
+ raise ProxyException(
+ message="Invalid credentials used to access UI.\nCheck 'UI_USERNAME', 'UI_PASSWORD' in .env file",
+ type=ProxyErrorTypes.auth_error,
+ param="invalid_credentials",
+ code=status.HTTP_401_UNAUTHORIZED,
+ )
+
+
+@app.get("/onboarding/get_token", include_in_schema=False)
+async def onboarding(invite_link: str):
+ """
+ - Get the invite link
+ - Validate it's still 'valid'
+ - Invalidate the link (prevents abuse)
+ - Get user from db
+ - Pass in user_email if set
+ """
+ global prisma_client, master_key, general_settings
+ if master_key is None:
+ raise ProxyException(
+ message="Master Key not set for Proxy. Please set Master Key to use Admin UI. Set `LITELLM_MASTER_KEY` in .env or set general_settings:master_key in config.yaml. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.",
+ type=ProxyErrorTypes.auth_error,
+ param="master_key",
+ code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ )
+ ### VALIDATE INVITE LINK ###
+ if prisma_client is None:
+ raise HTTPException(
+ status_code=500,
+ detail={"error": CommonProxyErrors.db_not_connected_error.value},
+ )
+
+ invite_obj = await prisma_client.db.litellm_invitationlink.find_unique(
+ where={"id": invite_link}
+ )
+ if invite_obj is None:
+ raise HTTPException(
+ status_code=401, detail={"error": "Invitation link does not exist in db."}
+ )
+ #### CHECK IF EXPIRED
+ # Extract the date part from both datetime objects
+ utc_now_date = litellm.utils.get_utc_datetime().date()
+ expires_at_date = invite_obj.expires_at.date()
+ if expires_at_date < utc_now_date:
+ raise HTTPException(
+ status_code=401, detail={"error": "Invitation link has expired."}
+ )
+
+ #### INVALIDATE LINK
+ current_time = litellm.utils.get_utc_datetime()
+
+ _ = await prisma_client.db.litellm_invitationlink.update(
+ where={"id": invite_link},
+ data={
+ "accepted_at": current_time,
+ "updated_at": current_time,
+ "is_accepted": True,
+ "updated_by": invite_obj.user_id, # type: ignore
+ },
+ )
+
+ ### GET USER OBJECT ###
+ user_obj = await prisma_client.db.litellm_usertable.find_unique(
+ where={"user_id": invite_obj.user_id}
+ )
+
+ if user_obj is None:
+ raise HTTPException(
+ status_code=401, detail={"error": "User does not exist in db."}
+ )
+
+ user_email = user_obj.user_email
+
+ response = await generate_key_helper_fn(
+ request_type="key",
+ **{
+ "user_role": user_obj.user_role,
+ "duration": "24hr",
+ "key_max_budget": litellm.max_ui_session_budget,
+ "models": [],
+ "aliases": {},
+ "config": {},
+ "spend": 0,
+ "user_id": user_obj.user_id,
+ "team_id": "litellm-dashboard",
+ }, # type: ignore
+ )
+ key = response["token"] # type: ignore
+
+ litellm_dashboard_ui = os.getenv("PROXY_BASE_URL", "")
+ if litellm_dashboard_ui.endswith("/"):
+ litellm_dashboard_ui += "ui/onboarding"
+ else:
+ litellm_dashboard_ui += "/ui/onboarding"
+ import jwt
+
+ disabled_non_admin_personal_key_creation = (
+ get_disabled_non_admin_personal_key_creation()
+ )
+
+ jwt_token = jwt.encode( # type: ignore
+ {
+ "user_id": user_obj.user_id,
+ "key": key,
+ "user_email": user_obj.user_email,
+ "user_role": user_obj.user_role,
+ "login_method": "username_password",
+ "premium_user": premium_user,
+ "auth_header_name": general_settings.get(
+ "litellm_key_header_name", "Authorization"
+ ),
+ "disabled_non_admin_personal_key_creation": disabled_non_admin_personal_key_creation,
+ },
+ master_key,
+ algorithm="HS256",
+ )
+
+ litellm_dashboard_ui += "?token={}&user_email={}".format(jwt_token, user_email)
+ return {
+ "login_url": litellm_dashboard_ui,
+ "token": jwt_token,
+ "user_email": user_email,
+ }
+
+
+@app.post("/onboarding/claim_token", include_in_schema=False)
+async def claim_onboarding_link(data: InvitationClaim):
+ """
+ Special route. Allows UI link share user to update their password.
+
+ - Get the invite link
+ - Validate it's still 'valid'
+ - Check if user within initial session (prevents abuse)
+ - Get user from db
+ - Update user password
+
+ This route can only update user password.
+ """
+ global prisma_client
+ ### VALIDATE INVITE LINK ###
+ if prisma_client is None:
+ raise HTTPException(
+ status_code=500,
+ detail={"error": CommonProxyErrors.db_not_connected_error.value},
+ )
+
+ invite_obj = await prisma_client.db.litellm_invitationlink.find_unique(
+ where={"id": data.invitation_link}
+ )
+ if invite_obj is None:
+ raise HTTPException(
+ status_code=401, detail={"error": "Invitation link does not exist in db."}
+ )
+ #### CHECK IF EXPIRED
+ # Extract the date part from both datetime objects
+ utc_now_date = litellm.utils.get_utc_datetime().date()
+ expires_at_date = invite_obj.expires_at.date()
+ if expires_at_date < utc_now_date:
+ raise HTTPException(
+ status_code=401, detail={"error": "Invitation link has expired."}
+ )
+
+ #### CHECK IF CLAIMED
+ ##### if claimed - accept
+ ##### if unclaimed - reject
+
+ if invite_obj.is_accepted is True:
+ # this is a valid invite that was accepted
+ pass
+ else:
+ raise HTTPException(
+ status_code=401,
+ detail={
+ "error": "The invitation link was never validated. Please file an issue, if this is not intended - https://github.com/BerriAI/litellm/issues."
+ },
+ )
+
+ #### CHECK IF VALID USER ID
+ if invite_obj.user_id != data.user_id:
+ raise HTTPException(
+ status_code=401,
+ detail={
+ "error": "Invalid invitation link. The user id submitted does not match the user id this link is attached to. Got={}, Expected={}".format(
+ data.user_id, invite_obj.user_id
+ )
+ },
+ )
+ ### UPDATE USER OBJECT ###
+ hash_password = hash_token(token=data.password)
+ user_obj = await prisma_client.db.litellm_usertable.update(
+ where={"user_id": invite_obj.user_id}, data={"password": hash_password}
+ )
+
+ if user_obj is None:
+ raise HTTPException(
+ status_code=401, detail={"error": "User does not exist in db."}
+ )
+
+ return user_obj
+
+
+@app.get("/get_image", include_in_schema=False)
+def get_image():
+ """Get logo to show on admin UI"""
+
+ # get current_dir
+ current_dir = os.path.dirname(os.path.abspath(__file__))
+ default_logo = os.path.join(current_dir, "logo.jpg")
+
+ logo_path = os.getenv("UI_LOGO_PATH", default_logo)
+ verbose_proxy_logger.debug("Reading logo from path: %s", logo_path)
+
+ # Check if the logo path is an HTTP/HTTPS URL
+ if logo_path.startswith(("http://", "https://")):
+ # Download the image and cache it
+ client = HTTPHandler()
+ response = client.get(logo_path)
+ if response.status_code == 200:
+ # Save the image to a local file
+ cache_path = os.path.join(current_dir, "cached_logo.jpg")
+ with open(cache_path, "wb") as f:
+ f.write(response.content)
+
+ # Return the cached image as a FileResponse
+ return FileResponse(cache_path, media_type="image/jpeg")
+ else:
+ # Handle the case when the image cannot be downloaded
+ return FileResponse(default_logo, media_type="image/jpeg")
+ else:
+ # Return the local image file if the logo path is not an HTTP/HTTPS URL
+ return FileResponse(logo_path, media_type="image/jpeg")
+
+
+#### INVITATION MANAGEMENT ####
+
+
+@router.post(
+ "/invitation/new",
+ tags=["Invite Links"],
+ dependencies=[Depends(user_api_key_auth)],
+ response_model=InvitationModel,
+ include_in_schema=False,
+)
+async def new_invitation(
+ data: InvitationNew, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)
+):
+ """
+ Allow admin to create invite links, to onboard new users to Admin UI.
+
+ ```
+ curl -X POST 'http://localhost:4000/invitation/new' \
+ -H 'Content-Type: application/json' \
+ -d '{
+ "user_id": "1234" // šŸ‘ˆ id of user in 'LiteLLM_UserTable'
+ }'
+ ```
+ """
+ global prisma_client
+
+ if prisma_client is None:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": CommonProxyErrors.db_not_connected_error.value},
+ )
+
+ if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "{}, your role={}".format(
+ CommonProxyErrors.not_allowed_access.value,
+ user_api_key_dict.user_role,
+ )
+ },
+ )
+
+ current_time = litellm.utils.get_utc_datetime()
+ expires_at = current_time + timedelta(days=7)
+
+ try:
+ response = await prisma_client.db.litellm_invitationlink.create(
+ data={
+ "user_id": data.user_id,
+ "created_at": current_time,
+ "expires_at": expires_at,
+ "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
+ "updated_at": current_time,
+ "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
+ } # type: ignore
+ )
+ return response
+ except Exception as e:
+ if "Foreign key constraint failed on the field" in str(e):
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "User id does not exist in 'LiteLLM_UserTable'. Fix this by creating user via `/user/new`."
+ },
+ )
+ raise HTTPException(status_code=500, detail={"error": str(e)})
+
+
+@router.get(
+ "/invitation/info",
+ tags=["Invite Links"],
+ dependencies=[Depends(user_api_key_auth)],
+ response_model=InvitationModel,
+ include_in_schema=False,
+)
+async def invitation_info(
+ invitation_id: str, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)
+):
+ """
+ Allow admin to create invite links, to onboard new users to Admin UI.
+
+ ```
+ curl -X POST 'http://localhost:4000/invitation/new' \
+ -H 'Content-Type: application/json' \
+ -d '{
+ "user_id": "1234" // šŸ‘ˆ id of user in 'LiteLLM_UserTable'
+ }'
+ ```
+ """
+ global prisma_client
+
+ if prisma_client is None:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": CommonProxyErrors.db_not_connected_error.value},
+ )
+
+ if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "{}, your role={}".format(
+ CommonProxyErrors.not_allowed_access.value,
+ user_api_key_dict.user_role,
+ )
+ },
+ )
+
+ response = await prisma_client.db.litellm_invitationlink.find_unique(
+ where={"id": invitation_id}
+ )
+
+ if response is None:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": "Invitation id does not exist in the database."},
+ )
+ return response
+
+
+@router.post(
+ "/invitation/update",
+ tags=["Invite Links"],
+ dependencies=[Depends(user_api_key_auth)],
+ response_model=InvitationModel,
+ include_in_schema=False,
+)
+async def invitation_update(
+ data: InvitationUpdate,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ Update when invitation is accepted
+
+ ```
+ curl -X POST 'http://localhost:4000/invitation/update' \
+ -H 'Content-Type: application/json' \
+ -d '{
+ "invitation_id": "1234" // šŸ‘ˆ id of invitation in 'LiteLLM_InvitationTable'
+ "is_accepted": True // when invitation is accepted
+ }'
+ ```
+ """
+ global prisma_client
+
+ if prisma_client is None:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": CommonProxyErrors.db_not_connected_error.value},
+ )
+
+ if user_api_key_dict.user_id is None:
+ raise HTTPException(
+ status_code=500,
+ detail={
+ "error": "Unable to identify user id. Received={}".format(
+ user_api_key_dict.user_id
+ )
+ },
+ )
+
+ current_time = litellm.utils.get_utc_datetime()
+ response = await prisma_client.db.litellm_invitationlink.update(
+ where={"id": data.invitation_id},
+ data={
+ "id": data.invitation_id,
+ "is_accepted": data.is_accepted,
+ "accepted_at": current_time,
+ "updated_at": current_time,
+ "updated_by": user_api_key_dict.user_id, # type: ignore
+ },
+ )
+
+ if response is None:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": "Invitation id does not exist in the database."},
+ )
+ return response
+
+
+@router.post(
+ "/invitation/delete",
+ tags=["Invite Links"],
+ dependencies=[Depends(user_api_key_auth)],
+ response_model=InvitationModel,
+ include_in_schema=False,
+)
+async def invitation_delete(
+ data: InvitationDelete,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ Delete invitation link
+
+ ```
+ curl -X POST 'http://localhost:4000/invitation/delete' \
+ -H 'Content-Type: application/json' \
+ -d '{
+ "invitation_id": "1234" // šŸ‘ˆ id of invitation in 'LiteLLM_InvitationTable'
+ }'
+ ```
+ """
+ global prisma_client
+
+ if prisma_client is None:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": CommonProxyErrors.db_not_connected_error.value},
+ )
+
+ if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "{}, your role={}".format(
+ CommonProxyErrors.not_allowed_access.value,
+ user_api_key_dict.user_role,
+ )
+ },
+ )
+
+ response = await prisma_client.db.litellm_invitationlink.delete(
+ where={"id": data.invitation_id}
+ )
+
+ if response is None:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": "Invitation id does not exist in the database."},
+ )
+ return response
+
+
+#### CONFIG MANAGEMENT ####
+@router.post(
+ "/config/update",
+ tags=["config.yaml"],
+ dependencies=[Depends(user_api_key_auth)],
+ include_in_schema=False,
+)
+async def update_config(config_info: ConfigYAML): # noqa: PLR0915
+ """
+ For Admin UI - allows admin to update config via UI
+
+ Currently supports modifying General Settings + LiteLLM settings
+ """
+ global llm_router, llm_model_list, general_settings, proxy_config, proxy_logging_obj, master_key, prisma_client
+ try:
+ import base64
+
+ """
+ - Update the ConfigTable DB
+ - Run 'add_deployment'
+ """
+ if prisma_client is None:
+ raise Exception("No DB Connected")
+
+ if store_model_in_db is not True:
+ raise HTTPException(
+ status_code=500,
+ detail={
+ "error": "Set `'STORE_MODEL_IN_DB='True'` in your env to enable this feature."
+ },
+ )
+
+ updated_settings = config_info.json(exclude_none=True)
+ updated_settings = prisma_client.jsonify_object(updated_settings)
+ for k, v in updated_settings.items():
+ if k == "router_settings":
+ await prisma_client.db.litellm_config.upsert(
+ where={"param_name": k},
+ data={
+ "create": {"param_name": k, "param_value": v},
+ "update": {"param_value": v},
+ },
+ )
+
+ ### OLD LOGIC [TODO] MOVE TO DB ###
+
+ # Load existing config
+ config = await proxy_config.get_config()
+ verbose_proxy_logger.debug("Loaded config: %s", config)
+
+ # update the general settings
+ if config_info.general_settings is not None:
+ config.setdefault("general_settings", {})
+ updated_general_settings = config_info.general_settings.dict(
+ exclude_none=True
+ )
+
+ _existing_settings = config["general_settings"]
+ for k, v in updated_general_settings.items():
+ # overwrite existing settings with updated values
+ if k == "alert_to_webhook_url":
+ # check if slack is already enabled. if not, enable it
+ if "alerting" not in _existing_settings:
+ _existing_settings = {"alerting": ["slack"]}
+ elif isinstance(_existing_settings["alerting"], list):
+ if "slack" not in _existing_settings["alerting"]:
+ _existing_settings["alerting"].append("slack")
+ _existing_settings[k] = v
+ config["general_settings"] = _existing_settings
+
+ if config_info.environment_variables is not None:
+ config.setdefault("environment_variables", {})
+ _updated_environment_variables = config_info.environment_variables
+
+ # encrypt updated_environment_variables #
+ for k, v in _updated_environment_variables.items():
+ encrypted_value = encrypt_value_helper(value=v)
+ _updated_environment_variables[k] = encrypted_value
+
+ _existing_env_variables = config["environment_variables"]
+
+ for k, v in _updated_environment_variables.items():
+ # overwrite existing env variables with updated values
+ _existing_env_variables[k] = _updated_environment_variables[k]
+
+ # update the litellm settings
+ if config_info.litellm_settings is not None:
+ config.setdefault("litellm_settings", {})
+ updated_litellm_settings = config_info.litellm_settings
+ config["litellm_settings"] = {
+ **updated_litellm_settings,
+ **config["litellm_settings"],
+ }
+
+ # if litellm.success_callback in updated_litellm_settings and config["litellm_settings"]
+ if (
+ "success_callback" in updated_litellm_settings
+ and "success_callback" in config["litellm_settings"]
+ ):
+
+ # check both success callback are lists
+ if isinstance(
+ config["litellm_settings"]["success_callback"], list
+ ) and isinstance(updated_litellm_settings["success_callback"], list):
+ combined_success_callback = (
+ config["litellm_settings"]["success_callback"]
+ + updated_litellm_settings["success_callback"]
+ )
+ combined_success_callback = list(set(combined_success_callback))
+ config["litellm_settings"][
+ "success_callback"
+ ] = combined_success_callback
+
+ # Save the updated config
+ await proxy_config.save_config(new_config=config)
+
+ await proxy_config.add_deployment(
+ prisma_client=prisma_client, proxy_logging_obj=proxy_logging_obj
+ )
+
+ return {"message": "Config updated successfully"}
+ except Exception as e:
+ verbose_proxy_logger.error(
+ "litellm.proxy.proxy_server.update_config(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ verbose_proxy_logger.debug(traceback.format_exc())
+ if isinstance(e, HTTPException):
+ raise ProxyException(
+ message=getattr(e, "detail", f"Authentication Error({str(e)})"),
+ type=ProxyErrorTypes.auth_error,
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+ )
+ elif isinstance(e, ProxyException):
+ raise e
+ raise ProxyException(
+ message="Authentication Error, " + str(e),
+ type=ProxyErrorTypes.auth_error,
+ param=getattr(e, "param", "None"),
+ code=status.HTTP_400_BAD_REQUEST,
+ )
+
+
+### CONFIG GENERAL SETTINGS
+"""
+- Update config settings
+- Get config settings
+
+Keep it more precise, to prevent overwrite other values unintentially
+"""
+
+
+@router.post(
+ "/config/field/update",
+ tags=["config.yaml"],
+ dependencies=[Depends(user_api_key_auth)],
+ include_in_schema=False,
+)
+async def update_config_general_settings(
+ data: ConfigFieldUpdate,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ Update a specific field in litellm general settings
+ """
+ global prisma_client
+ ## VALIDATION ##
+ """
+ - Check if prisma_client is None
+ - Check if user allowed to call this endpoint (admin-only)
+ - Check if param in general settings
+ - Check if config value is valid type
+ """
+
+ if prisma_client is None:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": CommonProxyErrors.db_not_connected_error.value},
+ )
+
+ if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": CommonProxyErrors.not_allowed_access.value},
+ )
+
+ if data.field_name not in ConfigGeneralSettings.model_fields:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": "Invalid field={} passed in.".format(data.field_name)},
+ )
+
+ try:
+ ConfigGeneralSettings(**{data.field_name: data.field_value})
+ except Exception:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "Invalid type of field value={} passed in.".format(
+ type(data.field_value),
+ )
+ },
+ )
+
+ ## get general settings from db
+ db_general_settings = await prisma_client.db.litellm_config.find_first(
+ where={"param_name": "general_settings"}
+ )
+ ### update value
+
+ if db_general_settings is None or db_general_settings.param_value is None:
+ general_settings = {}
+ else:
+ general_settings = dict(db_general_settings.param_value)
+
+ ## update db
+
+ general_settings[data.field_name] = data.field_value
+
+ response = await prisma_client.db.litellm_config.upsert(
+ where={"param_name": "general_settings"},
+ data={
+ "create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore
+ "update": {"param_value": json.dumps(general_settings)}, # type: ignore
+ },
+ )
+
+ return response
+
+
+@router.get(
+ "/config/field/info",
+ tags=["config.yaml"],
+ dependencies=[Depends(user_api_key_auth)],
+ response_model=ConfigFieldInfo,
+ include_in_schema=False,
+)
+async def get_config_general_settings(
+ field_name: str,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ global prisma_client
+
+ ## VALIDATION ##
+ """
+ - Check if prisma_client is None
+ - Check if user allowed to call this endpoint (admin-only)
+ - Check if param in general settings
+ """
+ if prisma_client is None:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": CommonProxyErrors.db_not_connected_error.value},
+ )
+
+ if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": CommonProxyErrors.not_allowed_access.value},
+ )
+
+ if field_name not in ConfigGeneralSettings.model_fields:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": "Invalid field={} passed in.".format(field_name)},
+ )
+
+ ## get general settings from db
+ db_general_settings = await prisma_client.db.litellm_config.find_first(
+ where={"param_name": "general_settings"}
+ )
+ ### pop the value
+
+ if db_general_settings is None or db_general_settings.param_value is None:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": "Field name={} not in DB".format(field_name)},
+ )
+ else:
+ general_settings = dict(db_general_settings.param_value)
+
+ if field_name in general_settings:
+ return ConfigFieldInfo(
+ field_name=field_name, field_value=general_settings[field_name]
+ )
+ else:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": "Field name={} not in DB".format(field_name)},
+ )
+
+
+@router.get(
+ "/config/list",
+ tags=["config.yaml"],
+ dependencies=[Depends(user_api_key_auth)],
+ include_in_schema=False,
+)
+async def get_config_list(
+ config_type: Literal["general_settings"],
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+) -> List[ConfigList]:
+ """
+ List the available fields + current values for a given type of setting (currently just 'general_settings'user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),)
+ """
+ global prisma_client, general_settings
+
+ ## VALIDATION ##
+ """
+ - Check if prisma_client is None
+ - Check if user allowed to call this endpoint (admin-only)
+ - Check if param in general settings
+ """
+ if prisma_client is None:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": CommonProxyErrors.db_not_connected_error.value},
+ )
+
+ if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "{}, your role={}".format(
+ CommonProxyErrors.not_allowed_access.value,
+ user_api_key_dict.user_role,
+ )
+ },
+ )
+
+ ## get general settings from db
+ db_general_settings = await prisma_client.db.litellm_config.find_first(
+ where={"param_name": "general_settings"}
+ )
+
+ if db_general_settings is not None and db_general_settings.param_value is not None:
+ db_general_settings_dict = dict(db_general_settings.param_value)
+ else:
+ db_general_settings_dict = {}
+
+ allowed_args = {
+ "max_parallel_requests": {"type": "Integer"},
+ "global_max_parallel_requests": {"type": "Integer"},
+ "max_request_size_mb": {"type": "Integer"},
+ "max_response_size_mb": {"type": "Integer"},
+ "pass_through_endpoints": {"type": "PydanticModel"},
+ }
+
+ return_val = []
+
+ for field_name, field_info in ConfigGeneralSettings.model_fields.items():
+ if field_name in allowed_args:
+
+ ## HANDLE TYPED DICT
+
+ typed_dict_type = allowed_args[field_name]["type"]
+
+ if typed_dict_type == "PydanticModel":
+ if field_name == "pass_through_endpoints":
+ pydantic_class_list = [PassThroughGenericEndpoint]
+ else:
+ pydantic_class_list = []
+
+ for pydantic_class in pydantic_class_list:
+ # Get type hints from the TypedDict to create FieldDetail objects
+ nested_fields = [
+ FieldDetail(
+ field_name=sub_field,
+ field_type=sub_field_type.__name__,
+ field_description="", # Add custom logic if descriptions are available
+ field_default_value=general_settings.get(sub_field, None),
+ stored_in_db=None,
+ )
+ for sub_field, sub_field_type in pydantic_class.__annotations__.items()
+ ]
+
+ idx = 0
+ for (
+ sub_field,
+ sub_field_info,
+ ) in pydantic_class.model_fields.items():
+ if (
+ hasattr(sub_field_info, "description")
+ and sub_field_info.description is not None
+ ):
+ nested_fields[idx].field_description = (
+ sub_field_info.description
+ )
+ idx += 1
+
+ _stored_in_db = None
+ if field_name in db_general_settings_dict:
+ _stored_in_db = True
+ elif field_name in general_settings:
+ _stored_in_db = False
+
+ _response_obj = ConfigList(
+ field_name=field_name,
+ field_type=allowed_args[field_name]["type"],
+ field_description=field_info.description or "",
+ field_value=general_settings.get(field_name, None),
+ stored_in_db=_stored_in_db,
+ field_default_value=field_info.default,
+ nested_fields=nested_fields,
+ )
+ return_val.append(_response_obj)
+
+ else:
+ nested_fields = None
+
+ _stored_in_db = None
+ if field_name in db_general_settings_dict:
+ _stored_in_db = True
+ elif field_name in general_settings:
+ _stored_in_db = False
+
+ _response_obj = ConfigList(
+ field_name=field_name,
+ field_type=allowed_args[field_name]["type"],
+ field_description=field_info.description or "",
+ field_value=general_settings.get(field_name, None),
+ stored_in_db=_stored_in_db,
+ field_default_value=field_info.default,
+ nested_fields=nested_fields,
+ )
+ return_val.append(_response_obj)
+
+ return return_val
+
+
+@router.post(
+ "/config/field/delete",
+ tags=["config.yaml"],
+ dependencies=[Depends(user_api_key_auth)],
+ include_in_schema=False,
+)
+async def delete_config_general_settings(
+ data: ConfigFieldDelete,
+ user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
+):
+ """
+ Delete the db value of this field in litellm general settings. Resets it to it's initial default value on litellm.
+ """
+ global prisma_client
+ ## VALIDATION ##
+ """
+ - Check if prisma_client is None
+ - Check if user allowed to call this endpoint (admin-only)
+ - Check if param in general settings
+ """
+ if prisma_client is None:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": CommonProxyErrors.db_not_connected_error.value},
+ )
+
+ if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
+ raise HTTPException(
+ status_code=400,
+ detail={
+ "error": "{}, your role={}".format(
+ CommonProxyErrors.not_allowed_access.value,
+ user_api_key_dict.user_role,
+ )
+ },
+ )
+
+ if data.field_name not in ConfigGeneralSettings.model_fields:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": "Invalid field={} passed in.".format(data.field_name)},
+ )
+
+ ## get general settings from db
+ db_general_settings = await prisma_client.db.litellm_config.find_first(
+ where={"param_name": "general_settings"}
+ )
+ ### pop the value
+
+ if db_general_settings is None or db_general_settings.param_value is None:
+ raise HTTPException(
+ status_code=400,
+ detail={"error": "Field name={} not in config".format(data.field_name)},
+ )
+ else:
+ general_settings = dict(db_general_settings.param_value)
+
+ ## update db
+
+ general_settings.pop(data.field_name, None)
+
+ response = await prisma_client.db.litellm_config.upsert(
+ where={"param_name": "general_settings"},
+ data={
+ "create": {"param_name": "general_settings", "param_value": json.dumps(general_settings)}, # type: ignore
+ "update": {"param_value": json.dumps(general_settings)}, # type: ignore
+ },
+ )
+
+ return response
+
+
+@router.get(
+ "/get/config/callbacks",
+ tags=["config.yaml"],
+ include_in_schema=False,
+ dependencies=[Depends(user_api_key_auth)],
+)
+async def get_config(): # noqa: PLR0915
+ """
+ For Admin UI - allows admin to view config via UI
+ # return the callbacks and the env variables for the callback
+
+ """
+ global llm_router, llm_model_list, general_settings, proxy_config, proxy_logging_obj, master_key
+ try:
+ import base64
+
+ all_available_callbacks = AllCallbacks()
+
+ config_data = await proxy_config.get_config()
+ _litellm_settings = config_data.get("litellm_settings", {})
+ _general_settings = config_data.get("general_settings", {})
+ environment_variables = config_data.get("environment_variables", {})
+
+ # check if "langfuse" in litellm_settings
+ _success_callbacks = _litellm_settings.get("success_callback", [])
+ _data_to_return = []
+ """
+ [
+ {
+ "name": "langfuse",
+ "variables": {
+ "LANGFUSE_PUB_KEY": "value",
+ "LANGFUSE_SECRET_KEY": "value",
+ "LANGFUSE_HOST": "value"
+ },
+ }
+ ]
+
+ """
+ for _callback in _success_callbacks:
+ if _callback != "langfuse":
+ if _callback == "openmeter":
+ env_vars = [
+ "OPENMETER_API_KEY",
+ ]
+ elif _callback == "braintrust":
+ env_vars = [
+ "BRAINTRUST_API_KEY",
+ ]
+ elif _callback == "traceloop":
+ env_vars = ["TRACELOOP_API_KEY"]
+ elif _callback == "custom_callback_api":
+ env_vars = ["GENERIC_LOGGER_ENDPOINT"]
+ elif _callback == "otel":
+ env_vars = ["OTEL_EXPORTER", "OTEL_ENDPOINT", "OTEL_HEADERS"]
+ elif _callback == "langsmith":
+ env_vars = [
+ "LANGSMITH_API_KEY",
+ "LANGSMITH_PROJECT",
+ "LANGSMITH_DEFAULT_RUN_NAME",
+ ]
+ else:
+ env_vars = []
+
+ env_vars_dict = {}
+ for _var in env_vars:
+ env_variable = environment_variables.get(_var, None)
+ if env_variable is None:
+ env_vars_dict[_var] = None
+ else:
+ # decode + decrypt the value
+ decrypted_value = decrypt_value_helper(value=env_variable)
+ env_vars_dict[_var] = decrypted_value
+
+ _data_to_return.append({"name": _callback, "variables": env_vars_dict})
+ elif _callback == "langfuse":
+ _langfuse_vars = [
+ "LANGFUSE_PUBLIC_KEY",
+ "LANGFUSE_SECRET_KEY",
+ "LANGFUSE_HOST",
+ ]
+ _langfuse_env_vars = {}
+ for _var in _langfuse_vars:
+ env_variable = environment_variables.get(_var, None)
+ if env_variable is None:
+ _langfuse_env_vars[_var] = None
+ else:
+ # decode + decrypt the value
+ decrypted_value = decrypt_value_helper(value=env_variable)
+ _langfuse_env_vars[_var] = decrypted_value
+
+ _data_to_return.append(
+ {"name": _callback, "variables": _langfuse_env_vars}
+ )
+
+ # Check if slack alerting is on
+ _alerting = _general_settings.get("alerting", [])
+ alerting_data = []
+ if "slack" in _alerting:
+ _slack_vars = [
+ "SLACK_WEBHOOK_URL",
+ ]
+ _slack_env_vars = {}
+ for _var in _slack_vars:
+ env_variable = environment_variables.get(_var, None)
+ if env_variable is None:
+ _value = os.getenv("SLACK_WEBHOOK_URL", None)
+ _slack_env_vars[_var] = _value
+ else:
+ # decode + decrypt the value
+ _decrypted_value = decrypt_value_helper(value=env_variable)
+ _slack_env_vars[_var] = _decrypted_value
+
+ _alerting_types = proxy_logging_obj.slack_alerting_instance.alert_types
+ _all_alert_types = (
+ proxy_logging_obj.slack_alerting_instance._all_possible_alert_types()
+ )
+ _alerts_to_webhook = (
+ proxy_logging_obj.slack_alerting_instance.alert_to_webhook_url
+ )
+ alerting_data.append(
+ {
+ "name": "slack",
+ "variables": _slack_env_vars,
+ "active_alerts": _alerting_types,
+ "alerts_to_webhook": _alerts_to_webhook,
+ }
+ )
+ # pass email alerting vars
+ _email_vars = [
+ "SMTP_HOST",
+ "SMTP_PORT",
+ "SMTP_USERNAME",
+ "SMTP_PASSWORD",
+ "SMTP_SENDER_EMAIL",
+ "TEST_EMAIL_ADDRESS",
+ "EMAIL_LOGO_URL",
+ "EMAIL_SUPPORT_CONTACT",
+ ]
+ _email_env_vars = {}
+ for _var in _email_vars:
+ env_variable = environment_variables.get(_var, None)
+ if env_variable is None:
+ _email_env_vars[_var] = None
+ else:
+ # decode + decrypt the value
+ _decrypted_value = decrypt_value_helper(value=env_variable)
+ _email_env_vars[_var] = _decrypted_value
+
+ alerting_data.append(
+ {
+ "name": "email",
+ "variables": _email_env_vars,
+ }
+ )
+
+ if llm_router is None:
+ _router_settings = {}
+ else:
+ _router_settings = llm_router.get_settings()
+
+ return {
+ "status": "success",
+ "callbacks": _data_to_return,
+ "alerts": alerting_data,
+ "router_settings": _router_settings,
+ "available_callbacks": all_available_callbacks,
+ }
+ except Exception as e:
+ verbose_proxy_logger.exception(
+ "litellm.proxy.proxy_server.get_config(): Exception occured - {}".format(
+ str(e)
+ )
+ )
+ if isinstance(e, HTTPException):
+ raise ProxyException(
+ message=getattr(e, "detail", f"Authentication Error({str(e)})"),
+ type=ProxyErrorTypes.auth_error,
+ param=getattr(e, "param", "None"),
+ code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
+ )
+ elif isinstance(e, ProxyException):
+ raise e
+ raise ProxyException(
+ message="Authentication Error, " + str(e),
+ type=ProxyErrorTypes.auth_error,
+ param=getattr(e, "param", "None"),
+ code=status.HTTP_400_BAD_REQUEST,
+ )
+
+
+@router.get(
+ "/config/yaml",
+ tags=["config.yaml"],
+ dependencies=[Depends(user_api_key_auth)],
+ include_in_schema=False,
+)
+async def config_yaml_endpoint(config_info: ConfigYAML):
+ """
+ This is a mock endpoint, to show what you can set in config.yaml details in the Swagger UI.
+
+ Parameters:
+
+ The config.yaml object has the following attributes:
+ - **model_list**: *Optional[List[ModelParams]]* - A list of supported models on the server, along with model-specific configurations. ModelParams includes "model_name" (name of the model), "litellm_params" (litellm-specific parameters for the model), and "model_info" (additional info about the model such as id, mode, cost per token, etc).
+
+ - **litellm_settings**: *Optional[dict]*: Settings for the litellm module. You can specify multiple properties like "drop_params", "set_verbose", "api_base", "cache".
+
+ - **general_settings**: *Optional[ConfigGeneralSettings]*: General settings for the server like "completion_model" (default model for chat completion calls), "use_azure_key_vault" (option to load keys from azure key vault), "master_key" (key required for all calls to proxy), and others.
+
+ Please, refer to each class's description for a better understanding of the specific attributes within them.
+
+ Note: This is a mock endpoint primarily meant for demonstration purposes, and does not actually provide or change any configurations.
+ """
+ return {"hello": "world"}
+
+
+@router.get(
+ "/get/litellm_model_cost_map",
+ include_in_schema=False,
+ dependencies=[Depends(user_api_key_auth)],
+)
+async def get_litellm_model_cost_map():
+ try:
+ _model_cost_map = litellm.model_cost
+ return _model_cost_map
+ except Exception as e:
+ raise HTTPException(
+ status_code=500,
+ detail=f"Internal Server Error ({str(e)})",
+ )
+
+
+@router.get("/", dependencies=[Depends(user_api_key_auth)])
+async def home(request: Request):
+ return "LiteLLM: RUNNING"
+
+
+@router.get("/routes", dependencies=[Depends(user_api_key_auth)])
+async def get_routes():
+ """
+ Get a list of available routes in the FastAPI application.
+ """
+ routes = []
+ for route in app.routes:
+ endpoint_route = getattr(route, "endpoint", None)
+ if endpoint_route is not None:
+ route_info = {
+ "path": getattr(route, "path", None),
+ "methods": getattr(route, "methods", None),
+ "name": getattr(route, "name", None),
+ "endpoint": (
+ endpoint_route.__name__
+ if getattr(route, "endpoint", None)
+ else None
+ ),
+ }
+ routes.append(route_info)
+
+ return {"routes": routes}
+
+
+#### TEST ENDPOINTS ####
+# @router.get(
+# "/token/generate",
+# dependencies=[Depends(user_api_key_auth)],
+# include_in_schema=False,
+# )
+# async def token_generate():
+# """
+# Test endpoint. Admin-only access. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc.
+# """
+# # Initialize AuthJWTSSO with your OpenID Provider configuration
+# from fastapi_sso import AuthJWTSSO
+
+# auth_jwt_sso = AuthJWTSSO(
+# issuer=os.getenv("OPENID_BASE_URL"),
+# client_id=os.getenv("OPENID_CLIENT_ID"),
+# client_secret=os.getenv("OPENID_CLIENT_SECRET"),
+# scopes=["litellm_proxy_admin"],
+# )
+
+# token = auth_jwt_sso.create_access_token()
+
+# return {"token": token}
+
+
+app.include_router(router)
+app.include_router(response_router)
+app.include_router(batches_router)
+app.include_router(rerank_router)
+app.include_router(fine_tuning_router)
+app.include_router(credential_router)
+app.include_router(vertex_router)
+app.include_router(llm_passthrough_router)
+app.include_router(anthropic_router)
+app.include_router(langfuse_router)
+app.include_router(pass_through_router)
+app.include_router(health_router)
+app.include_router(key_management_router)
+app.include_router(internal_user_router)
+app.include_router(team_router)
+app.include_router(ui_sso_router)
+app.include_router(organization_router)
+app.include_router(customer_router)
+app.include_router(spend_management_router)
+app.include_router(caching_router)
+app.include_router(analytics_router)
+app.include_router(guardrails_router)
+app.include_router(debugging_endpoints_router)
+app.include_router(ui_crud_endpoints_router)
+app.include_router(openai_files_router)
+app.include_router(team_callback_router)
+app.include_router(budget_management_router)
+app.include_router(model_management_router)