aboutsummaryrefslogtreecommitdiff
# TODO: Clean this up and make it more congruent across the vector database and the relational database.
import logging
import os
from typing import TYPE_CHECKING, Any, Optional

from ...base.abstractions import VectorQuantizationType
from ...base.providers import (
    DatabaseConfig,
    DatabaseProvider,
    PostgresConfigurationSettings,
)
from .base import PostgresConnectionManager, SemaphoreConnectionPool
from .chunks import PostgresChunksHandler
from .collections import PostgresCollectionsHandler
from .conversations import PostgresConversationsHandler
from .documents import PostgresDocumentsHandler
from .files import PostgresFilesHandler
from .graphs import (
    PostgresCommunitiesHandler,
    PostgresEntitiesHandler,
    PostgresGraphsHandler,
    PostgresRelationshipsHandler,
)
from .limits import PostgresLimitsHandler
from .prompts_handler import PostgresPromptsHandler
from .tokens import PostgresTokensHandler
from .users import PostgresUserHandler

if TYPE_CHECKING:
    from ..crypto import BCryptCryptoProvider, NaClCryptoProvider

    CryptoProviderType = BCryptCryptoProvider | NaClCryptoProvider

logger = logging.getLogger()


class PostgresDatabaseProvider(DatabaseProvider):
    # R2R configuration settings
    config: DatabaseConfig
    project_name: str

    # Postgres connection settings
    user: str
    password: str
    host: str
    port: int
    db_name: str
    connection_string: str
    dimension: int | float
    conn: Optional[Any]

    crypto_provider: "CryptoProviderType"
    postgres_configuration_settings: PostgresConfigurationSettings
    default_collection_name: str
    default_collection_description: str

    connection_manager: PostgresConnectionManager
    documents_handler: PostgresDocumentsHandler
    collections_handler: PostgresCollectionsHandler
    token_handler: PostgresTokensHandler
    users_handler: PostgresUserHandler
    chunks_handler: PostgresChunksHandler
    entities_handler: PostgresEntitiesHandler
    communities_handler: PostgresCommunitiesHandler
    relationships_handler: PostgresRelationshipsHandler
    graphs_handler: PostgresGraphsHandler
    prompts_handler: PostgresPromptsHandler
    files_handler: PostgresFilesHandler
    conversations_handler: PostgresConversationsHandler
    limits_handler: PostgresLimitsHandler

    def __init__(
        self,
        config: DatabaseConfig,
        dimension: int | float,
        crypto_provider: "BCryptCryptoProvider | NaClCryptoProvider",
        quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
        *args,
        **kwargs,
    ):
        super().__init__(config)

        env_vars = [
            ("user", "R2R_POSTGRES_USER"),
            ("password", "R2R_POSTGRES_PASSWORD"),
            ("host", "R2R_POSTGRES_HOST"),
            ("port", "R2R_POSTGRES_PORT"),
            ("db_name", "R2R_POSTGRES_DBNAME"),
        ]

        for attr, env_var in env_vars:
            if value := (getattr(config, attr) or os.getenv(env_var)):
                setattr(self, attr, value)
            else:
                raise ValueError(
                    f"Error, please set a valid {env_var} environment variable or set a '{attr}' in the 'database' settings of your `r2r.toml`."
                )

        self.port = int(self.port)

        self.project_name = (
            config.app.project_name
            or os.getenv("R2R_PROJECT_NAME")
            or "r2r_default"
        )

        if not self.project_name:
            raise ValueError(
                "Error, please set a valid R2R_PROJECT_NAME environment variable or set a 'project_name' in the 'database' settings of your `r2r.toml`."
            )

        # Check if it's a Unix socket connection
        if self.host.startswith("/") and not self.port:
            self.connection_string = f"postgresql://{self.user}:{self.password}@/{self.db_name}?host={self.host}"
            logger.info("Connecting to Postgres via Unix socket")
        else:
            self.connection_string = f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}"
            logger.info("Connecting to Postgres via TCP/IP")

        self.dimension = dimension
        self.quantization_type = quantization_type
        self.conn = None
        self.config: DatabaseConfig = config
        self.crypto_provider = crypto_provider
        self.postgres_configuration_settings: PostgresConfigurationSettings = (
            self._get_postgres_configuration_settings(config)
        )
        self.default_collection_name = config.default_collection_name
        self.default_collection_description = (
            config.default_collection_description
        )

        self.connection_manager: PostgresConnectionManager = (
            PostgresConnectionManager()
        )
        self.documents_handler = PostgresDocumentsHandler(
            project_name=self.project_name,
            connection_manager=self.connection_manager,
            dimension=self.dimension,
        )
        self.token_handler = PostgresTokensHandler(
            self.project_name, self.connection_manager
        )
        self.collections_handler = PostgresCollectionsHandler(
            self.project_name, self.connection_manager, self.config
        )
        self.users_handler = PostgresUserHandler(
            self.project_name, self.connection_manager, self.crypto_provider
        )
        self.chunks_handler = PostgresChunksHandler(
            project_name=self.project_name,
            connection_manager=self.connection_manager,
            dimension=self.dimension,
            quantization_type=(self.quantization_type),
        )
        self.conversations_handler = PostgresConversationsHandler(
            self.project_name, self.connection_manager
        )
        self.entities_handler = PostgresEntitiesHandler(
            project_name=self.project_name,
            connection_manager=self.connection_manager,
            collections_handler=self.collections_handler,
            dimension=self.dimension,
            quantization_type=self.quantization_type,
        )
        self.relationships_handler = PostgresRelationshipsHandler(
            project_name=self.project_name,
            connection_manager=self.connection_manager,
            collections_handler=self.collections_handler,
            dimension=self.dimension,
            quantization_type=self.quantization_type,
        )
        self.communities_handler = PostgresCommunitiesHandler(
            project_name=self.project_name,
            connection_manager=self.connection_manager,
            collections_handler=self.collections_handler,
            dimension=self.dimension,
            quantization_type=self.quantization_type,
        )
        self.graphs_handler = PostgresGraphsHandler(
            project_name=self.project_name,
            connection_manager=self.connection_manager,
            collections_handler=self.collections_handler,
            dimension=self.dimension,
            quantization_type=self.quantization_type,
        )
        self.prompts_handler = PostgresPromptsHandler(
            self.project_name, self.connection_manager
        )
        self.files_handler = PostgresFilesHandler(
            self.project_name, self.connection_manager
        )

        self.limits_handler = PostgresLimitsHandler(
            project_name=self.project_name,
            connection_manager=self.connection_manager,
            config=self.config,
        )

    async def initialize(self):
        logger.info("Initializing `PostgresDatabaseProvider`.")
        self.pool = SemaphoreConnectionPool(
            self.connection_string, self.postgres_configuration_settings
        )
        await self.pool.initialize()
        await self.connection_manager.initialize(self.pool)

        async with self.pool.get_connection() as conn:
            await conn.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
            await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;")
            await conn.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
            await conn.execute("CREATE EXTENSION IF NOT EXISTS fuzzystrmatch;")

            # Create schema if it doesn't exist
            await conn.execute(
                f'CREATE SCHEMA IF NOT EXISTS "{self.project_name}";'
            )

        await self.documents_handler.create_tables()
        await self.collections_handler.create_tables()
        await self.token_handler.create_tables()
        await self.users_handler.create_tables()
        await self.chunks_handler.create_tables()
        await self.prompts_handler.create_tables()
        await self.files_handler.create_tables()
        await self.graphs_handler.create_tables()
        await self.communities_handler.create_tables()
        await self.entities_handler.create_tables()
        await self.relationships_handler.create_tables()
        await self.conversations_handler.create_tables()
        await self.limits_handler.create_tables()

    def _get_postgres_configuration_settings(
        self, config: DatabaseConfig
    ) -> PostgresConfigurationSettings:
        settings = PostgresConfigurationSettings()

        env_mapping = {
            "checkpoint_completion_target": "R2R_POSTGRES_CHECKPOINT_COMPLETION_TARGET",
            "default_statistics_target": "R2R_POSTGRES_DEFAULT_STATISTICS_TARGET",
            "effective_cache_size": "R2R_POSTGRES_EFFECTIVE_CACHE_SIZE",
            "effective_io_concurrency": "R2R_POSTGRES_EFFECTIVE_IO_CONCURRENCY",
            "huge_pages": "R2R_POSTGRES_HUGE_PAGES",
            "maintenance_work_mem": "R2R_POSTGRES_MAINTENANCE_WORK_MEM",
            "min_wal_size": "R2R_POSTGRES_MIN_WAL_SIZE",
            "max_connections": "R2R_POSTGRES_MAX_CONNECTIONS",
            "max_parallel_workers_per_gather": "R2R_POSTGRES_MAX_PARALLEL_WORKERS_PER_GATHER",
            "max_parallel_workers": "R2R_POSTGRES_MAX_PARALLEL_WORKERS",
            "max_parallel_maintenance_workers": "R2R_POSTGRES_MAX_PARALLEL_MAINTENANCE_WORKERS",
            "max_wal_size": "R2R_POSTGRES_MAX_WAL_SIZE",
            "max_worker_processes": "R2R_POSTGRES_MAX_WORKER_PROCESSES",
            "random_page_cost": "R2R_POSTGRES_RANDOM_PAGE_COST",
            "statement_cache_size": "R2R_POSTGRES_STATEMENT_CACHE_SIZE",
            "shared_buffers": "R2R_POSTGRES_SHARED_BUFFERS",
            "wal_buffers": "R2R_POSTGRES_WAL_BUFFERS",
            "work_mem": "R2R_POSTGRES_WORK_MEM",
        }

        for setting, env_var in env_mapping.items():
            value = getattr(
                config.postgres_configuration_settings, setting, None
            )
            if value is None:
                value = os.getenv(env_var)

            if value is not None:
                field_type = settings.__annotations__[setting]
                if field_type == Optional[int]:
                    value = int(value)
                elif field_type == Optional[float]:
                    value = float(value)

                setattr(settings, setting, value)

        return settings

    async def close(self):
        if self.pool:
            await self.pool.close()

    async def __aenter__(self):
        await self.initialize()
        return self

    async def __aexit__(self, exc_type, exc, tb):
        await self.close()