about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/core/base/providers/database.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/base/providers/database.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/base/providers/database.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/database.py197
1 files changed, 197 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/database.py b/.venv/lib/python3.12/site-packages/core/base/providers/database.py
new file mode 100644
index 00000000..845a8109
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/database.py
@@ -0,0 +1,197 @@
+"""Base classes for database providers."""
+
+import logging
+from abc import ABC, abstractmethod
+from typing import Any, Optional, Sequence, cast
+from uuid import UUID
+
+from pydantic import BaseModel
+
+from core.base.abstractions import (
+    GraphCreationSettings,
+    GraphEnrichmentSettings,
+    GraphSearchSettings,
+)
+
+from .base import Provider, ProviderConfig
+
+logger = logging.getLogger()
+
+
+class DatabaseConnectionManager(ABC):
+    @abstractmethod
+    def execute_query(
+        self,
+        query: str,
+        params: Optional[dict[str, Any] | Sequence[Any]] = None,
+        isolation_level: Optional[str] = None,
+    ):
+        pass
+
+    @abstractmethod
+    async def execute_many(self, query, params=None, batch_size=1000):
+        pass
+
+    @abstractmethod
+    def fetch_query(
+        self,
+        query: str,
+        params: Optional[dict[str, Any] | Sequence[Any]] = None,
+    ):
+        pass
+
+    @abstractmethod
+    def fetchrow_query(
+        self,
+        query: str,
+        params: Optional[dict[str, Any] | Sequence[Any]] = None,
+    ):
+        pass
+
+    @abstractmethod
+    async def initialize(self, pool: Any):
+        pass
+
+
+class Handler(ABC):
+    def __init__(
+        self,
+        project_name: str,
+        connection_manager: DatabaseConnectionManager,
+    ):
+        self.project_name = project_name
+        self.connection_manager = connection_manager
+
+    def _get_table_name(self, base_name: str) -> str:
+        return f"{self.project_name}.{base_name}"
+
+    @abstractmethod
+    def create_tables(self):
+        pass
+
+
+class PostgresConfigurationSettings(BaseModel):
+    """Configuration settings with defaults defined by the PGVector docker
+    image.
+
+    These settings are helpful in managing the connections to the database. To
+    tune these settings for a specific deployment, see
+    https://pgtune.leopard.in.ua/
+    """
+
+    checkpoint_completion_target: Optional[float] = 0.9
+    default_statistics_target: Optional[int] = 100
+    effective_io_concurrency: Optional[int] = 1
+    effective_cache_size: Optional[int] = 524288
+    huge_pages: Optional[str] = "try"
+    maintenance_work_mem: Optional[int] = 65536
+    max_connections: Optional[int] = 256
+    max_parallel_workers_per_gather: Optional[int] = 2
+    max_parallel_workers: Optional[int] = 8
+    max_parallel_maintenance_workers: Optional[int] = 2
+    max_wal_size: Optional[int] = 1024
+    max_worker_processes: Optional[int] = 8
+    min_wal_size: Optional[int] = 80
+    shared_buffers: Optional[int] = 16384
+    statement_cache_size: Optional[int] = 100
+    random_page_cost: Optional[float] = 4
+    wal_buffers: Optional[int] = 512
+    work_mem: Optional[int] = 4096
+
+
+class LimitSettings(BaseModel):
+    global_per_min: Optional[int] = None
+    route_per_min: Optional[int] = None
+    monthly_limit: Optional[int] = None
+
+    def merge_with_defaults(
+        self, defaults: "LimitSettings"
+    ) -> "LimitSettings":
+        return LimitSettings(
+            global_per_min=self.global_per_min or defaults.global_per_min,
+            route_per_min=self.route_per_min or defaults.route_per_min,
+            monthly_limit=self.monthly_limit or defaults.monthly_limit,
+        )
+
+
+class DatabaseConfig(ProviderConfig):
+    """A base database configuration class."""
+
+    provider: str = "postgres"
+    user: Optional[str] = None
+    password: Optional[str] = None
+    host: Optional[str] = None
+    port: Optional[int] = None
+    db_name: Optional[str] = None
+    project_name: Optional[str] = None
+    postgres_configuration_settings: Optional[
+        PostgresConfigurationSettings
+    ] = None
+    default_collection_name: str = "Default"
+    default_collection_description: str = "Your default collection."
+    collection_summary_system_prompt: str = "system"
+    collection_summary_prompt: str = "collection_summary"
+    enable_fts: bool = False
+
+    # Graph settings
+    batch_size: Optional[int] = 1
+    graph_search_results_store_path: Optional[str] = None
+    graph_enrichment_settings: GraphEnrichmentSettings = (
+        GraphEnrichmentSettings()
+    )
+    graph_creation_settings: GraphCreationSettings = GraphCreationSettings()
+    graph_search_settings: GraphSearchSettings = GraphSearchSettings()
+
+    # Rate limits
+    limits: LimitSettings = LimitSettings(
+        global_per_min=60, route_per_min=20, monthly_limit=10000
+    )
+    route_limits: dict[str, LimitSettings] = {}
+    user_limits: dict[UUID, LimitSettings] = {}
+
+    def validate_config(self) -> None:
+        if self.provider not in self.supported_providers:
+            raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return ["postgres"]
+
+    @classmethod
+    def from_dict(cls, data: dict[str, Any]) -> "DatabaseConfig":
+        instance = cls.create(**data)
+
+        instance = cast(DatabaseConfig, instance)
+
+        limits_data = data.get("limits", {})
+        default_limits = LimitSettings(
+            global_per_min=limits_data.get("global_per_min", 60),
+            route_per_min=limits_data.get("route_per_min", 20),
+            monthly_limit=limits_data.get("monthly_limit", 10000),
+        )
+
+        instance.limits = default_limits
+
+        route_limits_data = limits_data.get("routes", {})
+        for route_str, route_cfg in route_limits_data.items():
+            instance.route_limits[route_str] = LimitSettings(**route_cfg)
+
+        return instance
+
+
+class DatabaseProvider(Provider):
+    connection_manager: DatabaseConnectionManager
+    config: DatabaseConfig
+    project_name: str
+
+    def __init__(self, config: DatabaseConfig):
+        logger.info(f"Initializing DatabaseProvider with config {config}.")
+        super().__init__(config)
+
+    @abstractmethod
+    async def __aenter__(self):
+        pass
+
+    @abstractmethod
+    async def __aexit__(self, exc_type, exc, tb):
+        pass