about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/db
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/db')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/db/base_client.py53
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/db/check_migration.py104
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/db/create_views.py227
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/db/dynamo_db.py71
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/db/log_db_metrics.py143
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/db/prisma_client.py278
6 files changed, 876 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/db/base_client.py b/.venv/lib/python3.12/site-packages/litellm/proxy/db/base_client.py
new file mode 100644
index 00000000..07f0ecdc
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/db/base_client.py
@@ -0,0 +1,53 @@
+from typing import Any, Literal, List
+
+
+class CustomDB:
+    """
+    Implements a base class that we expect any custom db implementation (e.g. DynamoDB) to follow
+    """
+
+    def __init__(self) -> None:
+        pass
+
+    def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
+        """
+        Check if key valid
+        """
+        pass
+
+    def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]):
+        """
+        For new key / user logic
+        """
+        pass
+
+    def update_data(
+        self, key: str, value: Any, table_name: Literal["user", "key", "config"]
+    ):
+        """
+        For cost tracking logic
+        """
+        pass
+
+    def delete_data(
+        self, keys: List[str], table_name: Literal["user", "key", "config"]
+    ):
+        """
+        For /key/delete endpoint s
+        """
+
+    def connect(
+        self,
+    ):
+        """
+        For connecting to db and creating / updating any tables
+        """
+        pass
+
+    def disconnect(
+        self,
+    ):
+        """
+        For closing connection on server shutdown
+        """
+        pass
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/db/check_migration.py b/.venv/lib/python3.12/site-packages/litellm/proxy/db/check_migration.py
new file mode 100644
index 00000000..bf180c11
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/db/check_migration.py
@@ -0,0 +1,104 @@
+"""Module for checking differences between Prisma schema and database."""
+
+import os
+import subprocess
+from typing import List, Optional, Tuple
+
+from litellm._logging import verbose_logger
+
+
+def extract_sql_commands(diff_output: str) -> List[str]:
+    """
+    Extract SQL commands from the Prisma migrate diff output.
+    Args:
+        diff_output (str): The full output from prisma migrate diff.
+    Returns:
+        List[str]: A list of SQL commands extracted from the diff output.
+    """
+    # Split the output into lines and remove empty lines
+    lines = [line.strip() for line in diff_output.split("\n") if line.strip()]
+
+    sql_commands = []
+    current_command = ""
+    in_sql_block = False
+
+    for line in lines:
+        if line.startswith("-- "):  # Comment line, likely a table operation description
+            if in_sql_block and current_command:
+                sql_commands.append(current_command.strip())
+                current_command = ""
+            in_sql_block = True
+        elif in_sql_block:
+            if line.endswith(";"):
+                current_command += line
+                sql_commands.append(current_command.strip())
+                current_command = ""
+                in_sql_block = False
+            else:
+                current_command += line + " "
+
+    # Add any remaining command
+    if current_command:
+        sql_commands.append(current_command.strip())
+
+    return sql_commands
+
+
+def check_prisma_schema_diff_helper(db_url: str) -> Tuple[bool, List[str]]:
+    """Checks for differences between current database and Prisma schema.
+    Returns:
+        A tuple containing:
+        - A boolean indicating if differences were found (True) or not (False).
+        - A string with the diff output or error message.
+    Raises:
+        subprocess.CalledProcessError: If the Prisma command fails.
+        Exception: For any other errors during execution.
+    """
+    verbose_logger.debug("Checking for Prisma schema diff...")  # noqa: T201
+    try:
+        result = subprocess.run(
+            [
+                "prisma",
+                "migrate",
+                "diff",
+                "--from-url",
+                db_url,
+                "--to-schema-datamodel",
+                "./schema.prisma",
+                "--script",
+            ],
+            capture_output=True,
+            text=True,
+            check=True,
+        )
+
+        # return True, "Migration diff generated successfully."
+        sql_commands = extract_sql_commands(result.stdout)
+
+        if sql_commands:
+            print("Changes to DB Schema detected")  # noqa: T201
+            print("Required SQL commands:")  # noqa: T201
+            for command in sql_commands:
+                print(command)  # noqa: T201
+            return True, sql_commands
+        else:
+            return False, []
+    except subprocess.CalledProcessError as e:
+        error_message = f"Failed to generate migration diff. Error: {e.stderr}"
+        print(error_message)  # noqa: T201
+        return False, []
+
+
+def check_prisma_schema_diff(db_url: Optional[str] = None) -> None:
+    """Main function to run the Prisma schema diff check."""
+    if db_url is None:
+        db_url = os.getenv("DATABASE_URL")
+        if db_url is None:
+            raise Exception("DATABASE_URL not set")
+    has_diff, message = check_prisma_schema_diff_helper(db_url)
+    if has_diff:
+        verbose_logger.exception(
+            "🚨🚨🚨 prisma schema out of sync with db. Consider running these sql_commands to sync the two - {}".format(
+                message
+            )
+        )
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/db/create_views.py b/.venv/lib/python3.12/site-packages/litellm/proxy/db/create_views.py
new file mode 100644
index 00000000..e9303077
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/db/create_views.py
@@ -0,0 +1,227 @@
+from typing import Any
+
+from litellm import verbose_logger
+
+_db = Any
+
+
+async def create_missing_views(db: _db):  # noqa: PLR0915
+    """
+    --------------------------------------------------
+    NOTE: Copy of `litellm/db_scripts/create_views.py`.
+    --------------------------------------------------
+    Checks if the LiteLLM_VerificationTokenView and MonthlyGlobalSpend exists in the user's db.
+
+    LiteLLM_VerificationTokenView: This view is used for getting the token + team data in user_api_key_auth
+
+    MonthlyGlobalSpend: This view is used for the admin view to see global spend for this month
+
+    If the view doesn't exist, one will be created.
+    """
+    try:
+        # Try to select one row from the view
+        await db.query_raw("""SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1""")
+        print("LiteLLM_VerificationTokenView Exists!")  # noqa
+    except Exception:
+        # If an error occurs, the view does not exist, so create it
+        await db.execute_raw(
+            """
+                CREATE VIEW "LiteLLM_VerificationTokenView" AS
+                SELECT 
+                v.*, 
+                t.spend AS team_spend, 
+                t.max_budget AS team_max_budget, 
+                t.tpm_limit AS team_tpm_limit, 
+                t.rpm_limit AS team_rpm_limit
+                FROM "LiteLLM_VerificationToken" v
+                LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
+            """
+        )
+
+        print("LiteLLM_VerificationTokenView Created!")  # noqa
+
+    try:
+        await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""")
+        print("MonthlyGlobalSpend Exists!")  # noqa
+    except Exception:
+        sql_query = """
+        CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS 
+        SELECT
+        DATE("startTime") AS date, 
+        SUM("spend") AS spend 
+        FROM 
+        "LiteLLM_SpendLogs" 
+        WHERE 
+        "startTime" >= (CURRENT_DATE - INTERVAL '30 days')
+        GROUP BY 
+        DATE("startTime");
+        """
+        await db.execute_raw(query=sql_query)
+
+        print("MonthlyGlobalSpend Created!")  # noqa
+
+    try:
+        await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""")
+        print("Last30dKeysBySpend Exists!")  # noqa
+    except Exception:
+        sql_query = """
+        CREATE OR REPLACE VIEW "Last30dKeysBySpend" AS
+        SELECT 
+        L."api_key", 
+        V."key_alias",
+        V."key_name",
+        SUM(L."spend") AS total_spend
+        FROM
+        "LiteLLM_SpendLogs" L
+        LEFT JOIN 
+        "LiteLLM_VerificationToken" V
+        ON
+        L."api_key" = V."token"
+        WHERE
+        L."startTime" >= (CURRENT_DATE - INTERVAL '30 days')
+        GROUP BY
+        L."api_key", V."key_alias", V."key_name"
+        ORDER BY
+        total_spend DESC;
+        """
+        await db.execute_raw(query=sql_query)
+
+        print("Last30dKeysBySpend Created!")  # noqa
+
+    try:
+        await db.query_raw("""SELECT 1 FROM "Last30dModelsBySpend" LIMIT 1""")
+        print("Last30dModelsBySpend Exists!")  # noqa
+    except Exception:
+        sql_query = """
+        CREATE OR REPLACE VIEW "Last30dModelsBySpend" AS
+        SELECT
+        "model",
+        SUM("spend") AS total_spend
+        FROM
+        "LiteLLM_SpendLogs"
+        WHERE
+        "startTime" >= (CURRENT_DATE - INTERVAL '30 days')
+        AND "model" != ''
+        GROUP BY
+        "model"
+        ORDER BY
+        total_spend DESC;
+        """
+        await db.execute_raw(query=sql_query)
+
+        print("Last30dModelsBySpend Created!")  # noqa
+    try:
+        await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpendPerKey" LIMIT 1""")
+        print("MonthlyGlobalSpendPerKey Exists!")  # noqa
+    except Exception:
+        sql_query = """
+            CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerKey" AS 
+            SELECT
+            DATE("startTime") AS date, 
+            SUM("spend") AS spend,
+            api_key as api_key
+            FROM 
+            "LiteLLM_SpendLogs" 
+            WHERE 
+            "startTime" >= (CURRENT_DATE - INTERVAL '30 days')
+            GROUP BY 
+            DATE("startTime"),
+            api_key;
+        """
+        await db.execute_raw(query=sql_query)
+
+        print("MonthlyGlobalSpendPerKey Created!")  # noqa
+    try:
+        await db.query_raw(
+            """SELECT 1 FROM "MonthlyGlobalSpendPerUserPerKey" LIMIT 1"""
+        )
+        print("MonthlyGlobalSpendPerUserPerKey Exists!")  # noqa
+    except Exception:
+        sql_query = """
+            CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerUserPerKey" AS 
+            SELECT
+            DATE("startTime") AS date, 
+            SUM("spend") AS spend,
+            api_key as api_key,
+            "user" as "user"
+            FROM 
+            "LiteLLM_SpendLogs" 
+            WHERE 
+            "startTime" >= (CURRENT_DATE - INTERVAL '30 days')
+            GROUP BY 
+            DATE("startTime"),
+            "user",
+            api_key;
+        """
+        await db.execute_raw(query=sql_query)
+
+        print("MonthlyGlobalSpendPerUserPerKey Created!")  # noqa
+
+    try:
+        await db.query_raw("""SELECT 1 FROM "DailyTagSpend" LIMIT 1""")
+        print("DailyTagSpend Exists!")  # noqa
+    except Exception:
+        sql_query = """
+        CREATE OR REPLACE VIEW "DailyTagSpend" AS
+        SELECT
+            jsonb_array_elements_text(request_tags) AS individual_request_tag,
+            DATE(s."startTime") AS spend_date,
+            COUNT(*) AS log_count,
+            SUM(spend) AS total_spend
+        FROM "LiteLLM_SpendLogs" s
+        GROUP BY individual_request_tag, DATE(s."startTime");
+        """
+        await db.execute_raw(query=sql_query)
+
+        print("DailyTagSpend Created!")  # noqa
+
+    try:
+        await db.query_raw("""SELECT 1 FROM "Last30dTopEndUsersSpend" LIMIT 1""")
+        print("Last30dTopEndUsersSpend Exists!")  # noqa
+    except Exception:
+        sql_query = """
+        CREATE VIEW "Last30dTopEndUsersSpend" AS
+        SELECT end_user, COUNT(*) AS total_events, SUM(spend) AS total_spend
+        FROM "LiteLLM_SpendLogs"
+        WHERE end_user <> '' AND end_user <> user
+        AND "startTime" >= CURRENT_DATE - INTERVAL '30 days'
+        GROUP BY end_user
+        ORDER BY total_spend DESC
+        LIMIT 100;
+        """
+        await db.execute_raw(query=sql_query)
+
+        print("Last30dTopEndUsersSpend Created!")  # noqa
+
+    return
+
+
+async def should_create_missing_views(db: _db) -> bool:
+    """
+    Run only on first time startup.
+
+    If SpendLogs table already has values, then don't create views on startup.
+    """
+
+    sql_query = """
+    SELECT reltuples::BIGINT
+    FROM pg_class
+    WHERE oid = '"LiteLLM_SpendLogs"'::regclass;
+    """
+
+    result = await db.query_raw(query=sql_query)
+
+    verbose_logger.debug("Estimated Row count of LiteLLM_SpendLogs = {}".format(result))
+    if (
+        result
+        and isinstance(result, list)
+        and len(result) > 0
+        and isinstance(result[0], dict)
+        and "reltuples" in result[0]
+        and result[0]["reltuples"]
+        and (result[0]["reltuples"] == 0 or result[0]["reltuples"] == -1)
+    ):
+        verbose_logger.debug("Should create views")
+        return True
+
+    return False
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/db/dynamo_db.py b/.venv/lib/python3.12/site-packages/litellm/proxy/db/dynamo_db.py
new file mode 100644
index 00000000..628509d9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/db/dynamo_db.py
@@ -0,0 +1,71 @@
+"""
+Deprecated. Only PostgresSQL is supported.
+"""
+
+from litellm._logging import verbose_proxy_logger
+from litellm.proxy._types import DynamoDBArgs
+from litellm.proxy.db.base_client import CustomDB
+
+
+class DynamoDBWrapper(CustomDB):
+    from aiodynamo.credentials import Credentials, StaticCredentials
+
+    credentials: Credentials
+
+    def __init__(self, database_arguments: DynamoDBArgs):
+        from aiodynamo.models import PayPerRequest, Throughput
+
+        self.throughput_type = None
+        if database_arguments.billing_mode == "PAY_PER_REQUEST":
+            self.throughput_type = PayPerRequest()
+        elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
+            if (
+                database_arguments.read_capacity_units is not None
+                and isinstance(database_arguments.read_capacity_units, int)
+                and database_arguments.write_capacity_units is not None
+                and isinstance(database_arguments.write_capacity_units, int)
+            ):
+                self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units)  # type: ignore
+            else:
+                raise Exception(
+                    f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}"
+                )
+        self.database_arguments = database_arguments
+        self.region_name = database_arguments.region_name
+
+    def set_env_vars_based_on_arn(self):
+        if self.database_arguments.aws_role_name is None:
+            return
+        verbose_proxy_logger.debug(
+            f"DynamoDB: setting env vars based on arn={self.database_arguments.aws_role_name}"
+        )
+        import os
+
+        import boto3
+
+        sts_client = boto3.client("sts")
+
+        # call 1
+        sts_client.assume_role_with_web_identity(
+            RoleArn=self.database_arguments.aws_role_name,
+            RoleSessionName=self.database_arguments.aws_session_name,
+            WebIdentityToken=self.database_arguments.aws_web_identity_token,
+        )
+
+        # call 2
+        assumed_role = sts_client.assume_role(
+            RoleArn=self.database_arguments.assume_role_aws_role_name,
+            RoleSessionName=self.database_arguments.assume_role_aws_session_name,
+        )
+
+        aws_access_key_id = assumed_role["Credentials"]["AccessKeyId"]
+        aws_secret_access_key = assumed_role["Credentials"]["SecretAccessKey"]
+        aws_session_token = assumed_role["Credentials"]["SessionToken"]
+
+        verbose_proxy_logger.debug(
+            f"Got STS assumed Role, aws_access_key_id={aws_access_key_id}"
+        )
+        # set these in the env so aiodynamo can use them
+        os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
+        os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
+        os.environ["AWS_SESSION_TOKEN"] = aws_session_token
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/db/log_db_metrics.py b/.venv/lib/python3.12/site-packages/litellm/proxy/db/log_db_metrics.py
new file mode 100644
index 00000000..9bd33507
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/db/log_db_metrics.py
@@ -0,0 +1,143 @@
+"""
+Handles logging DB success/failure to ServiceLogger()
+
+ServiceLogger() then sends DB logs to Prometheus, OTEL, Datadog etc
+"""
+
+import asyncio
+from datetime import datetime
+from functools import wraps
+from typing import Callable, Dict, Tuple
+
+from litellm._service_logger import ServiceTypes
+from litellm.litellm_core_utils.core_helpers import (
+    _get_parent_otel_span_from_kwargs,
+    get_litellm_metadata_from_kwargs,
+)
+
+
+def log_db_metrics(func):
+    """
+    Decorator to log the duration of a DB related function to ServiceLogger()
+
+    Handles logging DB success/failure to ServiceLogger(), which logs to Prometheus, OTEL, Datadog
+
+    When logging Failure it checks if the Exception is a PrismaError, httpx.ConnectError or httpx.TimeoutException and then logs that as a DB Service Failure
+
+    Args:
+        func: The function to be decorated
+
+    Returns:
+        Result from the decorated function
+
+    Raises:
+        Exception: If the decorated function raises an exception
+    """
+
+    @wraps(func)
+    async def wrapper(*args, **kwargs):
+
+        start_time: datetime = datetime.now()
+
+        try:
+            result = await func(*args, **kwargs)
+            end_time: datetime = datetime.now()
+            from litellm.proxy.proxy_server import proxy_logging_obj
+
+            if "PROXY" not in func.__name__:
+                asyncio.create_task(
+                    proxy_logging_obj.service_logging_obj.async_service_success_hook(
+                        service=ServiceTypes.DB,
+                        call_type=func.__name__,
+                        parent_otel_span=kwargs.get("parent_otel_span", None),
+                        duration=(end_time - start_time).total_seconds(),
+                        start_time=start_time,
+                        end_time=end_time,
+                        event_metadata={
+                            "function_name": func.__name__,
+                            "function_kwargs": kwargs,
+                            "function_args": args,
+                        },
+                    )
+                )
+            elif (
+                # in litellm custom callbacks kwargs is passed as arg[0]
+                # https://docs.litellm.ai/docs/observability/custom_callback#callback-functions
+                args is not None
+                and len(args) > 1
+                and isinstance(args[1], dict)
+            ):
+                passed_kwargs = args[1]
+                parent_otel_span = _get_parent_otel_span_from_kwargs(
+                    kwargs=passed_kwargs
+                )
+                if parent_otel_span is not None:
+                    metadata = get_litellm_metadata_from_kwargs(kwargs=passed_kwargs)
+
+                    asyncio.create_task(
+                        proxy_logging_obj.service_logging_obj.async_service_success_hook(
+                            service=ServiceTypes.BATCH_WRITE_TO_DB,
+                            call_type=func.__name__,
+                            parent_otel_span=parent_otel_span,
+                            duration=0.0,
+                            start_time=start_time,
+                            end_time=end_time,
+                            event_metadata=metadata,
+                        )
+                    )
+            # end of logging to otel
+            return result
+        except Exception as e:
+            end_time: datetime = datetime.now()
+            await _handle_logging_db_exception(
+                e=e,
+                func=func,
+                kwargs=kwargs,
+                args=args,
+                start_time=start_time,
+                end_time=end_time,
+            )
+            raise e
+
+    return wrapper
+
+
+def _is_exception_related_to_db(e: Exception) -> bool:
+    """
+    Returns True if the exception is related to the DB
+    """
+
+    import httpx
+    from prisma.errors import PrismaError
+
+    return isinstance(e, (PrismaError, httpx.ConnectError, httpx.TimeoutException))
+
+
+async def _handle_logging_db_exception(
+    e: Exception,
+    func: Callable,
+    kwargs: Dict,
+    args: Tuple,
+    start_time: datetime,
+    end_time: datetime,
+) -> None:
+    from litellm.proxy.proxy_server import proxy_logging_obj
+
+    # don't log this as a DB Service Failure, if the DB did not raise an exception
+    if _is_exception_related_to_db(e) is not True:
+        return
+
+    await proxy_logging_obj.service_logging_obj.async_service_failure_hook(
+        error=e,
+        service=ServiceTypes.DB,
+        call_type=func.__name__,
+        parent_otel_span=kwargs.get("parent_otel_span"),
+        duration=(end_time - start_time).total_seconds(),
+        start_time=start_time,
+        end_time=end_time,
+        event_metadata={
+            "function_name": func.__name__,
+            "function_kwargs": kwargs,
+            "function_args": args,
+        },
+    )
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/db/prisma_client.py b/.venv/lib/python3.12/site-packages/litellm/proxy/db/prisma_client.py
new file mode 100644
index 00000000..85a3a57a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/db/prisma_client.py
@@ -0,0 +1,278 @@
+"""
+This file contains the PrismaWrapper class, which is used to wrap the Prisma client and handle the RDS IAM token.
+"""
+
+import asyncio
+import os
+import random
+import subprocess
+import time
+import urllib
+import urllib.parse
+from datetime import datetime, timedelta
+from pathlib import Path
+from typing import Any, Optional, Union
+
+from litellm._logging import verbose_proxy_logger
+from litellm.secret_managers.main import str_to_bool
+
+
+class PrismaWrapper:
+    def __init__(self, original_prisma: Any, iam_token_db_auth: bool):
+        self._original_prisma = original_prisma
+        self.iam_token_db_auth = iam_token_db_auth
+
+    def is_token_expired(self, token_url: Optional[str]) -> bool:
+        if token_url is None:
+            return True
+        # Decode the token URL to handle URL-encoded characters
+        decoded_url = urllib.parse.unquote(token_url)
+
+        # Parse the token URL
+        parsed_url = urllib.parse.urlparse(decoded_url)
+
+        # Parse the query parameters from the path component (if they exist there)
+        query_params = urllib.parse.parse_qs(parsed_url.query)
+
+        # Get expiration time from the query parameters
+        expires = query_params.get("X-Amz-Expires", [None])[0]
+        if expires is None:
+            raise ValueError("X-Amz-Expires parameter is missing or invalid.")
+
+        expires_int = int(expires)
+
+        # Get the token's creation time from the X-Amz-Date parameter
+        token_time_str = query_params.get("X-Amz-Date", [""])[0]
+        if not token_time_str:
+            raise ValueError("X-Amz-Date parameter is missing or invalid.")
+
+        # Ensure the token time string is parsed correctly
+        try:
+            token_time = datetime.strptime(token_time_str, "%Y%m%dT%H%M%SZ")
+        except ValueError as e:
+            raise ValueError(f"Invalid X-Amz-Date format: {e}")
+
+        # Calculate the expiration time
+        expiration_time = token_time + timedelta(seconds=expires_int)
+
+        # Current time in UTC
+        current_time = datetime.utcnow()
+
+        # Check if the token is expired
+        return current_time > expiration_time
+
+    def get_rds_iam_token(self) -> Optional[str]:
+        if self.iam_token_db_auth:
+            from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
+
+            db_host = os.getenv("DATABASE_HOST")
+            db_port = os.getenv("DATABASE_PORT")
+            db_user = os.getenv("DATABASE_USER")
+            db_name = os.getenv("DATABASE_NAME")
+            db_schema = os.getenv("DATABASE_SCHEMA")
+
+            token = generate_iam_auth_token(
+                db_host=db_host, db_port=db_port, db_user=db_user
+            )
+
+            # print(f"token: {token}")
+            _db_url = f"postgresql://{db_user}:{token}@{db_host}:{db_port}/{db_name}"
+            if db_schema:
+                _db_url += f"?schema={db_schema}"
+
+            os.environ["DATABASE_URL"] = _db_url
+            return _db_url
+        return None
+
+    async def recreate_prisma_client(
+        self, new_db_url: str, http_client: Optional[Any] = None
+    ):
+        from prisma import Prisma  # type: ignore
+
+        if http_client is not None:
+            self._original_prisma = Prisma(http=http_client)
+        else:
+            self._original_prisma = Prisma()
+
+        await self._original_prisma.connect()
+
+    def __getattr__(self, name: str):
+        original_attr = getattr(self._original_prisma, name)
+        if self.iam_token_db_auth:
+            db_url = os.getenv("DATABASE_URL")
+            if self.is_token_expired(db_url):
+                db_url = self.get_rds_iam_token()
+                loop = asyncio.get_event_loop()
+
+                if db_url:
+                    if loop.is_running():
+                        asyncio.run_coroutine_threadsafe(
+                            self.recreate_prisma_client(db_url), loop
+                        )
+                    else:
+                        asyncio.run(self.recreate_prisma_client(db_url))
+                else:
+                    raise ValueError("Failed to get RDS IAM token")
+
+        return original_attr
+
+
+class PrismaManager:
+    @staticmethod
+    def _get_prisma_dir() -> str:
+        """Get the path to the migrations directory"""
+        abspath = os.path.abspath(__file__)
+        dname = os.path.dirname(os.path.dirname(abspath))
+        return dname
+
+    @staticmethod
+    def _create_baseline_migration(schema_path: str) -> bool:
+        """Create a baseline migration for an existing database"""
+        prisma_dir = PrismaManager._get_prisma_dir()
+        prisma_dir_path = Path(prisma_dir)
+        init_dir = prisma_dir_path / "migrations" / "0_init"
+
+        # Create migrations/0_init directory
+        init_dir.mkdir(parents=True, exist_ok=True)
+
+        # Generate migration SQL file
+        migration_file = init_dir / "migration.sql"
+
+        try:
+            # Generate migration diff with increased timeout
+            subprocess.run(
+                [
+                    "prisma",
+                    "migrate",
+                    "diff",
+                    "--from-empty",
+                    "--to-schema-datamodel",
+                    str(schema_path),
+                    "--script",
+                ],
+                stdout=open(migration_file, "w"),
+                check=True,
+                timeout=30,
+            )  # 30 second timeout
+
+            # Mark migration as applied with increased timeout
+            subprocess.run(
+                [
+                    "prisma",
+                    "migrate",
+                    "resolve",
+                    "--applied",
+                    "0_init",
+                ],
+                check=True,
+                timeout=30,
+            )
+
+            return True
+        except subprocess.TimeoutExpired:
+            verbose_proxy_logger.warning(
+                "Migration timed out - the database might be under heavy load."
+            )
+            return False
+        except subprocess.CalledProcessError as e:
+            verbose_proxy_logger.warning(f"Error creating baseline migration: {e}")
+            return False
+
+    @staticmethod
+    def setup_database(use_migrate: bool = False) -> bool:
+        """
+        Set up the database using either prisma migrate or prisma db push
+
+        Returns:
+            bool: True if setup was successful, False otherwise
+        """
+
+        for attempt in range(4):
+            original_dir = os.getcwd()
+            prisma_dir = PrismaManager._get_prisma_dir()
+            schema_path = prisma_dir + "/schema.prisma"
+            os.chdir(prisma_dir)
+            try:
+                if use_migrate:
+                    verbose_proxy_logger.info("Running prisma migrate deploy")
+                    # First try to run migrate deploy directly
+                    try:
+                        subprocess.run(
+                            ["prisma", "migrate", "deploy"],
+                            timeout=60,
+                            check=True,
+                            capture_output=True,
+                            text=True,
+                        )
+                        verbose_proxy_logger.info("prisma migrate deploy completed")
+                        return True
+                    except subprocess.CalledProcessError as e:
+                        # Check if this is the non-empty schema error
+                        if (
+                            "P3005" in e.stderr
+                            and "database schema is not empty" in e.stderr
+                        ):
+                            # Create baseline migration
+                            if PrismaManager._create_baseline_migration(schema_path):
+                                # Try migrate deploy again after baseline
+                                subprocess.run(
+                                    ["prisma", "migrate", "deploy"],
+                                    timeout=60,
+                                    check=True,
+                                )
+                                return True
+                        else:
+                            # If it's a different error, raise it
+                            raise e
+                else:
+                    # Use prisma db push with increased timeout
+                    subprocess.run(
+                        ["prisma", "db", "push", "--accept-data-loss"],
+                        timeout=60,
+                        check=True,
+                    )
+                    return True
+            except subprocess.TimeoutExpired:
+                verbose_proxy_logger.warning(f"Attempt {attempt + 1} timed out")
+                time.sleep(random.randrange(5, 15))
+            except subprocess.CalledProcessError as e:
+                attempts_left = 3 - attempt
+                retry_msg = (
+                    f" Retrying... ({attempts_left} attempts left)"
+                    if attempts_left > 0
+                    else ""
+                )
+                verbose_proxy_logger.warning(
+                    f"The process failed to execute. Details: {e}.{retry_msg}"
+                )
+                time.sleep(random.randrange(5, 15))
+            finally:
+                os.chdir(original_dir)
+        return False
+
+
+def should_update_prisma_schema(
+    disable_updates: Optional[Union[bool, str]] = None
+) -> bool:
+    """
+    Determines if Prisma Schema updates should be applied during startup.
+
+    Args:
+        disable_updates: Controls whether schema updates are disabled.
+            Accepts boolean or string ('true'/'false'). Defaults to checking DISABLE_SCHEMA_UPDATE env var.
+
+    Returns:
+        bool: True if schema updates should be applied, False if updates are disabled.
+
+    Examples:
+        >>> should_update_prisma_schema()  # Checks DISABLE_SCHEMA_UPDATE env var
+        >>> should_update_prisma_schema(True)  # Explicitly disable updates
+        >>> should_update_prisma_schema("false")  # Enable updates using string
+    """
+    if disable_updates is None:
+        disable_updates = os.getenv("DISABLE_SCHEMA_UPDATE", "false")
+
+    if isinstance(disable_updates, str):
+        disable_updates = str_to_bool(disable_updates)
+
+    return not bool(disable_updates)