aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/db/prisma_client.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/db/prisma_client.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/db/prisma_client.py278
1 files changed, 278 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/db/prisma_client.py b/.venv/lib/python3.12/site-packages/litellm/proxy/db/prisma_client.py
new file mode 100644
index 00000000..85a3a57a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/db/prisma_client.py
@@ -0,0 +1,278 @@
+"""
+This file contains the PrismaWrapper class, which is used to wrap the Prisma client and handle the RDS IAM token.
+"""
+
+import asyncio
+import os
+import random
+import subprocess
+import time
+import urllib
+import urllib.parse
+from datetime import datetime, timedelta
+from pathlib import Path
+from typing import Any, Optional, Union
+
+from litellm._logging import verbose_proxy_logger
+from litellm.secret_managers.main import str_to_bool
+
+
+class PrismaWrapper:
+ def __init__(self, original_prisma: Any, iam_token_db_auth: bool):
+ self._original_prisma = original_prisma
+ self.iam_token_db_auth = iam_token_db_auth
+
+ def is_token_expired(self, token_url: Optional[str]) -> bool:
+ if token_url is None:
+ return True
+ # Decode the token URL to handle URL-encoded characters
+ decoded_url = urllib.parse.unquote(token_url)
+
+ # Parse the token URL
+ parsed_url = urllib.parse.urlparse(decoded_url)
+
+ # Parse the query parameters from the path component (if they exist there)
+ query_params = urllib.parse.parse_qs(parsed_url.query)
+
+ # Get expiration time from the query parameters
+ expires = query_params.get("X-Amz-Expires", [None])[0]
+ if expires is None:
+ raise ValueError("X-Amz-Expires parameter is missing or invalid.")
+
+ expires_int = int(expires)
+
+ # Get the token's creation time from the X-Amz-Date parameter
+ token_time_str = query_params.get("X-Amz-Date", [""])[0]
+ if not token_time_str:
+ raise ValueError("X-Amz-Date parameter is missing or invalid.")
+
+ # Ensure the token time string is parsed correctly
+ try:
+ token_time = datetime.strptime(token_time_str, "%Y%m%dT%H%M%SZ")
+ except ValueError as e:
+ raise ValueError(f"Invalid X-Amz-Date format: {e}")
+
+ # Calculate the expiration time
+ expiration_time = token_time + timedelta(seconds=expires_int)
+
+ # Current time in UTC
+ current_time = datetime.utcnow()
+
+ # Check if the token is expired
+ return current_time > expiration_time
+
+ def get_rds_iam_token(self) -> Optional[str]:
+ if self.iam_token_db_auth:
+ from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
+
+ db_host = os.getenv("DATABASE_HOST")
+ db_port = os.getenv("DATABASE_PORT")
+ db_user = os.getenv("DATABASE_USER")
+ db_name = os.getenv("DATABASE_NAME")
+ db_schema = os.getenv("DATABASE_SCHEMA")
+
+ token = generate_iam_auth_token(
+ db_host=db_host, db_port=db_port, db_user=db_user
+ )
+
+ # print(f"token: {token}")
+ _db_url = f"postgresql://{db_user}:{token}@{db_host}:{db_port}/{db_name}"
+ if db_schema:
+ _db_url += f"?schema={db_schema}"
+
+ os.environ["DATABASE_URL"] = _db_url
+ return _db_url
+ return None
+
+ async def recreate_prisma_client(
+ self, new_db_url: str, http_client: Optional[Any] = None
+ ):
+ from prisma import Prisma # type: ignore
+
+ if http_client is not None:
+ self._original_prisma = Prisma(http=http_client)
+ else:
+ self._original_prisma = Prisma()
+
+ await self._original_prisma.connect()
+
+ def __getattr__(self, name: str):
+ original_attr = getattr(self._original_prisma, name)
+ if self.iam_token_db_auth:
+ db_url = os.getenv("DATABASE_URL")
+ if self.is_token_expired(db_url):
+ db_url = self.get_rds_iam_token()
+ loop = asyncio.get_event_loop()
+
+ if db_url:
+ if loop.is_running():
+ asyncio.run_coroutine_threadsafe(
+ self.recreate_prisma_client(db_url), loop
+ )
+ else:
+ asyncio.run(self.recreate_prisma_client(db_url))
+ else:
+ raise ValueError("Failed to get RDS IAM token")
+
+ return original_attr
+
+
+class PrismaManager:
+ @staticmethod
+ def _get_prisma_dir() -> str:
+ """Get the path to the migrations directory"""
+ abspath = os.path.abspath(__file__)
+ dname = os.path.dirname(os.path.dirname(abspath))
+ return dname
+
+ @staticmethod
+ def _create_baseline_migration(schema_path: str) -> bool:
+ """Create a baseline migration for an existing database"""
+ prisma_dir = PrismaManager._get_prisma_dir()
+ prisma_dir_path = Path(prisma_dir)
+ init_dir = prisma_dir_path / "migrations" / "0_init"
+
+ # Create migrations/0_init directory
+ init_dir.mkdir(parents=True, exist_ok=True)
+
+ # Generate migration SQL file
+ migration_file = init_dir / "migration.sql"
+
+ try:
+ # Generate migration diff with increased timeout
+ subprocess.run(
+ [
+ "prisma",
+ "migrate",
+ "diff",
+ "--from-empty",
+ "--to-schema-datamodel",
+ str(schema_path),
+ "--script",
+ ],
+ stdout=open(migration_file, "w"),
+ check=True,
+ timeout=30,
+ ) # 30 second timeout
+
+ # Mark migration as applied with increased timeout
+ subprocess.run(
+ [
+ "prisma",
+ "migrate",
+ "resolve",
+ "--applied",
+ "0_init",
+ ],
+ check=True,
+ timeout=30,
+ )
+
+ return True
+ except subprocess.TimeoutExpired:
+ verbose_proxy_logger.warning(
+ "Migration timed out - the database might be under heavy load."
+ )
+ return False
+ except subprocess.CalledProcessError as e:
+ verbose_proxy_logger.warning(f"Error creating baseline migration: {e}")
+ return False
+
+ @staticmethod
+ def setup_database(use_migrate: bool = False) -> bool:
+ """
+ Set up the database using either prisma migrate or prisma db push
+
+ Returns:
+ bool: True if setup was successful, False otherwise
+ """
+
+ for attempt in range(4):
+ original_dir = os.getcwd()
+ prisma_dir = PrismaManager._get_prisma_dir()
+ schema_path = prisma_dir + "/schema.prisma"
+ os.chdir(prisma_dir)
+ try:
+ if use_migrate:
+ verbose_proxy_logger.info("Running prisma migrate deploy")
+ # First try to run migrate deploy directly
+ try:
+ subprocess.run(
+ ["prisma", "migrate", "deploy"],
+ timeout=60,
+ check=True,
+ capture_output=True,
+ text=True,
+ )
+ verbose_proxy_logger.info("prisma migrate deploy completed")
+ return True
+ except subprocess.CalledProcessError as e:
+ # Check if this is the non-empty schema error
+ if (
+ "P3005" in e.stderr
+ and "database schema is not empty" in e.stderr
+ ):
+ # Create baseline migration
+ if PrismaManager._create_baseline_migration(schema_path):
+ # Try migrate deploy again after baseline
+ subprocess.run(
+ ["prisma", "migrate", "deploy"],
+ timeout=60,
+ check=True,
+ )
+ return True
+ else:
+ # If it's a different error, raise it
+ raise e
+ else:
+ # Use prisma db push with increased timeout
+ subprocess.run(
+ ["prisma", "db", "push", "--accept-data-loss"],
+ timeout=60,
+ check=True,
+ )
+ return True
+ except subprocess.TimeoutExpired:
+ verbose_proxy_logger.warning(f"Attempt {attempt + 1} timed out")
+ time.sleep(random.randrange(5, 15))
+ except subprocess.CalledProcessError as e:
+ attempts_left = 3 - attempt
+ retry_msg = (
+ f" Retrying... ({attempts_left} attempts left)"
+ if attempts_left > 0
+ else ""
+ )
+ verbose_proxy_logger.warning(
+ f"The process failed to execute. Details: {e}.{retry_msg}"
+ )
+ time.sleep(random.randrange(5, 15))
+ finally:
+ os.chdir(original_dir)
+ return False
+
+
+def should_update_prisma_schema(
+ disable_updates: Optional[Union[bool, str]] = None
+) -> bool:
+ """
+ Determines if Prisma Schema updates should be applied during startup.
+
+ Args:
+ disable_updates: Controls whether schema updates are disabled.
+ Accepts boolean or string ('true'/'false'). Defaults to checking DISABLE_SCHEMA_UPDATE env var.
+
+ Returns:
+ bool: True if schema updates should be applied, False if updates are disabled.
+
+ Examples:
+ >>> should_update_prisma_schema() # Checks DISABLE_SCHEMA_UPDATE env var
+ >>> should_update_prisma_schema(True) # Explicitly disable updates
+ >>> should_update_prisma_schema("false") # Enable updates using string
+ """
+ if disable_updates is None:
+ disable_updates = os.getenv("DISABLE_SCHEMA_UPDATE", "false")
+
+ if isinstance(disable_updates, str):
+ disable_updates = str_to_bool(disable_updates)
+
+ return not bool(disable_updates)