aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/db/prisma_client.py
blob: 85a3a57adc0a9e0bb69f584b08f91829e32d09b1 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
"""
This file contains the PrismaWrapper class, which is used to wrap the Prisma client and handle the RDS IAM token.
"""

import asyncio
import os
import random
import subprocess
import time
import urllib
import urllib.parse
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Optional, Union

from litellm._logging import verbose_proxy_logger
from litellm.secret_managers.main import str_to_bool


class PrismaWrapper:
    def __init__(self, original_prisma: Any, iam_token_db_auth: bool):
        self._original_prisma = original_prisma
        self.iam_token_db_auth = iam_token_db_auth

    def is_token_expired(self, token_url: Optional[str]) -> bool:
        if token_url is None:
            return True
        # Decode the token URL to handle URL-encoded characters
        decoded_url = urllib.parse.unquote(token_url)

        # Parse the token URL
        parsed_url = urllib.parse.urlparse(decoded_url)

        # Parse the query parameters from the path component (if they exist there)
        query_params = urllib.parse.parse_qs(parsed_url.query)

        # Get expiration time from the query parameters
        expires = query_params.get("X-Amz-Expires", [None])[0]
        if expires is None:
            raise ValueError("X-Amz-Expires parameter is missing or invalid.")

        expires_int = int(expires)

        # Get the token's creation time from the X-Amz-Date parameter
        token_time_str = query_params.get("X-Amz-Date", [""])[0]
        if not token_time_str:
            raise ValueError("X-Amz-Date parameter is missing or invalid.")

        # Ensure the token time string is parsed correctly
        try:
            token_time = datetime.strptime(token_time_str, "%Y%m%dT%H%M%SZ")
        except ValueError as e:
            raise ValueError(f"Invalid X-Amz-Date format: {e}")

        # Calculate the expiration time
        expiration_time = token_time + timedelta(seconds=expires_int)

        # Current time in UTC
        current_time = datetime.utcnow()

        # Check if the token is expired
        return current_time > expiration_time

    def get_rds_iam_token(self) -> Optional[str]:
        if self.iam_token_db_auth:
            from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token

            db_host = os.getenv("DATABASE_HOST")
            db_port = os.getenv("DATABASE_PORT")
            db_user = os.getenv("DATABASE_USER")
            db_name = os.getenv("DATABASE_NAME")
            db_schema = os.getenv("DATABASE_SCHEMA")

            token = generate_iam_auth_token(
                db_host=db_host, db_port=db_port, db_user=db_user
            )

            # print(f"token: {token}")
            _db_url = f"postgresql://{db_user}:{token}@{db_host}:{db_port}/{db_name}"
            if db_schema:
                _db_url += f"?schema={db_schema}"

            os.environ["DATABASE_URL"] = _db_url
            return _db_url
        return None

    async def recreate_prisma_client(
        self, new_db_url: str, http_client: Optional[Any] = None
    ):
        from prisma import Prisma  # type: ignore

        if http_client is not None:
            self._original_prisma = Prisma(http=http_client)
        else:
            self._original_prisma = Prisma()

        await self._original_prisma.connect()

    def __getattr__(self, name: str):
        original_attr = getattr(self._original_prisma, name)
        if self.iam_token_db_auth:
            db_url = os.getenv("DATABASE_URL")
            if self.is_token_expired(db_url):
                db_url = self.get_rds_iam_token()
                loop = asyncio.get_event_loop()

                if db_url:
                    if loop.is_running():
                        asyncio.run_coroutine_threadsafe(
                            self.recreate_prisma_client(db_url), loop
                        )
                    else:
                        asyncio.run(self.recreate_prisma_client(db_url))
                else:
                    raise ValueError("Failed to get RDS IAM token")

        return original_attr


