aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/base/providers/database.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/base/providers/database.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/base/providers/database.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/database.py197
1 files changed, 197 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/database.py b/.venv/lib/python3.12/site-packages/core/base/providers/database.py
new file mode 100644
index 00000000..845a8109
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/database.py
@@ -0,0 +1,197 @@
+"""Base classes for database providers."""
+
+import logging
+from abc import ABC, abstractmethod
+from typing import Any, Optional, Sequence, cast
+from uuid import UUID
+
+from pydantic import BaseModel
+
+from core.base.abstractions import (
+ GraphCreationSettings,
+ GraphEnrichmentSettings,
+ GraphSearchSettings,
+)
+
+from .base import Provider, ProviderConfig
+
+logger = logging.getLogger()
+
+
+class DatabaseConnectionManager(ABC):
+ @abstractmethod
+ def execute_query(
+ self,
+ query: str,
+ params: Optional[dict[str, Any] | Sequence[Any]] = None,
+ isolation_level: Optional[str] = None,
+ ):
+ pass
+
+ @abstractmethod
+ async def execute_many(self, query, params=None, batch_size=1000):
+ pass
+
+ @abstractmethod
+ def fetch_query(
+ self,
+ query: str,
+ params: Optional[dict[str, Any] | Sequence[Any]] = None,
+ ):
+ pass
+
+ @abstractmethod
+ def fetchrow_query(
+ self,
+ query: str,
+ params: Optional[dict[str, Any] | Sequence[Any]] = None,
+ ):
+ pass
+
+ @abstractmethod
+ async def initialize(self, pool: Any):
+ pass
+
+
+class Handler(ABC):
+ def __init__(
+ self,
+ project_name: str,
+ connection_manager: DatabaseConnectionManager,
+ ):
+ self.project_name = project_name
+ self.connection_manager = connection_manager
+
+ def _get_table_name(self, base_name: str) -> str:
+ return f"{self.project_name}.{base_name}"
+
+ @abstractmethod
+ def create_tables(self):
+ pass
+
+
+class PostgresConfigurationSettings(BaseModel):
+ """Configuration settings with defaults defined by the PGVector docker
+ image.
+
+ These settings are helpful in managing the connections to the database. To
+ tune these settings for a specific deployment, see
+ https://pgtune.leopard.in.ua/
+ """
+
+ checkpoint_completion_target: Optional[float] = 0.9
+ default_statistics_target: Optional[int] = 100
+ effective_io_concurrency: Optional[int] = 1
+ effective_cache_size: Optional[int] = 524288
+ huge_pages: Optional[str] = "try"
+ maintenance_work_mem: Optional[int] = 65536
+ max_connections: Optional[int] = 256
+ max_parallel_workers_per_gather: Optional[int] = 2
+ max_parallel_workers: Optional[int] = 8
+ max_parallel_maintenance_workers: Optional[int] = 2
+ max_wal_size: Optional[int] = 1024
+ max_worker_processes: Optional[int] = 8
+ min_wal_size: Optional[int] = 80
+ shared_buffers: Optional[int] = 16384
+ statement_cache_size: Optional[int] = 100
+ random_page_cost: Optional[float] = 4
+ wal_buffers: Optional[int] = 512
+ work_mem: Optional[int] = 4096
+
+
+class LimitSettings(BaseModel):
+ global_per_min: Optional[int] = None
+ route_per_min: Optional[int] = None
+ monthly_limit: Optional[int] = None
+
+ def merge_with_defaults(
+ self, defaults: "LimitSettings"
+ ) -> "LimitSettings":
+ return LimitSettings(
+ global_per_min=self.global_per_min or defaults.global_per_min,
+ route_per_min=self.route_per_min or defaults.route_per_min,
+ monthly_limit=self.monthly_limit or defaults.monthly_limit,
+ )
+
+
+class DatabaseConfig(ProviderConfig):
+ """A base database configuration class."""
+
+ provider: str = "postgres"
+ user: Optional[str] = None
+ password: Optional[str] = None
+ host: Optional[str] = None
+ port: Optional[int] = None
+ db_name: Optional[str] = None
+ project_name: Optional[str] = None
+ postgres_configuration_settings: Optional[
+ PostgresConfigurationSettings
+ ] = None
+ default_collection_name: str = "Default"
+ default_collection_description: str = "Your default collection."
+ collection_summary_system_prompt: str = "system"
+ collection_summary_prompt: str = "collection_summary"
+ enable_fts: bool = False
+
+ # Graph settings
+ batch_size: Optional[int] = 1
+ graph_search_results_store_path: Optional[str] = None
+ graph_enrichment_settings: GraphEnrichmentSettings = (
+ GraphEnrichmentSettings()
+ )
+ graph_creation_settings: GraphCreationSettings = GraphCreationSettings()
+ graph_search_settings: GraphSearchSettings = GraphSearchSettings()
+
+ # Rate limits
+ limits: LimitSettings = LimitSettings(
+ global_per_min=60, route_per_min=20, monthly_limit=10000
+ )
+ route_limits: dict[str, LimitSettings] = {}
+ user_limits: dict[UUID, LimitSettings] = {}
+
+ def validate_config(self) -> None:
+ if self.provider not in self.supported_providers:
+ raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["postgres"]
+
+ @classmethod
+ def from_dict(cls, data: dict[str, Any]) -> "DatabaseConfig":
+ instance = cls.create(**data)
+
+ instance = cast(DatabaseConfig, instance)
+
+ limits_data = data.get("limits", {})
+ default_limits = LimitSettings(
+ global_per_min=limits_data.get("global_per_min", 60),
+ route_per_min=limits_data.get("route_per_min", 20),
+ monthly_limit=limits_data.get("monthly_limit", 10000),
+ )
+
+ instance.limits = default_limits
+
+ route_limits_data = limits_data.get("routes", {})
+ for route_str, route_cfg in route_limits_data.items():
+ instance.route_limits[route_str] = LimitSettings(**route_cfg)
+
+ return instance
+
+
+class DatabaseProvider(Provider):
+ connection_manager: DatabaseConnectionManager
+ config: DatabaseConfig
+ project_name: str
+
+ def __init__(self, config: DatabaseConfig):
+ logger.info(f"Initializing DatabaseProvider with config {config}.")
+ super().__init__(config)
+
+ @abstractmethod
+ async def __aenter__(self):
+ pass
+
+ @abstractmethod
+ async def __aexit__(self, exc_type, exc, tb):
+ pass