import logging from abc import ABC, abstractmethod from typing import Optional, Union from ..abstractions.document import DocumentInfo from ..abstractions.search import VectorSearchResult from ..abstractions.vector import VectorEntry from .base_provider import Provider, ProviderConfig logger = logging.getLogger(__name__) class VectorDBConfig(ProviderConfig): provider: str def __post_init__(self): self.validate() # Capture additional fields for key, value in self.extra_fields.items(): setattr(self, key, value) def validate(self) -> None: if self.provider not in self.supported_providers: raise ValueError(f"Provider '{self.provider}' is not supported.") @property def supported_providers(self) -> list[str]: return ["local", "pgvector"] class VectorDBProvider(Provider, ABC): def __init__(self, config: VectorDBConfig): if not isinstance(config, VectorDBConfig): raise ValueError( "VectorDBProvider must be initialized with a `VectorDBConfig`." ) logger.info(f"Initializing VectorDBProvider with config {config}.") super().__init__(config) @abstractmethod def initialize_collection(self, dimension: int) -> None: pass @abstractmethod def copy(self, entry: VectorEntry, commit: bool = True) -> None: pass @abstractmethod def upsert(self, entry: VectorEntry, commit: bool = True) -> None: pass @abstractmethod def search( self, query_vector: list[float], filters: dict[str, Union[bool, int, str]] = {}, limit: int = 10, *args, **kwargs, ) -> list[VectorSearchResult]: pass @abstractmethod def hybrid_search( self, query_text: str, query_vector: list[float], limit: int = 10, filters: Optional[dict[str, Union[bool, int, str]]] = None, # Hybrid search parameters full_text_weight: float = 1.0, semantic_weight: float = 1.0, rrf_k: int = 20, # typical value is ~2x the number of results you want *args, **kwargs, ) -> list[VectorSearchResult]: pass @abstractmethod def create_index(self, index_type, column_name, index_options): pass def upsert_entries( self, entries: list[VectorEntry], commit: bool = True ) -> None: for entry in entries: self.upsert(entry, commit=commit) def copy_entries( self, entries: list[VectorEntry], commit: bool = True ) -> None: for entry in entries: self.copy(entry, commit=commit) @abstractmethod def delete_by_metadata( self, metadata_fields: list[str], metadata_values: list[Union[bool, int, str]], ) -> list[str]: if len(metadata_fields) != len(metadata_values): raise ValueError( "The number of metadata fields and values must be equal." ) pass @abstractmethod def get_metadatas( self, metadata_fields: list[str], filter_field: Optional[str] = None, filter_value: Optional[str] = None, ) -> list[str]: pass @abstractmethod def upsert_documents_overview( self, document_infs: list[DocumentInfo] ) -> None: pass @abstractmethod def get_documents_overview( self, filter_document_ids: Optional[list[str]] = None, filter_user_ids: Optional[list[str]] = None, ) -> list[DocumentInfo]: pass @abstractmethod def get_document_chunks(self, document_id: str) -> list[dict]: pass @abstractmethod def delete_from_documents_overview( self, document_id: str, version: Optional[str] = None ) -> dict: pass @abstractmethod def get_users_overview(self, user_ids: Optional[list[str]] = None) -> dict: pass