class PrismaManager:
    @staticmethod
    def _get_prisma_dir() -> str:
        """Get the path to the migrations directory"""
        abspath = os.path.abspath(__file__)
        dname = os.path.dirname(os.path.dirname(abspath))
        return dname

    @staticmethod
    def _create_baseline_migration(schema_path: str) -> bool:
        """Create a baseline migration for an existing database"""
        prisma_dir = PrismaManager._get_prisma_dir()
        prisma_dir_path = Path(prisma_dir)
        init_dir = prisma_dir_path / "migrations" / "0_init"

        # Create migrations/0_init directory
        init_dir.mkdir(parents=True, exist_ok=True)

        # Generate migration SQL file
        migration_file = init_dir / "migration.sql"

        try:
            # Generate migration diff with increased timeout
            subprocess.run(
                [
                    "prisma",
                    "migrate",
                    "diff",
                    "--from-empty",
                    "--to-schema-datamodel",
                    str(schema_path),
                    "--script",
                ],
                stdout=open(migration_file, "w"),
                check=True,
                timeout=30,
            )  # 30 second timeout

            # Mark migration as applied with increased timeout
            subprocess.run(
                [
                    "prisma",
                    "migrate",
                    "resolve",
                    "--applied",
                    "0_init",
                ],
                check=True,
                timeout=30,
            )

            return True
        except subprocess.TimeoutExpired:
            verbose_proxy_logger.warning(
                "Migration timed out - the database might be under heavy load."
            )
            return False
        except subprocess.CalledProcessError as e:
            verbose_proxy_logger.warning(f"Error creating baseline migration: {e}")
            return False

    @staticmethod
    def setup_database(use_migrate: bool = False) -> bool:
        """
        Set up the database using either prisma migrate or prisma db push

        Returns:
            bool: True if setup was successful, False otherwise
        """

        for attempt in range(4):
            original_dir = os.getcwd()
            prisma_dir = PrismaManager._get_prisma_dir()
            schema_path = prisma_dir + "/schema.prisma"
            os.chdir(prisma_dir)
            try:
                if use_migrate:
                    verbose_proxy_logger.info("Running prisma migrate deploy")
                    # First try to run migrate deploy directly
                    try:
                        subprocess.run(
                            ["prisma", "migrate", "deploy"],
                            timeout=60,
                            check=True,
                            capture_output=True,
                            text=True,
                        )
                        verbose_proxy_logger.info("prisma migrate deploy completed")
                        return True
                    except subprocess.CalledProcessError as e:
                        # Check if this is the non-empty schema error
                        if (
                            "P3005" in e.stderr
                            and "database schema is not empty" in e.stderr
                        ):
                            # Create baseline migration
                            if PrismaManager._create_baseline_migration(schema_path):
                                # Try migrate deploy again after baseline
                                subprocess.run(
                                    ["prisma", "migrate", "deploy"],
                                    timeout=60,
                                    check=True,
                                )
                                return True
                        else:
                            # If it's a different error, raise it
                            raise e
                else:
                    # Use prisma db push with increased timeout
                    subprocess.run(
                        ["prisma", "db", "push", "--accept-data-loss"],
                        timeout=60,
                        check=True,
                    )
                    return True
            except subprocess.TimeoutExpired:
                verbose_proxy_logger.warning(f"Attempt {attempt + 1} timed out")
                time.sleep(random.randrange(5, 15))
            except subprocess.CalledProcessError as e:
                attempts_left = 3 - attempt
                retry_msg = (
                    f" Retrying... ({attempts_left} attempts left)"
                    if attempts_left > 0
                    else ""
                )
                verbose_proxy_logger.warning(
                    f"The process failed to execute. Details: {e}.{retry_msg}"
                )
                time.sleep(random.randrange(5, 15))
            finally:
                os.chdir(original_dir)
        return False


def should_update_prisma_schema(
    disable_updates: Optional[Union[bool, str]] = None
) -> bool:
    """
    Determines if Prisma Schema updates should be applied during startup.

    Args:
        disable_updates: Controls whether schema updates are disabled.
            Accepts boolean or string ('true'/'false'). Defaults to checking DISABLE_SCHEMA_UPDATE env var.

    Returns:
        bool: True if schema updates should be applied, False if updates are disabled.

    Examples:
        >>> should_update_prisma_schema()  # Checks DISABLE_SCHEMA_UPDATE env var
        >>> should_update_prisma_schema(True)  # Explicitly disable updates
        >>> should_update_prisma_schema("false")  # Enable updates using string
    """
    if disable_updates is None:
        disable_updates = os.getenv("DISABLE_SCHEMA_UPDATE", "false")

    if isinstance(disable_updates, str):
        disable_updates = str_to_bool(disable_updates)

    return not bool(disable_updates)