diff options
Diffstat (limited to 'R2R/r2r/base/providers/vector_db_provider.py')
-rwxr-xr-x | R2R/r2r/base/providers/vector_db_provider.py | 142 |
1 files changed, 142 insertions, 0 deletions
diff --git a/R2R/r2r/base/providers/vector_db_provider.py b/R2R/r2r/base/providers/vector_db_provider.py new file mode 100755 index 00000000..a6d5aaa8 --- /dev/null +++ b/R2R/r2r/base/providers/vector_db_provider.py @@ -0,0 +1,142 @@ +import logging +from abc import ABC, abstractmethod +from typing import Optional, Union + +from ..abstractions.document import DocumentInfo +from ..abstractions.search import VectorSearchResult +from ..abstractions.vector import VectorEntry +from .base_provider import Provider, ProviderConfig + +logger = logging.getLogger(__name__) + + +class VectorDBConfig(ProviderConfig): + provider: str + + def __post_init__(self): + self.validate() + # Capture additional fields + for key, value in self.extra_fields.items(): + setattr(self, key, value) + + def validate(self) -> None: + if self.provider not in self.supported_providers: + raise ValueError(f"Provider '{self.provider}' is not supported.") + + @property + def supported_providers(self) -> list[str]: + return ["local", "pgvector"] + + +class VectorDBProvider(Provider, ABC): + def __init__(self, config: VectorDBConfig): + if not isinstance(config, VectorDBConfig): + raise ValueError( + "VectorDBProvider must be initialized with a `VectorDBConfig`." + ) + logger.info(f"Initializing VectorDBProvider with config {config}.") + super().__init__(config) + + @abstractmethod + def initialize_collection(self, dimension: int) -> None: + pass + + @abstractmethod + def copy(self, entry: VectorEntry, commit: bool = True) -> None: + pass + + @abstractmethod + def upsert(self, entry: VectorEntry, commit: bool = True) -> None: + pass + + @abstractmethod + def search( + self, + query_vector: list[float], + filters: dict[str, Union[bool, int, str]] = {}, + limit: int = 10, + *args, + **kwargs, + ) -> list[VectorSearchResult]: + pass + + @abstractmethod + def hybrid_search( + self, + query_text: str, + query_vector: list[float], + limit: int = 10, + filters: Optional[dict[str, Union[bool, int, str]]] = None, + # Hybrid search parameters + full_text_weight: float = 1.0, + semantic_weight: float = 1.0, + rrf_k: int = 20, # typical value is ~2x the number of results you want + *args, + **kwargs, + ) -> list[VectorSearchResult]: + pass + + @abstractmethod + def create_index(self, index_type, column_name, index_options): + pass + + def upsert_entries( + self, entries: list[VectorEntry], commit: bool = True + ) -> None: + for entry in entries: + self.upsert(entry, commit=commit) + + def copy_entries( + self, entries: list[VectorEntry], commit: bool = True + ) -> None: + for entry in entries: + self.copy(entry, commit=commit) + + @abstractmethod + def delete_by_metadata( + self, + metadata_fields: list[str], + metadata_values: list[Union[bool, int, str]], + ) -> list[str]: + if len(metadata_fields) != len(metadata_values): + raise ValueError( + "The number of metadata fields and values must be equal." + ) + pass + + @abstractmethod + def get_metadatas( + self, + metadata_fields: list[str], + filter_field: Optional[str] = None, + filter_value: Optional[str] = None, + ) -> list[str]: + pass + + @abstractmethod + def upsert_documents_overview( + self, document_infs: list[DocumentInfo] + ) -> None: + pass + + @abstractmethod + def get_documents_overview( + self, + filter_document_ids: Optional[list[str]] = None, + filter_user_ids: Optional[list[str]] = None, + ) -> list[DocumentInfo]: + pass + + @abstractmethod + def get_document_chunks(self, document_id: str) -> list[dict]: + pass + + @abstractmethod + def delete_from_documents_overview( + self, document_id: str, version: Optional[str] = None + ) -> dict: + pass + + @abstractmethod + def get_users_overview(self, user_ids: Optional[list[str]] = None) -> dict: + pass |