aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/db
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/db')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/db/base_client.py53
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/db/check_migration.py104
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/db/create_views.py227
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/db/dynamo_db.py71
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/db/log_db_metrics.py143
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/db/prisma_client.py278
6 files changed, 876 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/db/base_client.py b/.venv/lib/python3.12/site-packages/litellm/proxy/db/base_client.py
new file mode 100644
index 00000000..07f0ecdc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/db/base_client.py
@@ -0,0 +1,53 @@
+from typing import Any, Literal, List
+
+
+class CustomDB:
+ """
+ Implements a base class that we expect any custom db implementation (e.g. DynamoDB) to follow
+ """
+
+ def __init__(self) -> None:
+ pass
+
+ def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
+ """
+ Check if key valid
+ """
+ pass
+
+ def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]):
+ """
+ For new key / user logic
+ """
+ pass
+
+ def update_data(
+ self, key: str, value: Any, table_name: Literal["user", "key", "config"]
+ ):
+ """
+ For cost tracking logic
+ """
+ pass
+
+ def delete_data(
+ self, keys: List[str], table_name: Literal["user", "key", "config"]
+ ):
+ """
+ For /key/delete endpoint s
+ """
+
+ def connect(
+ self,
+ ):
+ """
+ For connecting to db and creating / updating any tables
+ """
+ pass
+
+ def disconnect(
+ self,
+ ):
+ """
+ For closing connection on server shutdown
+ """
+ pass
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/db/check_migration.py b/.venv/lib/python3.12/site-packages/litellm/proxy/db/check_migration.py
new file mode 100644
index 00000000..bf180c11
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/db/check_migration.py
@@ -0,0 +1,104 @@
+"""Module for checking differences between Prisma schema and database."""
+
+import os
+import subprocess
+from typing import List, Optional, Tuple
+
+from litellm._logging import verbose_logger
+
+
+def extract_sql_commands(diff_output: str) -> List[str]:
+ """
+ Extract SQL commands from the Prisma migrate diff output.
+ Args:
+ diff_output (str): The full output from prisma migrate diff.
+ Returns:
+ List[str]: A list of SQL commands extracted from the diff output.
+ """
+ # Split the output into lines and remove empty lines
+ lines = [line.strip() for line in diff_output.split("\n") if line.strip()]
+
+ sql_commands = []
+ current_command = ""
+ in_sql_block = False
+
+ for line in lines:
+ if line.startswith("-- "): # Comment line, likely a table operation description
+ if in_sql_block and current_command:
+ sql_commands.append(current_command.strip())
+ current_command = ""
+ in_sql_block = True
+ elif in_sql_block:
+ if line.endswith(";"):
+ current_command += line
+ sql_commands.append(current_command.strip())
+ current_command = ""
+ in_sql_block = False
+ else:
+ current_command += line + " "
+
+ # Add any remaining command
+ if current_command:
+ sql_commands.append(current_command.strip())
+
+ return sql_commands
+
+
+def check_prisma_schema_diff_helper(db_url: str) -> Tuple[bool, List[str]]:
+ """Checks for differences between current database and Prisma schema.
+ Returns:
+ A tuple containing:
+ - A boolean indicating if differences were found (True) or not (False).
+ - A string with the diff output or error message.
+ Raises:
+ subprocess.CalledProcessError: If the Prisma command fails.
+ Exception: For any other errors during execution.
+ """
+ verbose_logger.debug("Checking for Prisma schema diff...") # noqa: T201
+ try:
+ result = subprocess.run(
+ [
+ "prisma",
+ "migrate",
+ "diff",
+ "--from-url",
+ db_url,
+ "--to-schema-datamodel",
+ "./schema.prisma",
+ "--script",
+ ],
+ capture_output=True,
+ text=True,
+ check=True,
+ )
+
+ # return True, "Migration diff generated successfully."
+ sql_commands = extract_sql_commands(result.stdout)
+
+ if sql_commands:
+ print("Changes to DB Schema detected") # noqa: T201
+ print("Required SQL commands:") # noqa: T201
+ for command in sql_commands:
+ print(command) # noqa: T201
+ return True, sql_commands
+ else:
+ return False, []
+ except subprocess.CalledProcessError as e:
+ error_message = f"Failed to generate migration diff. Error: {e.stderr}"
+ print(error_message) # noqa: T201
+ return False, []
+
+
+def check_prisma_schema_diff(db_url: Optional[str] = None) -> None:
+ """Main function to run the Prisma schema diff check."""
+ if db_url is None:
+ db_url = os.getenv("DATABASE_URL")
+ if db_url is None:
+ raise Exception("DATABASE_URL not set")
+ has_diff, message = check_prisma_schema_diff_helper(db_url)
+ if has_diff:
+ verbose_logger.exception(
+ "🚨🚨🚨 prisma schema out of sync with db. Consider running these sql_commands to sync the two - {}".format(
+ message
+ )
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/db/create_views.py b/.venv/lib/python3.12/site-packages/litellm/proxy/db/create_views.py
new file mode 100644
index 00000000..e9303077
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/db/create_views.py
@@ -0,0 +1,227 @@
+from typing import Any
+
+from litellm import verbose_logger
+
+_db = Any
+
+
+async def create_missing_views(db: _db): # noqa: PLR0915
+ """
+ --------------------------------------------------
+ NOTE: Copy of `litellm/db_scripts/create_views.py`.
+ --------------------------------------------------
+ Checks if the LiteLLM_VerificationTokenView and MonthlyGlobalSpend exists in the user's db.
+
+ LiteLLM_VerificationTokenView: This view is used for getting the token + team data in user_api_key_auth
+
+ MonthlyGlobalSpend: This view is used for the admin view to see global spend for this month
+
+ If the view doesn't exist, one will be created.
+ """
+ try:
+ # Try to select one row from the view
+ await db.query_raw("""SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1""")
+ print("LiteLLM_VerificationTokenView Exists!") # noqa
+ except Exception:
+ # If an error occurs, the view does not exist, so create it
+ await db.execute_raw(
+ """
+ CREATE VIEW "LiteLLM_VerificationTokenView" AS
+ SELECT
+ v.*,
+ t.spend AS team_spend,
+ t.max_budget AS team_max_budget,
+ t.tpm_limit AS team_tpm_limit,
+ t.rpm_limit AS team_rpm_limit
+ FROM "LiteLLM_VerificationToken" v
+ LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
+ """
+ )
+
+ print("LiteLLM_VerificationTokenView Created!") # noqa
+
+ try:
+ await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""")
+ print("MonthlyGlobalSpend Exists!") # noqa
+ except Exception:
+ sql_query = """
+ CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS
+ SELECT
+ DATE("startTime") AS date,
+ SUM("spend") AS spend
+ FROM
+ "LiteLLM_SpendLogs"
+ WHERE
+ "startTime" >= (CURRENT_DATE - INTERVAL '30 days')
+ GROUP BY
+ DATE("startTime");
+ """
+ await db.execute_raw(query=sql_query)
+
+ print("MonthlyGlobalSpend Created!") # noqa
+
+ try:
+ await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""")
+ print("Last30dKeysBySpend Exists!") # noqa
+ except Exception:
+ sql_query = """
+ CREATE OR REPLACE VIEW "Last30dKeysBySpend" AS
+ SELECT
+ L."api_key",
+ V."key_alias",
+ V."key_name",
+ SUM(L."spend") AS total_spend
+ FROM
+ "LiteLLM_SpendLogs" L
+ LEFT JOIN
+ "LiteLLM_VerificationToken" V
+ ON
+ L."api_key" = V."token"
+ WHERE
+ L."startTime" >= (CURRENT_DATE - INTERVAL '30 days')
+ GROUP BY
+ L."api_key", V."key_alias", V."key_name"
+ ORDER BY
+ total_spend DESC;
+ """
+ await db.execute_raw(query=sql_query)
+
+ print("Last30dKeysBySpend Created!") # noqa
+
+ try:
+ await db.query_raw("""SELECT 1 FROM "Last30dModelsBySpend" LIMIT 1""")
+ print("Last30dModelsBySpend Exists!") # noqa
+ except Exception:
+ sql_query = """
+ CREATE OR REPLACE VIEW "Last30dModelsBySpend" AS
+ SELECT
+ "model",
+ SUM("spend") AS total_spend
+ FROM
+ "LiteLLM_SpendLogs"
+ WHERE
+ "startTime" >= (CURRENT_DATE - INTERVAL '30 days')
+ AND "model" != ''
+ GROUP BY
+ "model"
+ ORDER BY
+ total_spend DESC;
+ """
+ await db.execute_raw(query=sql_query)
+
+ print("Last30dModelsBySpend Created!") # noqa
+ try:
+ await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpendPerKey" LIMIT 1""")
+ print("MonthlyGlobalSpendPerKey Exists!") # noqa
+ except Exception:
+ sql_query = """
+ CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerKey" AS
+ SELECT
+ DATE("startTime") AS date,
+ SUM("spend") AS spend,
+ api_key as api_key
+ FROM
+ "LiteLLM_SpendLogs"
+ WHERE
+ "startTime" >= (CURRENT_DATE - INTERVAL '30 days')
+ GROUP BY
+ DATE("startTime"),
+ api_key;
+ """
+ await db.execute_raw(query=sql_query)
+
+ print("MonthlyGlobalSpendPerKey Created!") # noqa
+ try:
+ await db.query_raw(
+ """SELECT 1 FROM "MonthlyGlobalSpendPerUserPerKey" LIMIT 1"""
+ )
+ print("MonthlyGlobalSpendPerUserPerKey Exists!") # noqa
+ except Exception:
+ sql_query = """
+ CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerUserPerKey" AS
+ SELECT
+ DATE("startTime") AS date,
+ SUM("spend") AS spend,
+ api_key as api_key,
+ "user" as "user"
+ FROM
+ "LiteLLM_SpendLogs"
+ WHERE
+ "startTime" >= (CURRENT_DATE - INTERVAL '30 days')
+ GROUP BY
+ DATE("startTime"),
+ "user",
+ api_key;
+ """
+ await db.execute_raw(query=sql_query)
+
+ print("MonthlyGlobalSpendPerUserPerKey Created!") # noqa
+
+ try:
+ await db.query_raw("""SELECT 1 FROM "DailyTagSpend" LIMIT 1""")
+ print("DailyTagSpend Exists!") # noqa
+ except Exception:
+ sql_query = """
+ CREATE OR REPLACE VIEW "DailyTagSpend" AS
+ SELECT
+ jsonb_array_elements_text(request_tags) AS individual_request_tag,
+ DATE(s."startTime") AS spend_date,
+ COUNT(*) AS log_count,
+ SUM(spend) AS total_spend
+ FROM "LiteLLM_SpendLogs" s
+ GROUP BY individual_request_tag, DATE(s."startTime");
+ """
+ await db.execute_raw(query=sql_query)
+
+ print("DailyTagSpend Created!") # noqa
+
+ try:
+ await db.query_raw("""SELECT 1 FROM "Last30dTopEndUsersSpend" LIMIT 1""")
+ print("Last30dTopEndUsersSpend Exists!") # noqa
+ except Exception:
+ sql_query = """
+ CREATE VIEW "Last30dTopEndUsersSpend" AS
+ SELECT end_user, COUNT(*) AS total_events, SUM(spend) AS total_spend
+ FROM "LiteLLM_SpendLogs"
+ WHERE end_user <> '' AND end_user <> user
+ AND "startTime" >= CURRENT_DATE - INTERVAL '30 days'
+ GROUP BY end_user
+ ORDER BY total_spend DESC
+ LIMIT 100;
+ """
+ await db.execute_raw(query=sql_query)
+
+ print("Last30dTopEndUsersSpend Created!") # noqa
+
+ return
+
+
+async def should_create_missing_views(db: _db) -> bool:
+ """
+ Run only on first time startup.
+
+ If SpendLogs table already has values, then don't create views on startup.
+ """
+
+ sql_query = """
+ SELECT reltuples::BIGINT
+ FROM pg_class
+ WHERE oid = '"LiteLLM_SpendLogs"'::regclass;
+ """
+
+ result = await db.query_raw(query=sql_query)
+
+ verbose_logger.debug("Estimated Row count of LiteLLM_SpendLogs = {}".format(result))
+ if (
+ result
+ and isinstance(result, list)
+ and len(result) > 0
+ and isinstance(result[0], dict)
+ and "reltuples" in result[0]
+ and result[0]["reltuples"]
+ and (result[0]["reltuples"] == 0 or result[0]["reltuples"] == -1)
+ ):
+ verbose_logger.debug("Should create views")
+ return True
+
+ return False
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/db/dynamo_db.py b/.venv/lib/python3.12/site-packages/litellm/proxy/db/dynamo_db.py
new file mode 100644
index 00000000..628509d9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/db/dynamo_db.py
@@ -0,0 +1,71 @@
+"""
+Deprecated. Only PostgresSQL is supported.
+"""
+
+from litellm._logging import verbose_proxy_logger
+from litellm.proxy._types import DynamoDBArgs
+from litellm.proxy.db.base_client import CustomDB
+
+
+class DynamoDBWrapper(CustomDB):
+ from aiodynamo.credentials import Credentials, StaticCredentials
+
+ credentials: Credentials
+
+ def __init__(self, database_arguments: DynamoDBArgs):
+ from aiodynamo.models import PayPerRequest, Throughput
+
+ self.throughput_type = None
+ if database_arguments.billing_mode == "PAY_PER_REQUEST":
+ self.throughput_type = PayPerRequest()
+ elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
+ if (
+ database_arguments.read_capacity_units is not None
+ and isinstance(database_arguments.read_capacity_units, int)
+ and database_arguments.write_capacity_units is not None
+ and isinstance(database_arguments.write_capacity_units, int)
+ ):
+ self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore
+ else:
+ raise Exception(
+ f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}"
+ )
+ self.database_arguments = database_arguments
+ self.region_name = database_arguments.region_name
+
+ def set_env_vars_based_on_arn(self):
+ if self.database_arguments.aws_role_name is None:
+ return
+ verbose_proxy_logger.debug(
+ f"DynamoDB: setting env vars based on arn={self.database_arguments.aws_role_name}"
+ )
+ import os
+
+ import boto3
+
+ sts_client = boto3.client("sts")
+
+ # call 1
+ sts_client.assume_role_with_web_identity(
+ RoleArn=self.database_arguments.aws_role_name,
+ RoleSessionName=self.database_arguments.aws_session_name,
+ WebIdentityToken=self.database_arguments.aws_web_identity_token,
+ )
+
+ # call 2
+ assumed_role = sts_client.assume_role(
+ RoleArn=self.database_arguments.assume_role_aws_role_name,
+ RoleSessionName=self.database_arguments.assume_role_aws_session_name,
+ )
+
+ aws_access_key_id = assumed_role["Credentials"]["AccessKeyId"]
+ aws_secret_access_key = assumed_role["Credentials"]["SecretAccessKey"]
+ aws_session_token = assumed_role["Credentials"]["SessionToken"]
+
+ verbose_proxy_logger.debug(
+ f"Got STS assumed Role, aws_access_key_id={aws_access_key_id}"
+ )
+ # set these in the env so aiodynamo can use them
+ os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
+ os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
+ os.environ["AWS_SESSION_TOKEN"] = aws_session_token
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/db/log_db_metrics.py b/.venv/lib/python3.12/site-packages/litellm/proxy/db/log_db_metrics.py
new file mode 100644
index 00000000..9bd33507
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/db/log_db_metrics.py
@@ -0,0 +1,143 @@
+"""
+Handles logging DB success/failure to ServiceLogger()
+
+ServiceLogger() then sends DB logs to Prometheus, OTEL, Datadog etc
+"""
+
+import asyncio
+from datetime import datetime
+from functools import wraps
+from typing import Callable, Dict, Tuple
+
+from litellm._service_logger import ServiceTypes
+from litellm.litellm_core_utils.core_helpers import (
+ _get_parent_otel_span_from_kwargs,
+ get_litellm_metadata_from_kwargs,
+)
+
+
+def log_db_metrics(func):
+ """
+ Decorator to log the duration of a DB related function to ServiceLogger()
+
+ Handles logging DB success/failure to ServiceLogger(), which logs to Prometheus, OTEL, Datadog
+
+ When logging Failure it checks if the Exception is a PrismaError, httpx.ConnectError or httpx.TimeoutException and then logs that as a DB Service Failure
+
+ Args:
+ func: The function to be decorated
+
+ Returns:
+ Result from the decorated function
+
+ Raises:
+ Exception: If the decorated function raises an exception
+ """
+
+ @wraps(func)
+ async def wrapper(*args, **kwargs):
+
+ start_time: datetime = datetime.now()
+
+ try:
+ result = await func(*args, **kwargs)
+ end_time: datetime = datetime.now()
+ from litellm.proxy.proxy_server import proxy_logging_obj
+
+ if "PROXY" not in func.__name__:
+ asyncio.create_task(
+ proxy_logging_obj.service_logging_obj.async_service_success_hook(
+ service=ServiceTypes.DB,
+ call_type=func.__name__,
+ parent_otel_span=kwargs.get("parent_otel_span", None),
+ duration=(end_time - start_time).total_seconds(),
+ start_time=start_time,
+ end_time=end_time,
+ event_metadata={
+ "function_name": func.__name__,
+ "function_kwargs": kwargs,
+ "function_args": args,
+ },
+ )
+ )
+ elif (
+ # in litellm custom callbacks kwargs is passed as arg[0]
+ # https://docs.litellm.ai/docs/observability/custom_callback#callback-functions
+ args is not None
+ and len(args) > 1
+ and isinstance(args[1], dict)
+ ):
+ passed_kwargs = args[1]
+ parent_otel_span = _get_parent_otel_span_from_kwargs(
+ kwargs=passed_kwargs
+ )
+ if parent_otel_span is not None:
+ metadata = get_litellm_metadata_from_kwargs(kwargs=passed_kwargs)
+
+ asyncio.create_task(
+ proxy_logging_obj.service_logging_obj.async_service_success_hook(
+ service=ServiceTypes.BATCH_WRITE_TO_DB,
+ call_type=func.__name__,
+ parent_otel_span=parent_otel_span,
+ duration=0.0,
+ start_time=start_time,
+ end_time=end_time,
+ event_metadata=metadata,
+ )
+ )
+ # end of logging to otel
+ return result
+ except Exception as e:
+ end_time: datetime = datetime.now()
+ await _handle_logging_db_exception(
+ e=e,
+ func=func,
+ kwargs=kwargs,
+ args=args,
+ start_time=start_time,
+ end_time=end_time,
+ )
+ raise e
+
+ return wrapper
+
+
+def _is_exception_related_to_db(e: Exception) -> bool:
+ """
+ Returns True if the exception is related to the DB
+ """
+
+ import httpx
+ from prisma.errors import PrismaError
+
+ return isinstance(e, (PrismaError, httpx.ConnectError, httpx.TimeoutException))
+
+
+async def _handle_logging_db_exception(
+ e: Exception,
+ func: Callable,
+ kwargs: Dict,
+ args: Tuple,
+ start_time: datetime,
+ end_time: datetime,
+) -> None:
+ from litellm.proxy.proxy_server import proxy_logging_obj
+
+ # don't log this as a DB Service Failure, if the DB did not raise an exception
+ if _is_exception_related_to_db(e) is not True:
+ return
+
+ await proxy_logging_obj.service_logging_obj.async_service_failure_hook(
+ error=e,
+ service=ServiceTypes.DB,
+ call_type=func.__name__,
+ parent_otel_span=kwargs.get("parent_otel_span"),
+ duration=(end_time - start_time).total_seconds(),
+ start_time=start_time,
+ end_time=end_time,
+ event_metadata={
+ "function_name": func.__name__,
+ "function_kwargs": kwargs,
+ "function_args": args,
+ },
+ )
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/db/prisma_client.py b/.venv/lib/python3.12/site-packages/litellm/proxy/db/prisma_client.py
new file mode 100644
index 00000000..85a3a57a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/db/prisma_client.py
@@ -0,0 +1,278 @@
+"""
+This file contains the PrismaWrapper class, which is used to wrap the Prisma client and handle the RDS IAM token.
+"""
+
+import asyncio
+import os
+import random
+import subprocess
+import time
+import urllib
+import urllib.parse
+from datetime import datetime, timedelta
+from pathlib import Path
+from typing import Any, Optional, Union
+
+from litellm._logging import verbose_proxy_logger
+from litellm.secret_managers.main import str_to_bool
+
+
+class PrismaWrapper:
+ def __init__(self, original_prisma: Any, iam_token_db_auth: bool):
+ self._original_prisma = original_prisma
+ self.iam_token_db_auth = iam_token_db_auth
+
+ def is_token_expired(self, token_url: Optional[str]) -> bool:
+ if token_url is None:
+ return True
+ # Decode the token URL to handle URL-encoded characters
+ decoded_url = urllib.parse.unquote(token_url)
+
+ # Parse the token URL
+ parsed_url = urllib.parse.urlparse(decoded_url)
+
+ # Parse the query parameters from the path component (if they exist there)
+ query_params = urllib.parse.parse_qs(parsed_url.query)
+
+ # Get expiration time from the query parameters
+ expires = query_params.get("X-Amz-Expires", [None])[0]
+ if expires is None:
+ raise ValueError("X-Amz-Expires parameter is missing or invalid.")
+
+ expires_int = int(expires)
+
+ # Get the token's creation time from the X-Amz-Date parameter
+ token_time_str = query_params.get("X-Amz-Date", [""])[0]
+ if not token_time_str:
+ raise ValueError("X-Amz-Date parameter is missing or invalid.")
+
+ # Ensure the token time string is parsed correctly
+ try:
+ token_time = datetime.strptime(token_time_str, "%Y%m%dT%H%M%SZ")
+ except ValueError as e:
+ raise ValueError(f"Invalid X-Amz-Date format: {e}")
+
+ # Calculate the expiration time
+ expiration_time = token_time + timedelta(seconds=expires_int)
+
+ # Current time in UTC
+ current_time = datetime.utcnow()
+
+ # Check if the token is expired
+ return current_time > expiration_time
+
+ def get_rds_iam_token(self) -> Optional[str]:
+ if self.iam_token_db_auth:
+ from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
+
+ db_host = os.getenv("DATABASE_HOST")
+ db_port = os.getenv("DATABASE_PORT")
+ db_user = os.getenv("DATABASE_USER")
+ db_name = os.getenv("DATABASE_NAME")
+ db_schema = os.getenv("DATABASE_SCHEMA")
+
+ token = generate_iam_auth_token(
+ db_host=db_host, db_port=db_port, db_user=db_user
+ )
+
+ # print(f"token: {token}")
+ _db_url = f"postgresql://{db_user}:{token}@{db_host}:{db_port}/{db_name}"
+ if db_schema:
+ _db_url += f"?schema={db_schema}"
+
+ os.environ["DATABASE_URL"] = _db_url
+ return _db_url
+ return None
+
+ async def recreate_prisma_client(
+ self, new_db_url: str, http_client: Optional[Any] = None
+ ):
+ from prisma import Prisma # type: ignore
+
+ if http_client is not None:
+ self._original_prisma = Prisma(http=http_client)
+ else:
+ self._original_prisma = Prisma()
+
+ await self._original_prisma.connect()
+
+ def __getattr__(self, name: str):
+ original_attr = getattr(self._original_prisma, name)
+ if self.iam_token_db_auth:
+ db_url = os.getenv("DATABASE_URL")
+ if self.is_token_expired(db_url):
+ db_url = self.get_rds_iam_token()
+ loop = asyncio.get_event_loop()
+
+ if db_url:
+ if loop.is_running():
+ asyncio.run_coroutine_threadsafe(
+ self.recreate_prisma_client(db_url), loop
+ )
+ else:
+ asyncio.run(self.recreate_prisma_client(db_url))
+ else:
+ raise ValueError("Failed to get RDS IAM token")
+
+ return original_attr
+
+
+class PrismaManager:
+ @staticmethod
+ def _get_prisma_dir() -> str:
+ """Get the path to the migrations directory"""
+ abspath = os.path.abspath(__file__)
+ dname = os.path.dirname(os.path.dirname(abspath))
+ return dname
+
+ @staticmethod
+ def _create_baseline_migration(schema_path: str) -> bool:
+ """Create a baseline migration for an existing database"""
+ prisma_dir = PrismaManager._get_prisma_dir()
+ prisma_dir_path = Path(prisma_dir)
+ init_dir = prisma_dir_path / "migrations" / "0_init"
+
+ # Create migrations/0_init directory
+ init_dir.mkdir(parents=True, exist_ok=True)
+
+ # Generate migration SQL file
+ migration_file = init_dir / "migration.sql"
+
+ try:
+ # Generate migration diff with increased timeout
+ subprocess.run(
+ [
+ "prisma",
+ "migrate",
+ "diff",
+ "--from-empty",
+ "--to-schema-datamodel",
+ str(schema_path),
+ "--script",
+ ],
+ stdout=open(migration_file, "w"),
+ check=True,
+ timeout=30,
+ ) # 30 second timeout
+
+ # Mark migration as applied with increased timeout
+ subprocess.run(
+ [
+ "prisma",
+ "migrate",
+ "resolve",
+ "--applied",
+ "0_init",
+ ],
+ check=True,
+ timeout=30,
+ )
+
+ return True
+ except subprocess.TimeoutExpired:
+ verbose_proxy_logger.warning(
+ "Migration timed out - the database might be under heavy load."
+ )
+ return False
+ except subprocess.CalledProcessError as e:
+ verbose_proxy_logger.warning(f"Error creating baseline migration: {e}")
+ return False
+
+ @staticmethod
+ def setup_database(use_migrate: bool = False) -> bool:
+ """
+ Set up the database using either prisma migrate or prisma db push
+
+ Returns:
+ bool: True if setup was successful, False otherwise
+ """
+
+ for attempt in range(4):
+ original_dir = os.getcwd()
+ prisma_dir = PrismaManager._get_prisma_dir()
+ schema_path = prisma_dir + "/schema.prisma"
+ os.chdir(prisma_dir)
+ try:
+ if use_migrate:
+ verbose_proxy_logger.info("Running prisma migrate deploy")
+ # First try to run migrate deploy directly
+ try:
+ subprocess.run(
+ ["prisma", "migrate", "deploy"],
+ timeout=60,
+ check=True,
+ capture_output=True,
+ text=True,
+ )
+ verbose_proxy_logger.info("prisma migrate deploy completed")
+ return True
+ except subprocess.CalledProcessError as e:
+ # Check if this is the non-empty schema error
+ if (
+ "P3005" in e.stderr
+ and "database schema is not empty" in e.stderr
+ ):
+ # Create baseline migration
+ if PrismaManager._create_baseline_migration(schema_path):
+ # Try migrate deploy again after baseline
+ subprocess.run(
+ ["prisma", "migrate", "deploy"],
+ timeout=60,
+ check=True,
+ )
+ return True
+ else:
+ # If it's a different error, raise it
+ raise e
+ else:
+ # Use prisma db push with increased timeout
+ subprocess.run(
+ ["prisma", "db", "push", "--accept-data-loss"],
+ timeout=60,
+ check=True,
+ )
+ return True
+ except subprocess.TimeoutExpired:
+ verbose_proxy_logger.warning(f"Attempt {attempt + 1} timed out")
+ time.sleep(random.randrange(5, 15))
+ except subprocess.CalledProcessError as e:
+ attempts_left = 3 - attempt
+ retry_msg = (
+ f" Retrying... ({attempts_left} attempts left)"
+ if attempts_left > 0
+ else ""
+ )
+ verbose_proxy_logger.warning(
+ f"The process failed to execute. Details: {e}.{retry_msg}"
+ )
+ time.sleep(random.randrange(5, 15))
+ finally:
+ os.chdir(original_dir)
+ return False
+
+
+def should_update_prisma_schema(
+ disable_updates: Optional[Union[bool, str]] = None
+) -> bool:
+ """
+ Determines if Prisma Schema updates should be applied during startup.
+
+ Args:
+ disable_updates: Controls whether schema updates are disabled.
+ Accepts boolean or string ('true'/'false'). Defaults to checking DISABLE_SCHEMA_UPDATE env var.
+
+ Returns:
+ bool: True if schema updates should be applied, False if updates are disabled.
+
+ Examples:
+ >>> should_update_prisma_schema() # Checks DISABLE_SCHEMA_UPDATE env var
+ >>> should_update_prisma_schema(True) # Explicitly disable updates
+ >>> should_update_prisma_schema("false") # Enable updates using string
+ """
+ if disable_updates is None:
+ disable_updates = os.getenv("DISABLE_SCHEMA_UPDATE", "false")
+
+ if isinstance(disable_updates, str):
+ disable_updates = str_to_bool(disable_updates)
+
+ return not bool(disable_updates)