diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/base/providers/database.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/base/providers/database.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/core/base/providers/database.py | 197 |
1 files changed, 197 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/database.py b/.venv/lib/python3.12/site-packages/core/base/providers/database.py new file mode 100644 index 00000000..845a8109 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/database.py @@ -0,0 +1,197 @@ +"""Base classes for database providers.""" + +import logging +from abc import ABC, abstractmethod +from typing import Any, Optional, Sequence, cast +from uuid import UUID + +from pydantic import BaseModel + +from core.base.abstractions import ( + GraphCreationSettings, + GraphEnrichmentSettings, + GraphSearchSettings, +) + +from .base import Provider, ProviderConfig + +logger = logging.getLogger() + + +class DatabaseConnectionManager(ABC): + @abstractmethod + def execute_query( + self, + query: str, + params: Optional[dict[str, Any] | Sequence[Any]] = None, + isolation_level: Optional[str] = None, + ): + pass + + @abstractmethod + async def execute_many(self, query, params=None, batch_size=1000): + pass + + @abstractmethod + def fetch_query( + self, + query: str, + params: Optional[dict[str, Any] | Sequence[Any]] = None, + ): + pass + + @abstractmethod + def fetchrow_query( + self, + query: str, + params: Optional[dict[str, Any] | Sequence[Any]] = None, + ): + pass + + @abstractmethod + async def initialize(self, pool: Any): + pass + + +class Handler(ABC): + def __init__( + self, + project_name: str, + connection_manager: DatabaseConnectionManager, + ): + self.project_name = project_name + self.connection_manager = connection_manager + + def _get_table_name(self, base_name: str) -> str: + return f"{self.project_name}.{base_name}" + + @abstractmethod + def create_tables(self): + pass + + +class PostgresConfigurationSettings(BaseModel): + """Configuration settings with defaults defined by the PGVector docker + image. + + These settings are helpful in managing the connections to the database. To + tune these settings for a specific deployment, see + https://pgtune.leopard.in.ua/ + """ + + checkpoint_completion_target: Optional[float] = 0.9 + default_statistics_target: Optional[int] = 100 + effective_io_concurrency: Optional[int] = 1 + effective_cache_size: Optional[int] = 524288 + huge_pages: Optional[str] = "try" + maintenance_work_mem: Optional[int] = 65536 + max_connections: Optional[int] = 256 + max_parallel_workers_per_gather: Optional[int] = 2 + max_parallel_workers: Optional[int] = 8 + max_parallel_maintenance_workers: Optional[int] = 2 + max_wal_size: Optional[int] = 1024 + max_worker_processes: Optional[int] = 8 + min_wal_size: Optional[int] = 80 + shared_buffers: Optional[int] = 16384 + statement_cache_size: Optional[int] = 100 + random_page_cost: Optional[float] = 4 + wal_buffers: Optional[int] = 512 + work_mem: Optional[int] = 4096 + + +class LimitSettings(BaseModel): + global_per_min: Optional[int] = None + route_per_min: Optional[int] = None + monthly_limit: Optional[int] = None + + def merge_with_defaults( + self, defaults: "LimitSettings" + ) -> "LimitSettings": + return LimitSettings( + global_per_min=self.global_per_min or defaults.global_per_min, + route_per_min=self.route_per_min or defaults.route_per_min, + monthly_limit=self.monthly_limit or defaults.monthly_limit, + ) + + +class DatabaseConfig(ProviderConfig): + """A base database configuration class.""" + + provider: str = "postgres" + user: Optional[str] = None + password: Optional[str] = None + host: Optional[str] = None + port: Optional[int] = None + db_name: Optional[str] = None + project_name: Optional[str] = None + postgres_configuration_settings: Optional[ + PostgresConfigurationSettings + ] = None + default_collection_name: str = "Default" + default_collection_description: str = "Your default collection." + collection_summary_system_prompt: str = "system" + collection_summary_prompt: str = "collection_summary" + enable_fts: bool = False + + # Graph settings + batch_size: Optional[int] = 1 + graph_search_results_store_path: Optional[str] = None + graph_enrichment_settings: GraphEnrichmentSettings = ( + GraphEnrichmentSettings() + ) + graph_creation_settings: GraphCreationSettings = GraphCreationSettings() + graph_search_settings: GraphSearchSettings = GraphSearchSettings() + + # Rate limits + limits: LimitSettings = LimitSettings( + global_per_min=60, route_per_min=20, monthly_limit=10000 + ) + route_limits: dict[str, LimitSettings] = {} + user_limits: dict[UUID, LimitSettings] = {} + + def validate_config(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider '{self.provider}' is not supported.") + + @property + def supported_providers(self) -> list[str]: + return ["postgres"] + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "DatabaseConfig": + instance = cls.create(**data) + + instance = cast(DatabaseConfig, instance) + + limits_data = data.get("limits", {}) + default_limits = LimitSettings( + global_per_min=limits_data.get("global_per_min", 60), + route_per_min=limits_data.get("route_per_min", 20), + monthly_limit=limits_data.get("monthly_limit", 10000), + ) + + instance.limits = default_limits + + route_limits_data = limits_data.get("routes", {}) + for route_str, route_cfg in route_limits_data.items(): + instance.route_limits[route_str] = LimitSettings(**route_cfg) + + return instance + + +class DatabaseProvider(Provider): + connection_manager: DatabaseConnectionManager + config: DatabaseConfig + project_name: str + + def __init__(self, config: DatabaseConfig): + logger.info(f"Initializing DatabaseProvider with config {config}.") + super().__init__(config) + + @abstractmethod + async def __aenter__(self): + pass + + @abstractmethod + async def __aexit__(self, exc_type, exc, tb): + pass |