about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/core/providers/database/base.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/database/base.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/database/base.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/base.py247
1 files changed, 247 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/base.py b/.venv/lib/python3.12/site-packages/core/providers/database/base.py
new file mode 100644
index 00000000..c70c1352
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/base.py
@@ -0,0 +1,247 @@
+import asyncio
+import logging
+import textwrap
+from contextlib import asynccontextmanager
+from typing import Optional
+
+import asyncpg
+
+from core.base.providers import DatabaseConnectionManager
+
+logger = logging.getLogger()
+
+
+class SemaphoreConnectionPool:
+    def __init__(self, connection_string, postgres_configuration_settings):
+        self.connection_string = connection_string
+        self.postgres_configuration_settings = postgres_configuration_settings
+
+    async def initialize(self):
+        try:
+            logger.info(
+                f"Connecting with {int(self.postgres_configuration_settings.max_connections * 0.9)} connections to `asyncpg.create_pool`."
+            )
+
+            self.semaphore = asyncio.Semaphore(
+                int(self.postgres_configuration_settings.max_connections * 0.9)
+            )
+
+            self.pool = await asyncpg.create_pool(
+                self.connection_string,
+                max_size=self.postgres_configuration_settings.max_connections,
+                statement_cache_size=self.postgres_configuration_settings.statement_cache_size,
+            )
+
+            logger.info(
+                "Successfully connected to Postgres database and created connection pool."
+            )
+        except Exception as e:
+            raise ValueError(
+                f"Error {e} occurred while attempting to connect to relational database."
+            ) from e
+
+    @asynccontextmanager
+    async def get_connection(self):
+        async with self.semaphore:
+            async with self.pool.acquire() as conn:
+                yield conn
+
+    async def close(self):
+        await self.pool.close()
+
+
+class QueryBuilder:
+    def __init__(self, table_name: str):
+        self.table_name = table_name
+        self.conditions: list[str] = []
+        self.params: list = []
+        self.select_fields = "*"
+        self.operation = "SELECT"
+        self.limit_value: Optional[int] = None
+        self.offset_value: Optional[int] = None
+        self.order_by_fields: Optional[str] = None
+        self.returning_fields: Optional[list[str]] = None
+        self.insert_data: Optional[dict] = None
+        self.update_data: Optional[dict] = None
+        self.param_counter = 1
+
+    def select(self, fields: list[str]):
+        self.select_fields = ", ".join(fields)
+        return self
+
+    def insert(self, data: dict):
+        self.operation = "INSERT"
+        self.insert_data = data
+        return self
+
+    def update(self, data: dict):
+        self.operation = "UPDATE"
+        self.update_data = data
+        return self
+
+    def delete(self):
+        self.operation = "DELETE"
+        return self
+
+    def where(self, condition: str):
+        self.conditions.append(condition)
+        return self
+
+    def limit(self, value: Optional[int]):
+        self.limit_value = value
+        return self
+
+    def offset(self, value: int):
+        self.offset_value = value
+        return self
+
+    def order_by(self, fields: str):
+        self.order_by_fields = fields
+        return self
+
+    def returning(self, fields: list[str]):
+        self.returning_fields = fields
+        return self
+
+    def build(self):
+        if self.operation == "SELECT":
+            query = f"SELECT {self.select_fields} FROM {self.table_name}"
+
+        elif self.operation == "INSERT":
+            columns = ", ".join(self.insert_data.keys())
+            placeholders = ", ".join(
+                f"${i}" for i in range(1, len(self.insert_data) + 1)
+            )
+            query = f"INSERT INTO {self.table_name} ({columns}) VALUES ({placeholders})"
+            self.params.extend(list(self.insert_data.values()))
+
+        elif self.operation == "UPDATE":
+            set_clauses = []
+            for i, (key, value) in enumerate(
+                self.update_data.items(), start=len(self.params) + 1
+            ):
+                set_clauses.append(f"{key} = ${i}")
+                self.params.append(value)
+            query = f"UPDATE {self.table_name} SET {', '.join(set_clauses)}"
+
+        elif self.operation == "DELETE":
+            query = f"DELETE FROM {self.table_name}"
+
+        else:
+            raise ValueError(f"Unsupported operation: {self.operation}")
+
+        if self.conditions:
+            query += " WHERE " + " AND ".join(self.conditions)
+
+        if self.order_by_fields and self.operation == "SELECT":
+            query += f" ORDER BY {self.order_by_fields}"
+
+        if self.offset_value is not None:
+            query += f" OFFSET {self.offset_value}"
+
+        if self.limit_value is not None:
+            query += f" LIMIT {self.limit_value}"
+
+        if self.returning_fields:
+            query += f" RETURNING {', '.join(self.returning_fields)}"
+
+        return query, self.params
+
+
+class PostgresConnectionManager(DatabaseConnectionManager):
+    def __init__(self):
+        self.pool: Optional[SemaphoreConnectionPool] = None
+
+    async def initialize(self, pool: SemaphoreConnectionPool):
+        self.pool = pool
+
+    async def execute_query(self, query, params=None, isolation_level=None):
+        if not self.pool:
+            raise ValueError("PostgresConnectionManager is not initialized.")
+        async with self.pool.get_connection() as conn:
+            if isolation_level:
+                async with conn.transaction(isolation=isolation_level):
+                    if params:
+                        return await conn.execute(query, *params)
+                    else:
+                        return await conn.execute(query)
+            else:
+                if params:
+                    return await conn.execute(query, *params)
+                else:
+                    return await conn.execute(query)
+
+    async def execute_many(self, query, params=None, batch_size=1000):
+        if not self.pool:
+            raise ValueError("PostgresConnectionManager is not initialized.")
+        async with self.pool.get_connection() as conn:
+            async with conn.transaction():
+                if params:
+                    results = []
+                    for i in range(0, len(params), batch_size):
+                        param_batch = params[i : i + batch_size]
+                        result = await conn.executemany(query, param_batch)
+                        results.append(result)
+                    return results
+                else:
+                    return await conn.executemany(query)
+
+    async def fetch_query(self, query, params=None):
+        if not self.pool:
+            raise ValueError("PostgresConnectionManager is not initialized.")
+        try:
+            async with self.pool.get_connection() as conn:
+                async with conn.transaction():
+                    return (
+                        await conn.fetch(query, *params)
+                        if params
+                        else await conn.fetch(query)
+                    )
+        except asyncpg.exceptions.DuplicatePreparedStatementError:
+            error_msg = textwrap.dedent("""
+                Database Configuration Error
+
+                Your database provider does not support statement caching.
+
+                To fix this, either:
+                • Set R2R_POSTGRES_STATEMENT_CACHE_SIZE=0 in your environment
+                • Add statement_cache_size = 0 to your database configuration:
+
+                    [database.postgres_configuration_settings]
+                    statement_cache_size = 0
+
+                This is required when using connection poolers like PgBouncer or
+                managed database services like Supabase.
+            """).strip()
+            raise ValueError(error_msg) from None
+
+    async def fetchrow_query(self, query, params=None):
+        if not self.pool:
+            raise ValueError("PostgresConnectionManager is not initialized.")
+        async with self.pool.get_connection() as conn:
+            async with conn.transaction():
+                if params:
+                    return await conn.fetchrow(query, *params)
+                else:
+                    return await conn.fetchrow(query)
+
+    @asynccontextmanager
+    async def transaction(self, isolation_level=None):
+        """Async context manager for database transactions.
+
+        Args:
+            isolation_level: Optional isolation level for the transaction
+
+        Yields:
+            The connection manager instance for use within the transaction
+        """
+        if not self.pool:
+            raise ValueError("PostgresConnectionManager is not initialized.")
+
+        async with self.pool.get_connection() as conn:
+            async with conn.transaction(isolation=isolation_level):
+                try:
+                    yield self
+                except Exception as e:
+                    logger.error(f"Transaction failed: {str(e)}")
+                    raise