diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/db')
6 files changed, 876 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/db/base_client.py b/.venv/lib/python3.12/site-packages/litellm/proxy/db/base_client.py new file mode 100644 index 00000000..07f0ecdc --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/db/base_client.py @@ -0,0 +1,53 @@ +from typing import Any, Literal, List + + +class CustomDB: + """ + Implements a base class that we expect any custom db implementation (e.g. DynamoDB) to follow + """ + + def __init__(self) -> None: + pass + + def get_data(self, key: str, table_name: Literal["user", "key", "config"]): + """ + Check if key valid + """ + pass + + def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]): + """ + For new key / user logic + """ + pass + + def update_data( + self, key: str, value: Any, table_name: Literal["user", "key", "config"] + ): + """ + For cost tracking logic + """ + pass + + def delete_data( + self, keys: List[str], table_name: Literal["user", "key", "config"] + ): + """ + For /key/delete endpoint s + """ + + def connect( + self, + ): + """ + For connecting to db and creating / updating any tables + """ + pass + + def disconnect( + self, + ): + """ + For closing connection on server shutdown + """ + pass diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/db/check_migration.py b/.venv/lib/python3.12/site-packages/litellm/proxy/db/check_migration.py new file mode 100644 index 00000000..bf180c11 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/db/check_migration.py @@ -0,0 +1,104 @@ +"""Module for checking differences between Prisma schema and database.""" + +import os +import subprocess +from typing import List, Optional, Tuple + +from litellm._logging import verbose_logger + + +def extract_sql_commands(diff_output: str) -> List[str]: + """ + Extract SQL commands from the Prisma migrate diff output. + Args: + diff_output (str): The full output from prisma migrate diff. + Returns: + List[str]: A list of SQL commands extracted from the diff output. + """ + # Split the output into lines and remove empty lines + lines = [line.strip() for line in diff_output.split("\n") if line.strip()] + + sql_commands = [] + current_command = "" + in_sql_block = False + + for line in lines: + if line.startswith("-- "): # Comment line, likely a table operation description + if in_sql_block and current_command: + sql_commands.append(current_command.strip()) + current_command = "" + in_sql_block = True + elif in_sql_block: + if line.endswith(";"): + current_command += line + sql_commands.append(current_command.strip()) + current_command = "" + in_sql_block = False + else: + current_command += line + " " + + # Add any remaining command + if current_command: + sql_commands.append(current_command.strip()) + + return sql_commands + + +def check_prisma_schema_diff_helper(db_url: str) -> Tuple[bool, List[str]]: + """Checks for differences between current database and Prisma schema. + Returns: + A tuple containing: + - A boolean indicating if differences were found (True) or not (False). + - A string with the diff output or error message. + Raises: + subprocess.CalledProcessError: If the Prisma command fails. + Exception: For any other errors during execution. + """ + verbose_logger.debug("Checking for Prisma schema diff...") # noqa: T201 + try: + result = subprocess.run( + [ + "prisma", + "migrate", + "diff", + "--from-url", + db_url, + "--to-schema-datamodel", + "./schema.prisma", + "--script", + ], + capture_output=True, + text=True, + check=True, + ) + + # return True, "Migration diff generated successfully." + sql_commands = extract_sql_commands(result.stdout) + + if sql_commands: + print("Changes to DB Schema detected") # noqa: T201 + print("Required SQL commands:") # noqa: T201 + for command in sql_commands: + print(command) # noqa: T201 + return True, sql_commands + else: + return False, [] + except subprocess.CalledProcessError as e: + error_message = f"Failed to generate migration diff. Error: {e.stderr}" + print(error_message) # noqa: T201 + return False, [] + + +def check_prisma_schema_diff(db_url: Optional[str] = None) -> None: + """Main function to run the Prisma schema diff check.""" + if db_url is None: + db_url = os.getenv("DATABASE_URL") + if db_url is None: + raise Exception("DATABASE_URL not set") + has_diff, message = check_prisma_schema_diff_helper(db_url) + if has_diff: + verbose_logger.exception( + "🚨🚨🚨 prisma schema out of sync with db. Consider running these sql_commands to sync the two - {}".format( + message + ) + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/db/create_views.py b/.venv/lib/python3.12/site-packages/litellm/proxy/db/create_views.py new file mode 100644 index 00000000..e9303077 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/db/create_views.py @@ -0,0 +1,227 @@ +from typing import Any + +from litellm import verbose_logger + +_db = Any + + +async def create_missing_views(db: _db): # noqa: PLR0915 + """ + -------------------------------------------------- + NOTE: Copy of `litellm/db_scripts/create_views.py`. + -------------------------------------------------- + Checks if the LiteLLM_VerificationTokenView and MonthlyGlobalSpend exists in the user's db. + + LiteLLM_VerificationTokenView: This view is used for getting the token + team data in user_api_key_auth + + MonthlyGlobalSpend: This view is used for the admin view to see global spend for this month + + If the view doesn't exist, one will be created. + """ + try: + # Try to select one row from the view + await db.query_raw("""SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1""") + print("LiteLLM_VerificationTokenView Exists!") # noqa + except Exception: + # If an error occurs, the view does not exist, so create it + await db.execute_raw( + """ + CREATE VIEW "LiteLLM_VerificationTokenView" AS + SELECT + v.*, + t.spend AS team_spend, + t.max_budget AS team_max_budget, + t.tpm_limit AS team_tpm_limit, + t.rpm_limit AS team_rpm_limit + FROM "LiteLLM_VerificationToken" v + LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id; + """ + ) + + print("LiteLLM_VerificationTokenView Created!") # noqa + + try: + await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""") + print("MonthlyGlobalSpend Exists!") # noqa + except Exception: + sql_query = """ + CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS + SELECT + DATE("startTime") AS date, + SUM("spend") AS spend + FROM + "LiteLLM_SpendLogs" + WHERE + "startTime" >= (CURRENT_DATE - INTERVAL '30 days') + GROUP BY + DATE("startTime"); + """ + await db.execute_raw(query=sql_query) + + print("MonthlyGlobalSpend Created!") # noqa + + try: + await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""") + print("Last30dKeysBySpend Exists!") # noqa + except Exception: + sql_query = """ + CREATE OR REPLACE VIEW "Last30dKeysBySpend" AS + SELECT + L."api_key", + V."key_alias", + V."key_name", + SUM(L."spend") AS total_spend + FROM + "LiteLLM_SpendLogs" L + LEFT JOIN + "LiteLLM_VerificationToken" V + ON + L."api_key" = V."token" + WHERE + L."startTime" >= (CURRENT_DATE - INTERVAL '30 days') + GROUP BY + L."api_key", V."key_alias", V."key_name" + ORDER BY + total_spend DESC; + """ + await db.execute_raw(query=sql_query) + + print("Last30dKeysBySpend Created!") # noqa + + try: + await db.query_raw("""SELECT 1 FROM "Last30dModelsBySpend" LIMIT 1""") + print("Last30dModelsBySpend Exists!") # noqa + except Exception: + sql_query = """ + CREATE OR REPLACE VIEW "Last30dModelsBySpend" AS + SELECT + "model", + SUM("spend") AS total_spend + FROM + "LiteLLM_SpendLogs" + WHERE + "startTime" >= (CURRENT_DATE - INTERVAL '30 days') + AND "model" != '' + GROUP BY + "model" + ORDER BY + total_spend DESC; + """ + await db.execute_raw(query=sql_query) + + print("Last30dModelsBySpend Created!") # noqa + try: + await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpendPerKey" LIMIT 1""") + print("MonthlyGlobalSpendPerKey Exists!") # noqa + except Exception: + sql_query = """ + CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerKey" AS + SELECT + DATE("startTime") AS date, + SUM("spend") AS spend, + api_key as api_key + FROM + "LiteLLM_SpendLogs" + WHERE + "startTime" >= (CURRENT_DATE - INTERVAL '30 days') + GROUP BY + DATE("startTime"), + api_key; + """ + await db.execute_raw(query=sql_query) + + print("MonthlyGlobalSpendPerKey Created!") # noqa + try: + await db.query_raw( + """SELECT 1 FROM "MonthlyGlobalSpendPerUserPerKey" LIMIT 1""" + ) + print("MonthlyGlobalSpendPerUserPerKey Exists!") # noqa + except Exception: + sql_query = """ + CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerUserPerKey" AS + SELECT + DATE("startTime") AS date, + SUM("spend") AS spend, + api_key as api_key, + "user" as "user" + FROM + "LiteLLM_SpendLogs" + WHERE + "startTime" >= (CURRENT_DATE - INTERVAL '30 days') + GROUP BY + DATE("startTime"), + "user", + api_key; + """ + await db.execute_raw(query=sql_query) + + print("MonthlyGlobalSpendPerUserPerKey Created!") # noqa + + try: + await db.query_raw("""SELECT 1 FROM "DailyTagSpend" LIMIT 1""") + print("DailyTagSpend Exists!") # noqa + except Exception: + sql_query = """ + CREATE OR REPLACE VIEW "DailyTagSpend" AS + SELECT + jsonb_array_elements_text(request_tags) AS individual_request_tag, + DATE(s."startTime") AS spend_date, + COUNT(*) AS log_count, + SUM(spend) AS total_spend + FROM "LiteLLM_SpendLogs" s + GROUP BY individual_request_tag, DATE(s."startTime"); + """ + await db.execute_raw(query=sql_query) + + print("DailyTagSpend Created!") # noqa + + try: + await db.query_raw("""SELECT 1 FROM "Last30dTopEndUsersSpend" LIMIT 1""") + print("Last30dTopEndUsersSpend Exists!") # noqa + except Exception: + sql_query = """ + CREATE VIEW "Last30dTopEndUsersSpend" AS + SELECT end_user, COUNT(*) AS total_events, SUM(spend) AS total_spend + FROM "LiteLLM_SpendLogs" + WHERE end_user <> '' AND end_user <> user + AND "startTime" >= CURRENT_DATE - INTERVAL '30 days' + GROUP BY end_user + ORDER BY total_spend DESC + LIMIT 100; + """ + await db.execute_raw(query=sql_query) + + print("Last30dTopEndUsersSpend Created!") # noqa + + return + + +async def should_create_missing_views(db: _db) -> bool: + """ + Run only on first time startup. + + If SpendLogs table already has values, then don't create views on startup. + """ + + sql_query = """ + SELECT reltuples::BIGINT + FROM pg_class + WHERE oid = '"LiteLLM_SpendLogs"'::regclass; + """ + + result = await db.query_raw(query=sql_query) + + verbose_logger.debug("Estimated Row count of LiteLLM_SpendLogs = {}".format(result)) + if ( + result + and isinstance(result, list) + and len(result) > 0 + and isinstance(result[0], dict) + and "reltuples" in result[0] + and result[0]["reltuples"] + and (result[0]["reltuples"] == 0 or result[0]["reltuples"] == -1) + ): + verbose_logger.debug("Should create views") + return True + + return False diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/db/dynamo_db.py b/.venv/lib/python3.12/site-packages/litellm/proxy/db/dynamo_db.py new file mode 100644 index 00000000..628509d9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/db/dynamo_db.py @@ -0,0 +1,71 @@ +""" +Deprecated. Only PostgresSQL is supported. +""" + +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import DynamoDBArgs +from litellm.proxy.db.base_client import CustomDB + + +class DynamoDBWrapper(CustomDB): + from aiodynamo.credentials import Credentials, StaticCredentials + + credentials: Credentials + + def __init__(self, database_arguments: DynamoDBArgs): + from aiodynamo.models import PayPerRequest, Throughput + + self.throughput_type = None + if database_arguments.billing_mode == "PAY_PER_REQUEST": + self.throughput_type = PayPerRequest() + elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT": + if ( + database_arguments.read_capacity_units is not None + and isinstance(database_arguments.read_capacity_units, int) + and database_arguments.write_capacity_units is not None + and isinstance(database_arguments.write_capacity_units, int) + ): + self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore + else: + raise Exception( + f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}" + ) + self.database_arguments = database_arguments + self.region_name = database_arguments.region_name + + def set_env_vars_based_on_arn(self): + if self.database_arguments.aws_role_name is None: + return + verbose_proxy_logger.debug( + f"DynamoDB: setting env vars based on arn={self.database_arguments.aws_role_name}" + ) + import os + + import boto3 + + sts_client = boto3.client("sts") + + # call 1 + sts_client.assume_role_with_web_identity( + RoleArn=self.database_arguments.aws_role_name, + RoleSessionName=self.database_arguments.aws_session_name, + WebIdentityToken=self.database_arguments.aws_web_identity_token, + ) + + # call 2 + assumed_role = sts_client.assume_role( + RoleArn=self.database_arguments.assume_role_aws_role_name, + RoleSessionName=self.database_arguments.assume_role_aws_session_name, + ) + + aws_access_key_id = assumed_role["Credentials"]["AccessKeyId"] + aws_secret_access_key = assumed_role["Credentials"]["SecretAccessKey"] + aws_session_token = assumed_role["Credentials"]["SessionToken"] + + verbose_proxy_logger.debug( + f"Got STS assumed Role, aws_access_key_id={aws_access_key_id}" + ) + # set these in the env so aiodynamo can use them + os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id + os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key + os.environ["AWS_SESSION_TOKEN"] = aws_session_token diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/db/log_db_metrics.py b/.venv/lib/python3.12/site-packages/litellm/proxy/db/log_db_metrics.py new file mode 100644 index 00000000..9bd33507 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/db/log_db_metrics.py @@ -0,0 +1,143 @@ +""" +Handles logging DB success/failure to ServiceLogger() + +ServiceLogger() then sends DB logs to Prometheus, OTEL, Datadog etc +""" + +import asyncio +from datetime import datetime +from functools import wraps +from typing import Callable, Dict, Tuple + +from litellm._service_logger import ServiceTypes +from litellm.litellm_core_utils.core_helpers import ( + _get_parent_otel_span_from_kwargs, + get_litellm_metadata_from_kwargs, +) + + +def log_db_metrics(func): + """ + Decorator to log the duration of a DB related function to ServiceLogger() + + Handles logging DB success/failure to ServiceLogger(), which logs to Prometheus, OTEL, Datadog + + When logging Failure it checks if the Exception is a PrismaError, httpx.ConnectError or httpx.TimeoutException and then logs that as a DB Service Failure + + Args: + func: The function to be decorated + + Returns: + Result from the decorated function + + Raises: + Exception: If the decorated function raises an exception + """ + + @wraps(func) + async def wrapper(*args, **kwargs): + + start_time: datetime = datetime.now() + + try: + result = await func(*args, **kwargs) + end_time: datetime = datetime.now() + from litellm.proxy.proxy_server import proxy_logging_obj + + if "PROXY" not in func.__name__: + asyncio.create_task( + proxy_logging_obj.service_logging_obj.async_service_success_hook( + service=ServiceTypes.DB, + call_type=func.__name__, + parent_otel_span=kwargs.get("parent_otel_span", None), + duration=(end_time - start_time).total_seconds(), + start_time=start_time, + end_time=end_time, + event_metadata={ + "function_name": func.__name__, + "function_kwargs": kwargs, + "function_args": args, + }, + ) + ) + elif ( + # in litellm custom callbacks kwargs is passed as arg[0] + # https://docs.litellm.ai/docs/observability/custom_callback#callback-functions + args is not None + and len(args) > 1 + and isinstance(args[1], dict) + ): + passed_kwargs = args[1] + parent_otel_span = _get_parent_otel_span_from_kwargs( + kwargs=passed_kwargs + ) + if parent_otel_span is not None: + metadata = get_litellm_metadata_from_kwargs(kwargs=passed_kwargs) + + asyncio.create_task( + proxy_logging_obj.service_logging_obj.async_service_success_hook( + service=ServiceTypes.BATCH_WRITE_TO_DB, + call_type=func.__name__, + parent_otel_span=parent_otel_span, + duration=0.0, + start_time=start_time, + end_time=end_time, + event_metadata=metadata, + ) + ) + # end of logging to otel + return result + except Exception as e: + end_time: datetime = datetime.now() + await _handle_logging_db_exception( + e=e, + func=func, + kwargs=kwargs, + args=args, + start_time=start_time, + end_time=end_time, + ) + raise e + + return wrapper + + +def _is_exception_related_to_db(e: Exception) -> bool: + """ + Returns True if the exception is related to the DB + """ + + import httpx + from prisma.errors import PrismaError + + return isinstance(e, (PrismaError, httpx.ConnectError, httpx.TimeoutException)) + + +async def _handle_logging_db_exception( + e: Exception, + func: Callable, + kwargs: Dict, + args: Tuple, + start_time: datetime, + end_time: datetime, +) -> None: + from litellm.proxy.proxy_server import proxy_logging_obj + + # don't log this as a DB Service Failure, if the DB did not raise an exception + if _is_exception_related_to_db(e) is not True: + return + + await proxy_logging_obj.service_logging_obj.async_service_failure_hook( + error=e, + service=ServiceTypes.DB, + call_type=func.__name__, + parent_otel_span=kwargs.get("parent_otel_span"), + duration=(end_time - start_time).total_seconds(), + start_time=start_time, + end_time=end_time, + event_metadata={ + "function_name": func.__name__, + "function_kwargs": kwargs, + "function_args": args, + }, + ) diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/db/prisma_client.py b/.venv/lib/python3.12/site-packages/litellm/proxy/db/prisma_client.py new file mode 100644 index 00000000..85a3a57a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/litellm/proxy/db/prisma_client.py @@ -0,0 +1,278 @@ +""" +This file contains the PrismaWrapper class, which is used to wrap the Prisma client and handle the RDS IAM token. +""" + +import asyncio +import os +import random +import subprocess +import time +import urllib +import urllib.parse +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any, Optional, Union + +from litellm._logging import verbose_proxy_logger +from litellm.secret_managers.main import str_to_bool + + +class PrismaWrapper: + def __init__(self, original_prisma: Any, iam_token_db_auth: bool): + self._original_prisma = original_prisma + self.iam_token_db_auth = iam_token_db_auth + + def is_token_expired(self, token_url: Optional[str]) -> bool: + if token_url is None: + return True + # Decode the token URL to handle URL-encoded characters + decoded_url = urllib.parse.unquote(token_url) + + # Parse the token URL + parsed_url = urllib.parse.urlparse(decoded_url) + + # Parse the query parameters from the path component (if they exist there) + query_params = urllib.parse.parse_qs(parsed_url.query) + + # Get expiration time from the query parameters + expires = query_params.get("X-Amz-Expires", [None])[0] + if expires is None: + raise ValueError("X-Amz-Expires parameter is missing or invalid.") + + expires_int = int(expires) + + # Get the token's creation time from the X-Amz-Date parameter + token_time_str = query_params.get("X-Amz-Date", [""])[0] + if not token_time_str: + raise ValueError("X-Amz-Date parameter is missing or invalid.") + + # Ensure the token time string is parsed correctly + try: + token_time = datetime.strptime(token_time_str, "%Y%m%dT%H%M%SZ") + except ValueError as e: + raise ValueError(f"Invalid X-Amz-Date format: {e}") + + # Calculate the expiration time + expiration_time = token_time + timedelta(seconds=expires_int) + + # Current time in UTC + current_time = datetime.utcnow() + + # Check if the token is expired + return current_time > expiration_time + + def get_rds_iam_token(self) -> Optional[str]: + if self.iam_token_db_auth: + from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token + + db_host = os.getenv("DATABASE_HOST") + db_port = os.getenv("DATABASE_PORT") + db_user = os.getenv("DATABASE_USER") + db_name = os.getenv("DATABASE_NAME") + db_schema = os.getenv("DATABASE_SCHEMA") + + token = generate_iam_auth_token( + db_host=db_host, db_port=db_port, db_user=db_user + ) + + # print(f"token: {token}") + _db_url = f"postgresql://{db_user}:{token}@{db_host}:{db_port}/{db_name}" + if db_schema: + _db_url += f"?schema={db_schema}" + + os.environ["DATABASE_URL"] = _db_url + return _db_url + return None + + async def recreate_prisma_client( + self, new_db_url: str, http_client: Optional[Any] = None + ): + from prisma import Prisma # type: ignore + + if http_client is not None: + self._original_prisma = Prisma(http=http_client) + else: + self._original_prisma = Prisma() + + await self._original_prisma.connect() + + def __getattr__(self, name: str): + original_attr = getattr(self._original_prisma, name) + if self.iam_token_db_auth: + db_url = os.getenv("DATABASE_URL") + if self.is_token_expired(db_url): + db_url = self.get_rds_iam_token() + loop = asyncio.get_event_loop() + + if db_url: + if loop.is_running(): + asyncio.run_coroutine_threadsafe( + self.recreate_prisma_client(db_url), loop + ) + else: + asyncio.run(self.recreate_prisma_client(db_url)) + else: + raise ValueError("Failed to get RDS IAM token") + + return original_attr + + +class PrismaManager: + @staticmethod + def _get_prisma_dir() -> str: + """Get the path to the migrations directory""" + abspath = os.path.abspath(__file__) + dname = os.path.dirname(os.path.dirname(abspath)) + return dname + + @staticmethod + def _create_baseline_migration(schema_path: str) -> bool: + """Create a baseline migration for an existing database""" + prisma_dir = PrismaManager._get_prisma_dir() + prisma_dir_path = Path(prisma_dir) + init_dir = prisma_dir_path / "migrations" / "0_init" + + # Create migrations/0_init directory + init_dir.mkdir(parents=True, exist_ok=True) + + # Generate migration SQL file + migration_file = init_dir / "migration.sql" + + try: + # Generate migration diff with increased timeout + subprocess.run( + [ + "prisma", + "migrate", + "diff", + "--from-empty", + "--to-schema-datamodel", + str(schema_path), + "--script", + ], + stdout=open(migration_file, "w"), + check=True, + timeout=30, + ) # 30 second timeout + + # Mark migration as applied with increased timeout + subprocess.run( + [ + "prisma", + "migrate", + "resolve", + "--applied", + "0_init", + ], + check=True, + timeout=30, + ) + + return True + except subprocess.TimeoutExpired: + verbose_proxy_logger.warning( + "Migration timed out - the database might be under heavy load." + ) + return False + except subprocess.CalledProcessError as e: + verbose_proxy_logger.warning(f"Error creating baseline migration: {e}") + return False + + @staticmethod + def setup_database(use_migrate: bool = False) -> bool: + """ + Set up the database using either prisma migrate or prisma db push + + Returns: + bool: True if setup was successful, False otherwise + """ + + for attempt in range(4): + original_dir = os.getcwd() + prisma_dir = PrismaManager._get_prisma_dir() + schema_path = prisma_dir + "/schema.prisma" + os.chdir(prisma_dir) + try: + if use_migrate: + verbose_proxy_logger.info("Running prisma migrate deploy") + # First try to run migrate deploy directly + try: + subprocess.run( + ["prisma", "migrate", "deploy"], + timeout=60, + check=True, + capture_output=True, + text=True, + ) + verbose_proxy_logger.info("prisma migrate deploy completed") + return True + except subprocess.CalledProcessError as e: + # Check if this is the non-empty schema error + if ( + "P3005" in e.stderr + and "database schema is not empty" in e.stderr + ): + # Create baseline migration + if PrismaManager._create_baseline_migration(schema_path): + # Try migrate deploy again after baseline + subprocess.run( + ["prisma", "migrate", "deploy"], + timeout=60, + check=True, + ) + return True + else: + # If it's a different error, raise it + raise e + else: + # Use prisma db push with increased timeout + subprocess.run( + ["prisma", "db", "push", "--accept-data-loss"], + timeout=60, + check=True, + ) + return True + except subprocess.TimeoutExpired: + verbose_proxy_logger.warning(f"Attempt {attempt + 1} timed out") + time.sleep(random.randrange(5, 15)) + except subprocess.CalledProcessError as e: + attempts_left = 3 - attempt + retry_msg = ( + f" Retrying... ({attempts_left} attempts left)" + if attempts_left > 0 + else "" + ) + verbose_proxy_logger.warning( + f"The process failed to execute. Details: {e}.{retry_msg}" + ) + time.sleep(random.randrange(5, 15)) + finally: + os.chdir(original_dir) + return False + + +def should_update_prisma_schema( + disable_updates: Optional[Union[bool, str]] = None +) -> bool: + """ + Determines if Prisma Schema updates should be applied during startup. + + Args: + disable_updates: Controls whether schema updates are disabled. + Accepts boolean or string ('true'/'false'). Defaults to checking DISABLE_SCHEMA_UPDATE env var. + + Returns: + bool: True if schema updates should be applied, False if updates are disabled. + + Examples: + >>> should_update_prisma_schema() # Checks DISABLE_SCHEMA_UPDATE env var + >>> should_update_prisma_schema(True) # Explicitly disable updates + >>> should_update_prisma_schema("false") # Enable updates using string + """ + if disable_updates is None: + disable_updates = os.getenv("DISABLE_SCHEMA_UPDATE", "false") + + if isinstance(disable_updates, str): + disable_updates = str_to_bool(disable_updates) + + return not bool(disable_updates) |