aboutsummaryrefslogtreecommitdiff
import asyncio
import logging
import textwrap
from contextlib import asynccontextmanager
from typing import Optional

import asyncpg

from core.base.providers import DatabaseConnectionManager

logger = logging.getLogger()


class SemaphoreConnectionPool:
    def __init__(self, connection_string, postgres_configuration_settings):
        self.connection_string = connection_string
        self.postgres_configuration_settings = postgres_configuration_settings

    async def initialize(self):
        try:
            logger.info(
                f"Connecting with {int(self.postgres_configuration_settings.max_connections * 0.9)} connections to `asyncpg.create_pool`."
            )

            self.semaphore = asyncio.Semaphore(
                int(self.postgres_configuration_settings.max_connections * 0.9)
            )

            self.pool = await asyncpg.create_pool(
                self.connection_string,
                max_size=self.postgres_configuration_settings.max_connections,
                statement_cache_size=self.postgres_configuration_settings.statement_cache_size,
            )

            logger.info(
                "Successfully connected to Postgres database and created connection pool."
            )
        except Exception as e:
            raise ValueError(
                f"Error {e} occurred while attempting to connect to relational database."
            ) from e

    @asynccontextmanager
    async def get_connection(self):
        async with self.semaphore:
            async with self.pool.acquire() as conn:
                yield conn

    async def close(self):
        await self.pool.close()


class QueryBuilder:
    def __init__(self, table_name: str):
        self.table_name = table_name
        self.conditions: list[str] = []
        self.params: list = []
        self.select_fields = "*"
        self.operation = "SELECT"
        self.limit_value: Optional[int] = None
        self.offset_value: Optional[int] = None
        self.order_by_fields: Optional[str] = None
        self.returning_fields: Optional[list[str]] = None
        self.insert_data: Optional[dict] = None
        self.update_data: Optional[dict] = None
        self.param_counter = 1

    def select(self, fields: list[str]):
        self.select_fields = ", ".join(fields)
        return self

    def insert(self, data: dict):
        self.operation = "INSERT"
        self.insert_data = data
        return self

    def update(self, data: dict):
        self.operation = "UPDATE"
        self.update_data = data
        return self

    def delete(self):
        self.operation = "DELETE"
        return self

    def where(self, condition: str):
        self.conditions.append(condition)
        return self

    def limit(self, value: Optional[int]):
        self.limit_value = value
        return self

    def offset(self, value: int):
        self.offset_value = value
        return self

    def order_by(self, fields: str):
        self.order_by_fields = fields
        return self

    def returning(self, fields: list[str]):
        self.returning_fields = fields
        return self

    def build(self):
        if self.operation == "SELECT":
            query = f"SELECT {self.select_fields} FROM {self.table_name}"

        elif self.operation == "INSERT":
            columns = ", ".join(self.insert_data.keys())
            placeholders = ", ".join(
                f"${i}" for i in range(1, len(self.insert_data) + 1)
            )
            query = f"INSERT INTO {self.table_name} ({columns}) VALUES ({placeholders})"
            self.params.extend(list(self.insert_data.values()))

        elif self.operation == "UPDATE":
            set_clauses = []
            for i, (key, value) in enumerate(
                self.update_data.items(), start=len(self.params) + 1
            ):
                set_clauses.append(f"{key} = ${i}")
                self.params.append(value)
            query = f"UPDATE {self.table_name} SET {', '.join(set_clauses)}"

        elif self.operation == "DELETE":
            query = f"DELETE FROM {self.table_name}"

        else:
            raise ValueError(f"Unsupported operation: {self.operation}")

        if self.conditions:
            query += " WHERE " + " AND ".join(self.conditions)

        if self.order_by_fields and self.operation == "SELECT":
            query += f" ORDER BY {self.order_by_fields}"

        if self.offset_value is not None:
            query += f" OFFSET {self.offset_value}"

        if self.limit_value is not None:
            query += f" LIMIT {self.limit_value}"

        if self.returning_fields:
            query += f" RETURNING {', '.join(self.returning_fields)}"

        return query, self.params


class PostgresConnectionManager(DatabaseConnectionManager):
    def __init__(self):
        self.pool: Optional[SemaphoreConnectionPool] = None

    async def initialize(self, pool: SemaphoreConnectionPool):
        self.pool = pool

    async def execute_query(self, query, params=None, isolation_level=None):
        if not self.pool:
            raise ValueError("PostgresConnectionManager is not initialized.")
        async with self.pool.get_connection() as conn:
            if isolation_level:
                async with conn.transaction(isolation=isolation_level):
                    if params:
                        return await conn.execute(query, *params)
                    else:
                        return await conn.execute(query)
            else:
                if params:
                    return await conn.execute(query, *params)
                else:
                    return await conn.execute(query)

    async def execute_many(self, query, params=None, batch_size=1000):
        if not self.pool:
            raise ValueError("PostgresConnectionManager is not initialized.")
        async with self.pool.get_connection() as conn:
            async with conn.transaction():
                if params:
                    results = []
                    for i in range(0, len(params), batch_size):
                        param_batch = params[i : i + batch_size]
                        result = await conn.executemany(query, param_batch)
                        results.append(result)
                    return results
                else:
                    return await conn.executemany(query)

    async def fetch_query(self, query, params=None):
        if not self.pool:
            raise ValueError("PostgresConnectionManager is not initialized.")
        try:
            async with self.pool.get_connection() as conn:
                async with conn.transaction():
                    return (
                        await conn.fetch(query, *params)
                        if params
                        else await conn.fetch(query)
                    )
        except asyncpg.exceptions.DuplicatePreparedStatementError:
            error_msg = textwrap.dedent("""
                Database Configuration Error

                Your database provider does not support statement caching.

                To fix this, either:
                • Set R2R_POSTGRES_STATEMENT_CACHE_SIZE=0 in your environment
                • Add statement_cache_size = 0 to your database configuration:

                    [database.postgres_configuration_settings]
                    statement_cache_size = 0

                This is required when using connection poolers like PgBouncer or
                managed database services like Supabase.
            """).strip()
            raise ValueError(error_msg) from None

    async def fetchrow_query(self, query, params=None):
        if not self.pool:
            raise ValueError("PostgresConnectionManager is not initialized.")
        async with self.pool.get_connection() as conn:
            async with conn.transaction():
                if params:
                    return await conn.fetchrow(query, *params)
                else:
                    return await conn.fetchrow(query)

    @asynccontextmanager
    async def transaction(self, isolation_level=None):
        """Async context manager for database transactions.

        Args:
            isolation_level: Optional isolation level for the transaction

        Yields:
            The connection manager instance for use within the transaction
        """
        if not self.pool:
            raise ValueError("PostgresConnectionManager is not initialized.")

        async with self.pool.get_connection() as conn:
            async with conn.transaction(isolation=isolation_level):
                try:
                    yield self
                except Exception as e:
                    logger.error(f"Transaction failed: {str(e)}")
                    raise