diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/database/base.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/database/base.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/core/providers/database/base.py | 247 |
1 files changed, 247 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/base.py b/.venv/lib/python3.12/site-packages/core/providers/database/base.py new file mode 100644 index 00000000..c70c1352 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/providers/database/base.py @@ -0,0 +1,247 @@ +import asyncio +import logging +import textwrap +from contextlib import asynccontextmanager +from typing import Optional + +import asyncpg + +from core.base.providers import DatabaseConnectionManager + +logger = logging.getLogger() + + +class SemaphoreConnectionPool: + def __init__(self, connection_string, postgres_configuration_settings): + self.connection_string = connection_string + self.postgres_configuration_settings = postgres_configuration_settings + + async def initialize(self): + try: + logger.info( + f"Connecting with {int(self.postgres_configuration_settings.max_connections * 0.9)} connections to `asyncpg.create_pool`." + ) + + self.semaphore = asyncio.Semaphore( + int(self.postgres_configuration_settings.max_connections * 0.9) + ) + + self.pool = await asyncpg.create_pool( + self.connection_string, + max_size=self.postgres_configuration_settings.max_connections, + statement_cache_size=self.postgres_configuration_settings.statement_cache_size, + ) + + logger.info( + "Successfully connected to Postgres database and created connection pool." + ) + except Exception as e: + raise ValueError( + f"Error {e} occurred while attempting to connect to relational database." + ) from e + + @asynccontextmanager + async def get_connection(self): + async with self.semaphore: + async with self.pool.acquire() as conn: + yield conn + + async def close(self): + await self.pool.close() + + +class QueryBuilder: + def __init__(self, table_name: str): + self.table_name = table_name + self.conditions: list[str] = [] + self.params: list = [] + self.select_fields = "*" + self.operation = "SELECT" + self.limit_value: Optional[int] = None + self.offset_value: Optional[int] = None + self.order_by_fields: Optional[str] = None + self.returning_fields: Optional[list[str]] = None + self.insert_data: Optional[dict] = None + self.update_data: Optional[dict] = None + self.param_counter = 1 + + def select(self, fields: list[str]): + self.select_fields = ", ".join(fields) + return self + + def insert(self, data: dict): + self.operation = "INSERT" + self.insert_data = data + return self + + def update(self, data: dict): + self.operation = "UPDATE" + self.update_data = data + return self + + def delete(self): + self.operation = "DELETE" + return self + + def where(self, condition: str): + self.conditions.append(condition) + return self + + def limit(self, value: Optional[int]): + self.limit_value = value + return self + + def offset(self, value: int): + self.offset_value = value + return self + + def order_by(self, fields: str): + self.order_by_fields = fields + return self + + def returning(self, fields: list[str]): + self.returning_fields = fields + return self + + def build(self): + if self.operation == "SELECT": + query = f"SELECT {self.select_fields} FROM {self.table_name}" + + elif self.operation == "INSERT": + columns = ", ".join(self.insert_data.keys()) + placeholders = ", ".join( + f"${i}" for i in range(1, len(self.insert_data) + 1) + ) + query = f"INSERT INTO {self.table_name} ({columns}) VALUES ({placeholders})" + self.params.extend(list(self.insert_data.values())) + + elif self.operation == "UPDATE": + set_clauses = [] + for i, (key, value) in enumerate( + self.update_data.items(), start=len(self.params) + 1 + ): + set_clauses.append(f"{key} = ${i}") + self.params.append(value) + query = f"UPDATE {self.table_name} SET {', '.join(set_clauses)}" + + elif self.operation == "DELETE": + query = f"DELETE FROM {self.table_name}" + + else: + raise ValueError(f"Unsupported operation: {self.operation}") + + if self.conditions: + query += " WHERE " + " AND ".join(self.conditions) + + if self.order_by_fields and self.operation == "SELECT": + query += f" ORDER BY {self.order_by_fields}" + + if self.offset_value is not None: + query += f" OFFSET {self.offset_value}" + + if self.limit_value is not None: + query += f" LIMIT {self.limit_value}" + + if self.returning_fields: + query += f" RETURNING {', '.join(self.returning_fields)}" + + return query, self.params + + +class PostgresConnectionManager(DatabaseConnectionManager): + def __init__(self): + self.pool: Optional[SemaphoreConnectionPool] = None + + async def initialize(self, pool: SemaphoreConnectionPool): + self.pool = pool + + async def execute_query(self, query, params=None, isolation_level=None): + if not self.pool: + raise ValueError("PostgresConnectionManager is not initialized.") + async with self.pool.get_connection() as conn: + if isolation_level: + async with conn.transaction(isolation=isolation_level): + if params: + return await conn.execute(query, *params) + else: + return await conn.execute(query) + else: + if params: + return await conn.execute(query, *params) + else: + return await conn.execute(query) + + async def execute_many(self, query, params=None, batch_size=1000): + if not self.pool: + raise ValueError("PostgresConnectionManager is not initialized.") + async with self.pool.get_connection() as conn: + async with conn.transaction(): + if params: + results = [] + for i in range(0, len(params), batch_size): + param_batch = params[i : i + batch_size] + result = await conn.executemany(query, param_batch) + results.append(result) + return results + else: + return await conn.executemany(query) + + async def fetch_query(self, query, params=None): + if not self.pool: + raise ValueError("PostgresConnectionManager is not initialized.") + try: + async with self.pool.get_connection() as conn: + async with conn.transaction(): + return ( + await conn.fetch(query, *params) + if params + else await conn.fetch(query) + ) + except asyncpg.exceptions.DuplicatePreparedStatementError: + error_msg = textwrap.dedent(""" + Database Configuration Error + + Your database provider does not support statement caching. + + To fix this, either: + • Set R2R_POSTGRES_STATEMENT_CACHE_SIZE=0 in your environment + • Add statement_cache_size = 0 to your database configuration: + + [database.postgres_configuration_settings] + statement_cache_size = 0 + + This is required when using connection poolers like PgBouncer or + managed database services like Supabase. + """).strip() + raise ValueError(error_msg) from None + + async def fetchrow_query(self, query, params=None): + if not self.pool: + raise ValueError("PostgresConnectionManager is not initialized.") + async with self.pool.get_connection() as conn: + async with conn.transaction(): + if params: + return await conn.fetchrow(query, *params) + else: + return await conn.fetchrow(query) + + @asynccontextmanager + async def transaction(self, isolation_level=None): + """Async context manager for database transactions. + + Args: + isolation_level: Optional isolation level for the transaction + + Yields: + The connection manager instance for use within the transaction + """ + if not self.pool: + raise ValueError("PostgresConnectionManager is not initialized.") + + async with self.pool.get_connection() as conn: + async with conn.transaction(isolation=isolation_level): + try: + yield self + except Exception as e: + logger.error(f"Transaction failed: {str(e)}") + raise |