aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/providers/database/postgres.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/database/postgres.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/database/postgres.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/postgres.py286
1 files changed, 286 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/postgres.py b/.venv/lib/python3.12/site-packages/core/providers/database/postgres.py
new file mode 100644
index 00000000..acccc9c0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/postgres.py
@@ -0,0 +1,286 @@
+# TODO: Clean this up and make it more congruent across the vector database and the relational database.
+import logging
+import os
+from typing import TYPE_CHECKING, Any, Optional
+
+from ...base.abstractions import VectorQuantizationType
+from ...base.providers import (
+ DatabaseConfig,
+ DatabaseProvider,
+ PostgresConfigurationSettings,
+)
+from .base import PostgresConnectionManager, SemaphoreConnectionPool
+from .chunks import PostgresChunksHandler
+from .collections import PostgresCollectionsHandler
+from .conversations import PostgresConversationsHandler
+from .documents import PostgresDocumentsHandler
+from .files import PostgresFilesHandler
+from .graphs import (
+ PostgresCommunitiesHandler,
+ PostgresEntitiesHandler,
+ PostgresGraphsHandler,
+ PostgresRelationshipsHandler,
+)
+from .limits import PostgresLimitsHandler
+from .prompts_handler import PostgresPromptsHandler
+from .tokens import PostgresTokensHandler
+from .users import PostgresUserHandler
+
+if TYPE_CHECKING:
+ from ..crypto import BCryptCryptoProvider, NaClCryptoProvider
+
+ CryptoProviderType = BCryptCryptoProvider | NaClCryptoProvider
+
+logger = logging.getLogger()
+
+
+class PostgresDatabaseProvider(DatabaseProvider):
+ # R2R configuration settings
+ config: DatabaseConfig
+ project_name: str
+
+ # Postgres connection settings
+ user: str
+ password: str
+ host: str
+ port: int
+ db_name: str
+ connection_string: str
+ dimension: int | float
+ conn: Optional[Any]
+
+ crypto_provider: "CryptoProviderType"
+ postgres_configuration_settings: PostgresConfigurationSettings
+ default_collection_name: str
+ default_collection_description: str
+
+ connection_manager: PostgresConnectionManager
+ documents_handler: PostgresDocumentsHandler
+ collections_handler: PostgresCollectionsHandler
+ token_handler: PostgresTokensHandler
+ users_handler: PostgresUserHandler
+ chunks_handler: PostgresChunksHandler
+ entities_handler: PostgresEntitiesHandler
+ communities_handler: PostgresCommunitiesHandler
+ relationships_handler: PostgresRelationshipsHandler
+ graphs_handler: PostgresGraphsHandler
+ prompts_handler: PostgresPromptsHandler
+ files_handler: PostgresFilesHandler
+ conversations_handler: PostgresConversationsHandler
+ limits_handler: PostgresLimitsHandler
+
+ def __init__(
+ self,
+ config: DatabaseConfig,
+ dimension: int | float,
+ crypto_provider: "BCryptCryptoProvider | NaClCryptoProvider",
+ quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(config)
+
+ env_vars = [
+ ("user", "R2R_POSTGRES_USER"),
+ ("password", "R2R_POSTGRES_PASSWORD"),
+ ("host", "R2R_POSTGRES_HOST"),
+ ("port", "R2R_POSTGRES_PORT"),
+ ("db_name", "R2R_POSTGRES_DBNAME"),
+ ]
+
+ for attr, env_var in env_vars:
+ if value := (getattr(config, attr) or os.getenv(env_var)):
+ setattr(self, attr, value)
+ else:
+ raise ValueError(
+ f"Error, please set a valid {env_var} environment variable or set a '{attr}' in the 'database' settings of your `r2r.toml`."
+ )
+
+ self.port = int(self.port)
+
+ self.project_name = (
+ config.app.project_name
+ or os.getenv("R2R_PROJECT_NAME")
+ or "r2r_default"
+ )
+
+ if not self.project_name:
+ raise ValueError(
+ "Error, please set a valid R2R_PROJECT_NAME environment variable or set a 'project_name' in the 'database' settings of your `r2r.toml`."
+ )
+
+ # Check if it's a Unix socket connection
+ if self.host.startswith("/") and not self.port:
+ self.connection_string = f"postgresql://{self.user}:{self.password}@/{self.db_name}?host={self.host}"
+ logger.info("Connecting to Postgres via Unix socket")
+ else:
+ self.connection_string = f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}"
+ logger.info("Connecting to Postgres via TCP/IP")
+
+ self.dimension = dimension
+ self.quantization_type = quantization_type
+ self.conn = None
+ self.config: DatabaseConfig = config
+ self.crypto_provider = crypto_provider
+ self.postgres_configuration_settings: PostgresConfigurationSettings = (
+ self._get_postgres_configuration_settings(config)
+ )
+ self.default_collection_name = config.default_collection_name
+ self.default_collection_description = (
+ config.default_collection_description
+ )
+
+ self.connection_manager: PostgresConnectionManager = (
+ PostgresConnectionManager()
+ )
+ self.documents_handler = PostgresDocumentsHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ dimension=self.dimension,
+ )
+ self.token_handler = PostgresTokensHandler(
+ self.project_name, self.connection_manager
+ )
+ self.collections_handler = PostgresCollectionsHandler(
+ self.project_name, self.connection_manager, self.config
+ )
+ self.users_handler = PostgresUserHandler(
+ self.project_name, self.connection_manager, self.crypto_provider
+ )
+ self.chunks_handler = PostgresChunksHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ dimension=self.dimension,
+ quantization_type=(self.quantization_type),
+ )
+ self.conversations_handler = PostgresConversationsHandler(
+ self.project_name, self.connection_manager
+ )
+ self.entities_handler = PostgresEntitiesHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ collections_handler=self.collections_handler,
+ dimension=self.dimension,
+ quantization_type=self.quantization_type,
+ )
+ self.relationships_handler = PostgresRelationshipsHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ collections_handler=self.collections_handler,
+ dimension=self.dimension,
+ quantization_type=self.quantization_type,
+ )
+ self.communities_handler = PostgresCommunitiesHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ collections_handler=self.collections_handler,
+ dimension=self.dimension,
+ quantization_type=self.quantization_type,
+ )
+ self.graphs_handler = PostgresGraphsHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ collections_handler=self.collections_handler,
+ dimension=self.dimension,
+ quantization_type=self.quantization_type,
+ )
+ self.prompts_handler = PostgresPromptsHandler(
+ self.project_name, self.connection_manager
+ )
+ self.files_handler = PostgresFilesHandler(
+ self.project_name, self.connection_manager
+ )
+
+ self.limits_handler = PostgresLimitsHandler(
+ project_name=self.project_name,
+ connection_manager=self.connection_manager,
+ config=self.config,
+ )
+
+ async def initialize(self):
+ logger.info("Initializing `PostgresDatabaseProvider`.")
+ self.pool = SemaphoreConnectionPool(
+ self.connection_string, self.postgres_configuration_settings
+ )
+ await self.pool.initialize()
+ await self.connection_manager.initialize(self.pool)
+
+ async with self.pool.get_connection() as conn:
+ await conn.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
+ await conn.execute("CREATE EXTENSION IF NOT EXISTS vector;")
+ await conn.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
+ await conn.execute("CREATE EXTENSION IF NOT EXISTS fuzzystrmatch;")
+
+ # Create schema if it doesn't exist
+ await conn.execute(
+ f'CREATE SCHEMA IF NOT EXISTS "{self.project_name}";'
+ )
+
+ await self.documents_handler.create_tables()
+ await self.collections_handler.create_tables()
+ await self.token_handler.create_tables()
+ await self.users_handler.create_tables()
+ await self.chunks_handler.create_tables()
+ await self.prompts_handler.create_tables()
+ await self.files_handler.create_tables()
+ await self.graphs_handler.create_tables()
+ await self.communities_handler.create_tables()
+ await self.entities_handler.create_tables()
+ await self.relationships_handler.create_tables()
+ await self.conversations_handler.create_tables()
+ await self.limits_handler.create_tables()
+
+ def _get_postgres_configuration_settings(
+ self, config: DatabaseConfig
+ ) -> PostgresConfigurationSettings:
+ settings = PostgresConfigurationSettings()
+
+ env_mapping = {
+ "checkpoint_completion_target": "R2R_POSTGRES_CHECKPOINT_COMPLETION_TARGET",
+ "default_statistics_target": "R2R_POSTGRES_DEFAULT_STATISTICS_TARGET",
+ "effective_cache_size": "R2R_POSTGRES_EFFECTIVE_CACHE_SIZE",
+ "effective_io_concurrency": "R2R_POSTGRES_EFFECTIVE_IO_CONCURRENCY",
+ "huge_pages": "R2R_POSTGRES_HUGE_PAGES",
+ "maintenance_work_mem": "R2R_POSTGRES_MAINTENANCE_WORK_MEM",
+ "min_wal_size": "R2R_POSTGRES_MIN_WAL_SIZE",
+ "max_connections": "R2R_POSTGRES_MAX_CONNECTIONS",
+ "max_parallel_workers_per_gather": "R2R_POSTGRES_MAX_PARALLEL_WORKERS_PER_GATHER",
+ "max_parallel_workers": "R2R_POSTGRES_MAX_PARALLEL_WORKERS",
+ "max_parallel_maintenance_workers": "R2R_POSTGRES_MAX_PARALLEL_MAINTENANCE_WORKERS",
+ "max_wal_size": "R2R_POSTGRES_MAX_WAL_SIZE",
+ "max_worker_processes": "R2R_POSTGRES_MAX_WORKER_PROCESSES",
+ "random_page_cost": "R2R_POSTGRES_RANDOM_PAGE_COST",
+ "statement_cache_size": "R2R_POSTGRES_STATEMENT_CACHE_SIZE",
+ "shared_buffers": "R2R_POSTGRES_SHARED_BUFFERS",
+ "wal_buffers": "R2R_POSTGRES_WAL_BUFFERS",
+ "work_mem": "R2R_POSTGRES_WORK_MEM",
+ }
+
+ for setting, env_var in env_mapping.items():
+ value = getattr(
+ config.postgres_configuration_settings, setting, None
+ )
+ if value is None:
+ value = os.getenv(env_var)
+
+ if value is not None:
+ field_type = settings.__annotations__[setting]
+ if field_type == Optional[int]:
+ value = int(value)
+ elif field_type == Optional[float]:
+ value = float(value)
+
+ setattr(settings, setting, value)
+
+ return settings
+
+ async def close(self):
+ if self.pool:
+ await self.pool.close()
+
+ async def __aenter__(self):
+ await self.initialize()
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ await self.close()