aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/providers/database/base.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/database/base.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/database/base.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/base.py247
1 files changed, 247 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/base.py b/.venv/lib/python3.12/site-packages/core/providers/database/base.py
new file mode 100644
index 00000000..c70c1352
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/base.py
@@ -0,0 +1,247 @@
+import asyncio
+import logging
+import textwrap
+from contextlib import asynccontextmanager
+from typing import Optional
+
+import asyncpg
+
+from core.base.providers import DatabaseConnectionManager
+
+logger = logging.getLogger()
+
+
+class SemaphoreConnectionPool:
+ def __init__(self, connection_string, postgres_configuration_settings):
+ self.connection_string = connection_string
+ self.postgres_configuration_settings = postgres_configuration_settings
+
+ async def initialize(self):
+ try:
+ logger.info(
+ f"Connecting with {int(self.postgres_configuration_settings.max_connections * 0.9)} connections to `asyncpg.create_pool`."
+ )
+
+ self.semaphore = asyncio.Semaphore(
+ int(self.postgres_configuration_settings.max_connections * 0.9)
+ )
+
+ self.pool = await asyncpg.create_pool(
+ self.connection_string,
+ max_size=self.postgres_configuration_settings.max_connections,
+ statement_cache_size=self.postgres_configuration_settings.statement_cache_size,
+ )
+
+ logger.info(
+ "Successfully connected to Postgres database and created connection pool."
+ )
+ except Exception as e:
+ raise ValueError(
+ f"Error {e} occurred while attempting to connect to relational database."
+ ) from e
+
+ @asynccontextmanager
+ async def get_connection(self):
+ async with self.semaphore:
+ async with self.pool.acquire() as conn:
+ yield conn
+
+ async def close(self):
+ await self.pool.close()
+
+
+class QueryBuilder:
+ def __init__(self, table_name: str):
+ self.table_name = table_name
+ self.conditions: list[str] = []
+ self.params: list = []
+ self.select_fields = "*"
+ self.operation = "SELECT"
+ self.limit_value: Optional[int] = None
+ self.offset_value: Optional[int] = None
+ self.order_by_fields: Optional[str] = None
+ self.returning_fields: Optional[list[str]] = None
+ self.insert_data: Optional[dict] = None
+ self.update_data: Optional[dict] = None
+ self.param_counter = 1
+
+ def select(self, fields: list[str]):
+ self.select_fields = ", ".join(fields)
+ return self
+
+ def insert(self, data: dict):
+ self.operation = "INSERT"
+ self.insert_data = data
+ return self
+
+ def update(self, data: dict):
+ self.operation = "UPDATE"
+ self.update_data = data
+ return self
+
+ def delete(self):
+ self.operation = "DELETE"
+ return self
+
+ def where(self, condition: str):
+ self.conditions.append(condition)
+ return self
+
+ def limit(self, value: Optional[int]):
+ self.limit_value = value
+ return self
+
+ def offset(self, value: int):
+ self.offset_value = value
+ return self
+
+ def order_by(self, fields: str):
+ self.order_by_fields = fields
+ return self
+
+ def returning(self, fields: list[str]):
+ self.returning_fields = fields
+ return self
+
+ def build(self):
+ if self.operation == "SELECT":
+ query = f"SELECT {self.select_fields} FROM {self.table_name}"
+
+ elif self.operation == "INSERT":
+ columns = ", ".join(self.insert_data.keys())
+ placeholders = ", ".join(
+ f"${i}" for i in range(1, len(self.insert_data) + 1)
+ )
+ query = f"INSERT INTO {self.table_name} ({columns}) VALUES ({placeholders})"
+ self.params.extend(list(self.insert_data.values()))
+
+ elif self.operation == "UPDATE":
+ set_clauses = []
+ for i, (key, value) in enumerate(
+ self.update_data.items(), start=len(self.params) + 1
+ ):
+ set_clauses.append(f"{key} = ${i}")
+ self.params.append(value)
+ query = f"UPDATE {self.table_name} SET {', '.join(set_clauses)}"
+
+ elif self.operation == "DELETE":
+ query = f"DELETE FROM {self.table_name}"
+
+ else:
+ raise ValueError(f"Unsupported operation: {self.operation}")
+
+ if self.conditions:
+ query += " WHERE " + " AND ".join(self.conditions)
+
+ if self.order_by_fields and self.operation == "SELECT":
+ query += f" ORDER BY {self.order_by_fields}"
+
+ if self.offset_value is not None:
+ query += f" OFFSET {self.offset_value}"
+
+ if self.limit_value is not None:
+ query += f" LIMIT {self.limit_value}"
+
+ if self.returning_fields:
+ query += f" RETURNING {', '.join(self.returning_fields)}"
+
+ return query, self.params
+
+
+class PostgresConnectionManager(DatabaseConnectionManager):
+ def __init__(self):
+ self.pool: Optional[SemaphoreConnectionPool] = None
+
+ async def initialize(self, pool: SemaphoreConnectionPool):
+ self.pool = pool
+
+ async def execute_query(self, query, params=None, isolation_level=None):
+ if not self.pool:
+ raise ValueError("PostgresConnectionManager is not initialized.")
+ async with self.pool.get_connection() as conn:
+ if isolation_level:
+ async with conn.transaction(isolation=isolation_level):
+ if params:
+ return await conn.execute(query, *params)
+ else:
+ return await conn.execute(query)
+ else:
+ if params:
+ return await conn.execute(query, *params)
+ else:
+ return await conn.execute(query)
+
+ async def execute_many(self, query, params=None, batch_size=1000):
+ if not self.pool:
+ raise ValueError("PostgresConnectionManager is not initialized.")
+ async with self.pool.get_connection() as conn:
+ async with conn.transaction():
+ if params:
+ results = []
+ for i in range(0, len(params), batch_size):
+ param_batch = params[i : i + batch_size]
+ result = await conn.executemany(query, param_batch)
+ results.append(result)
+ return results
+ else:
+ return await conn.executemany(query)
+
+ async def fetch_query(self, query, params=None):
+ if not self.pool:
+ raise ValueError("PostgresConnectionManager is not initialized.")
+ try:
+ async with self.pool.get_connection() as conn:
+ async with conn.transaction():
+ return (
+ await conn.fetch(query, *params)
+ if params
+ else await conn.fetch(query)
+ )
+ except asyncpg.exceptions.DuplicatePreparedStatementError:
+ error_msg = textwrap.dedent("""
+ Database Configuration Error
+
+ Your database provider does not support statement caching.
+
+ To fix this, either:
+ • Set R2R_POSTGRES_STATEMENT_CACHE_SIZE=0 in your environment
+ • Add statement_cache_size = 0 to your database configuration:
+
+ [database.postgres_configuration_settings]
+ statement_cache_size = 0
+
+ This is required when using connection poolers like PgBouncer or
+ managed database services like Supabase.
+ """).strip()
+ raise ValueError(error_msg) from None
+
+ async def fetchrow_query(self, query, params=None):
+ if not self.pool:
+ raise ValueError("PostgresConnectionManager is not initialized.")
+ async with self.pool.get_connection() as conn:
+ async with conn.transaction():
+ if params:
+ return await conn.fetchrow(query, *params)
+ else:
+ return await conn.fetchrow(query)
+
+ @asynccontextmanager
+ async def transaction(self, isolation_level=None):
+ """Async context manager for database transactions.
+
+ Args:
+ isolation_level: Optional isolation level for the transaction
+
+ Yields:
+ The connection manager instance for use within the transaction
+ """
+ if not self.pool:
+ raise ValueError("PostgresConnectionManager is not initialized.")
+
+ async with self.pool.get_connection() as conn:
+ async with conn.transaction(isolation=isolation_level):
+ try:
+ yield self
+ except Exception as e:
+ logger.error(f"Transaction failed: {str(e)}")
+ raise