diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/database/postgres.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/database/postgres.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/core/providers/database/postgres.py | 286 |
1 files changed, 286 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/postgres.py b/.venv/lib/python3.12/site-packages/core/providers/database/postgres.py new file mode 100644 index 00000000..acccc9c0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/providers/database/postgres.py @@ -0,0 +1,286 @@ +# TODO: Clean this up and make it more congruent across the vector database and the relational database. +import logging +import os +from typing import TYPE_CHECKING, Any, Optional + +from ...base.abstractions import VectorQuantizationType +from ...base.providers import ( + DatabaseConfig, + DatabaseProvider, + PostgresConfigurationSettings, +) +from .base import PostgresConnectionManager, SemaphoreConnectionPool +from .chunks import PostgresChunksHandler +from .collections import PostgresCollectionsHandler +from .conversations import PostgresConversationsHandler +from .documents import PostgresDocumentsHandler +from .files import PostgresFilesHandler +from .graphs import ( + PostgresCommunitiesHandler, + PostgresEntitiesHandler, + PostgresGraphsHandler, + PostgresRelationshipsHandler, +) +from .limits import PostgresLimitsHandler +from .prompts_handler import PostgresPromptsHandler +from .tokens import PostgresTokensHandler +from .users import PostgresUserHandler + +if TYPE_CHECKING: + from ..crypto import BCryptCryptoProvider, NaClCryptoProvider + + CryptoProviderType = BCryptCryptoProvider | NaClCryptoProvider + +logger = logging.getLogger() + + +class PostgresDatabaseProvider(DatabaseProvider): + # R2R configuration settings + config: DatabaseConfig + project_name: str + + # Postgres connection settings + user: str + password: str + host: str + port: int + db_name: str + connection_string: str + dimension: int | float + conn: Optional[Any] + + crypto_provider: "CryptoProviderType" + postgres_configuration_settings: PostgresConfigurationSettings + default_collection_name: str + default_collection_description: str + + connection_manager: PostgresConnectionManager + documents_handler: PostgresDocumentsHandler + collections_handler: PostgresCollectionsHandler + token_handler: PostgresTokensHandler + users_handler: PostgresUserHandler + chunks_handler: PostgresChunksHandler + entities_handler: PostgresEntitiesHandler + communities_handler: PostgresCommunitiesHandler + relationships_handler: PostgresRelationshipsHandler + graphs_handler: PostgresGraphsHandler + prompts_handler: PostgresPromptsHandler + files_handler: PostgresFilesHandler + conversations_handler: PostgresConversationsHandler + limits_handler: PostgresLimitsHandler + + def __init__( + self, + config: DatabaseConfig, + dimension: int | float, + crypto_provider: "BCryptCryptoProvider | NaClCryptoProvider", + quantization_type: VectorQuantizationType = VectorQuantizationType.FP32, + *args, + **kwargs, + ): + super().__init__(config) + + env_vars = [ + ("user", "R2R_POSTGRES_USER"), + ("password", "R2R_POSTGRES_PASSWORD"), + ("host", "R2R_POSTGRES_HOST"), + ("port", "R2R_POSTGRES_PORT"), + ("db_name", "R2R_POSTGRES_DBNAME"), + ] + + for attr, env_var in env_vars: + if value := (getattr(config, attr) or os.getenv(env_var)): + setattr(self, attr, value) + else: + raise ValueError( + f"Error, please set a valid {env_var} environment variable or set a '{attr}' in the 'database' settings of your `r2r.toml`." + ) + + self.port = int(self.port) + + self.project_name = ( + config.app.project_name + or os.getenv("R2R_PROJECT_NAME") + or "r2r_default" + ) + + if not self.project_name: + raise ValueError( + "Error, please set a valid R2R_PROJECT_NAME environment variable or set a 'project_name' in the 'database' settings of your `r2r.toml`." + ) + + # Check if it's a Unix socket connection + if self.host.startswith("/") and not self.port: + self.connection_string = f"postgresql://{self.user}:{self.password}@/{self.db_name}?host={self.host}" + logger.info("Connecting to Postgres via Unix socket") + else: + self.connection_string = f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}" + logger.info("Connecting to Postgres via TCP/IP") + + self.dimension = dimension + self.quantization_type = quantization_type + self.conn = None + self.config: DatabaseConfig = config + self.crypto_provider = crypto_provider + self.postgres_configuration_settings: PostgresConfigurationSettings = ( + self._get_postgres_configuration_settings(config) + ) + self.default_collection_name = config.default_collection_name + self.default_collection_description = ( + config.default_collection_description + ) + + self.connection_manager: PostgresConnectionManager = ( + PostgresConnectionManager() + ) + self.documents_handler = PostgresDocumentsHandler( + project_name=self.project_name, + connection_manager=self.connection_manager, + dimension=self.dimension, + ) + self.token_handler = PostgresTokensHandler( + self.project_name, self.connection_manager + ) + self.collections_handler = PostgresCollectionsHandler( + self.project_name, self.connection_manager, self.config + ) + self.users_handler = PostgresUserHandler( + self.project_name, self.connection_manager, self.crypto_provider + ) + self.chunks_handler = PostgresChunksHandler( + project_name=self.project_name, + connection_manager=self.connection_manager, + dimension=self.dimension, + quantization_type=(self.quantization_type), + ) + self.conversations_handler = PostgresConversationsHandler( + self.project_name, self.connection_manager + ) + self.entities_handler = PostgresEntitiesHandler( + project_name=self.project_name, + connection_manager=self.connection_manager, + collections_handler=self.collections_handler, + dimension=self.dimension, + quantization_type=self.quantization_type, + ) + self.relationships_handler = PostgresRelationshipsHandler( + project_name=self.project_name, + connection_manager=self.connection_manager, + collections_handler=self.collections_handler, + dimension=self.dimension, + quantization_type=self.quantization_type, + ) + self.communities_handler = PostgresCommunitiesHandler( + project_name=self.project_name, + connection_manager=self.connection_manager, + collections_handler=self.collections_handler, + dimension=self.dimension, + quantization_type=self.quantization_type, + ) + self.graphs_handler = PostgresGraphsHandler( + project_name=self.project_name, + connection_manager=self.connection_manager, + collections_handler=self.collections_handler, + dimension=self.dimension, + quantization_type=self.quantization_type, + ) + self.prompts_handler = PostgresPromptsHandler( + self.project_name, self.connection_manager + ) + self.files_handler = PostgresFilesHandler( + self.project_name, self.connection_manager + ) + + self.limits_handler = PostgresLimitsHandler( + project_name=self.project_name, + connection_manager=self.connection_manager, + config=self.config, + ) + + async def initialize(self): + logger.info("Initializing `PostgresDatabaseProvider`.") + self.pool = SemaphoreConnectionPool( + self.connection_string, self.postgres_configuration_settings + ) + await self.pool.initialize() + await self.connection_manager.initialize(self.pool) + + async with self.pool.get_connection() as conn: + await conn.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') + await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;") + await conn.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;") + await conn.execute("CREATE EXTENSION IF NOT EXISTS fuzzystrmatch;") + + # Create schema if it doesn't exist + await conn.execute( + f'CREATE SCHEMA IF NOT EXISTS "{self.project_name}";' + ) + + await self.documents_handler.create_tables() + await self.collections_handler.create_tables() + await self.token_handler.create_tables() + await self.users_handler.create_tables() + await self.chunks_handler.create_tables() + await self.prompts_handler.create_tables() + await self.files_handler.create_tables() + await self.graphs_handler.create_tables() + await self.communities_handler.create_tables() + await self.entities_handler.create_tables() + await self.relationships_handler.create_tables() + await self.conversations_handler.create_tables() + await self.limits_handler.create_tables() + + def _get_postgres_configuration_settings( + self, config: DatabaseConfig + ) -> PostgresConfigurationSettings: + settings = PostgresConfigurationSettings() + + env_mapping = { + "checkpoint_completion_target": "R2R_POSTGRES_CHECKPOINT_COMPLETION_TARGET", + "default_statistics_target": "R2R_POSTGRES_DEFAULT_STATISTICS_TARGET", + "effective_cache_size": "R2R_POSTGRES_EFFECTIVE_CACHE_SIZE", + "effective_io_concurrency": "R2R_POSTGRES_EFFECTIVE_IO_CONCURRENCY", + "huge_pages": "R2R_POSTGRES_HUGE_PAGES", + "maintenance_work_mem": "R2R_POSTGRES_MAINTENANCE_WORK_MEM", + "min_wal_size": "R2R_POSTGRES_MIN_WAL_SIZE", + "max_connections": "R2R_POSTGRES_MAX_CONNECTIONS", + "max_parallel_workers_per_gather": "R2R_POSTGRES_MAX_PARALLEL_WORKERS_PER_GATHER", + "max_parallel_workers": "R2R_POSTGRES_MAX_PARALLEL_WORKERS", + "max_parallel_maintenance_workers": "R2R_POSTGRES_MAX_PARALLEL_MAINTENANCE_WORKERS", + "max_wal_size": "R2R_POSTGRES_MAX_WAL_SIZE", + "max_worker_processes": "R2R_POSTGRES_MAX_WORKER_PROCESSES", + "random_page_cost": "R2R_POSTGRES_RANDOM_PAGE_COST", + "statement_cache_size": "R2R_POSTGRES_STATEMENT_CACHE_SIZE", + "shared_buffers": "R2R_POSTGRES_SHARED_BUFFERS", + "wal_buffers": "R2R_POSTGRES_WAL_BUFFERS", + "work_mem": "R2R_POSTGRES_WORK_MEM", + } + + for setting, env_var in env_mapping.items(): + value = getattr( + config.postgres_configuration_settings, setting, None + ) + if value is None: + value = os.getenv(env_var) + + if value is not None: + field_type = settings.__annotations__[setting] + if field_type == Optional[int]: + value = int(value) + elif field_type == Optional[float]: + value = float(value) + + setattr(settings, setting, value) + + return settings + + async def close(self): + if self.pool: + await self.pool.close() + + async def __aenter__(self): + await self.initialize() + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.close() |