aboutsummaryrefslogtreecommitdiff
"""Base classes for database providers."""

import logging
from abc import ABC, abstractmethod
from typing import Any, Optional, Sequence, cast
from uuid import UUID

from pydantic import BaseModel

from core.base.abstractions import (
    GraphCreationSettings,
    GraphEnrichmentSettings,
    GraphSearchSettings,
)

from .base import Provider, ProviderConfig

logger = logging.getLogger()


class DatabaseConnectionManager(ABC):
    @abstractmethod
    def execute_query(
        self,
        query: str,
        params: Optional[dict[str, Any] | Sequence[Any]] = None,
        isolation_level: Optional[str] = None,
    ):
        pass

    @abstractmethod
    async def execute_many(self, query, params=None, batch_size=1000):
        pass

    @abstractmethod
    def fetch_query(
        self,
        query: str,
        params: Optional[dict[str, Any] | Sequence[Any]] = None,
    ):
        pass

    @abstractmethod
    def fetchrow_query(
        self,
        query: str,
        params: Optional[dict[str, Any] | Sequence[Any]] = None,
    ):
        pass

    @abstractmethod
    async def initialize(self, pool: Any):
        pass


class Handler(ABC):
    def __init__(
        self,
        project_name: str,
        connection_manager: DatabaseConnectionManager,
    ):
        self.project_name = project_name
        self.connection_manager = connection_manager

    def _get_table_name(self, base_name: str) -> str:
        return f"{self.project_name}.{base_name}"

    @abstractmethod
    def create_tables(self):
        pass


class PostgresConfigurationSettings(BaseModel):
    """Configuration settings with defaults defined by the PGVector docker
    image.

    These settings are helpful in managing the connections to the database. To
    tune these settings for a specific deployment, see
    https://pgtune.leopard.in.ua/
    """

    checkpoint_completion_target: Optional[float] = 0.9
    default_statistics_target: Optional[int] = 100
    effective_io_concurrency: Optional[int] = 1
    effective_cache_size: Optional[int] = 524288
    huge_pages: Optional[str] = "try"
    maintenance_work_mem: Optional[int] = 65536
    max_connections: Optional[int] = 256
    max_parallel_workers_per_gather: Optional[int] = 2
    max_parallel_workers: Optional[int] = 8
    max_parallel_maintenance_workers: Optional[int] = 2
    max_wal_size: Optional[int] = 1024
    max_worker_processes: Optional[int] = 8
    min_wal_size: Optional[int] = 80
    shared_buffers: Optional[int] = 16384
    statement_cache_size: Optional[int] = 100
    random_page_cost: Optional[float] = 4
    wal_buffers: Optional[int] = 512
    work_mem: Optional[int] = 4096


class LimitSettings(BaseModel):
    global_per_min: Optional[int] = None
    route_per_min: Optional[int] = None
    monthly_limit: Optional[int] = None

    def merge_with_defaults(
        self, defaults: "LimitSettings"
    ) -> "LimitSettings":
        return LimitSettings(
            global_per_min=self.global_per_min or defaults.global_per_min,
            route_per_min=self.route_per_min or defaults.route_per_min,
            monthly_limit=self.monthly_limit or defaults.monthly_limit,
        )


class DatabaseConfig(ProviderConfig):
    """A base database configuration class."""

    provider: str = "postgres"
    user: Optional[str] = None
    password: Optional[str] = None
    host: Optional[str] = None
    port: Optional[int] = None
    db_name: Optional[str] = None
    project_name: Optional[str] = None
    postgres_configuration_settings: Optional[
        PostgresConfigurationSettings
    ] = None
    default_collection_name: str = "Default"
    default_collection_description: str = "Your default collection."
    collection_summary_system_prompt: str = "system"
    collection_summary_prompt: str = "collection_summary"
    enable_fts: bool = False

    # Graph settings
    batch_size: Optional[int] = 1
    graph_search_results_store_path: Optional[str] = None
    graph_enrichment_settings: GraphEnrichmentSettings = (
        GraphEnrichmentSettings()
    )
    graph_creation_settings: GraphCreationSettings = GraphCreationSettings()
    graph_search_settings: GraphSearchSettings = GraphSearchSettings()

    # Rate limits
    limits: LimitSettings = LimitSettings(
        global_per_min=60, route_per_min=20, monthly_limit=10000
    )
    route_limits: dict[str, LimitSettings] = {}
    user_limits: dict[UUID, LimitSettings] = {}

    def validate_config(self) -> None:
        if self.provider not in self.supported_providers:
            raise ValueError(f"Provider '{self.provider}' is not supported.")

    @property
    def supported_providers(self) -> list[str]:
        return ["postgres"]

    @classmethod
    def from_dict(cls, data: dict[str, Any]) -> "DatabaseConfig":
        instance = cls.create(**data)

        instance = cast(DatabaseConfig, instance)

        limits_data = data.get("limits", {})
        default_limits = LimitSettings(
            global_per_min=limits_data.get("global_per_min", 60),
            route_per_min=limits_data.get("route_per_min", 20),
            monthly_limit=limits_data.get("monthly_limit", 10000),
        )

        instance.limits = default_limits

        route_limits_data = limits_data.get("routes", {})
        for route_str, route_cfg in route_limits_data.items():
            instance.route_limits[route_str] = LimitSettings(**route_cfg)

        return instance


class DatabaseProvider(Provider):
    connection_manager: DatabaseConnectionManager
    config: DatabaseConfig
    project_name: str

    def __init__(self, config: DatabaseConfig):
        logger.info(f"Initializing DatabaseProvider with config {config}.")
        super().__init__(config)

    @abstractmethod
    async def __aenter__(self):
        pass

    @abstractmethod
    async def __aexit__(self, exc_type, exc, tb):
        pass