about summary refs log tree commit diff
path: root/R2R/r2r/base/providers/vector_db_provider.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/base/providers/vector_db_provider.py')
-rwxr-xr-xR2R/r2r/base/providers/vector_db_provider.py142
1 files changed, 142 insertions, 0 deletions
diff --git a/R2R/r2r/base/providers/vector_db_provider.py b/R2R/r2r/base/providers/vector_db_provider.py
new file mode 100755
index 00000000..a6d5aaa8
--- /dev/null
+++ b/R2R/r2r/base/providers/vector_db_provider.py
@@ -0,0 +1,142 @@
+import logging
+from abc import ABC, abstractmethod
+from typing import Optional, Union
+
+from ..abstractions.document import DocumentInfo
+from ..abstractions.search import VectorSearchResult
+from ..abstractions.vector import VectorEntry
+from .base_provider import Provider, ProviderConfig
+
+logger = logging.getLogger(__name__)
+
+
+class VectorDBConfig(ProviderConfig):
+    provider: str
+
+    def __post_init__(self):
+        self.validate()
+        # Capture additional fields
+        for key, value in self.extra_fields.items():
+            setattr(self, key, value)
+
+    def validate(self) -> None:
+        if self.provider not in self.supported_providers:
+            raise ValueError(f"Provider '{self.provider}' is not supported.")
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return ["local", "pgvector"]
+
+
+class VectorDBProvider(Provider, ABC):
+    def __init__(self, config: VectorDBConfig):
+        if not isinstance(config, VectorDBConfig):
+            raise ValueError(
+                "VectorDBProvider must be initialized with a `VectorDBConfig`."
+            )
+        logger.info(f"Initializing VectorDBProvider with config {config}.")
+        super().__init__(config)
+
+    @abstractmethod
+    def initialize_collection(self, dimension: int) -> None:
+        pass
+
+    @abstractmethod
+    def copy(self, entry: VectorEntry, commit: bool = True) -> None:
+        pass
+
+    @abstractmethod
+    def upsert(self, entry: VectorEntry, commit: bool = True) -> None:
+        pass
+
+    @abstractmethod
+    def search(
+        self,
+        query_vector: list[float],
+        filters: dict[str, Union[bool, int, str]] = {},
+        limit: int = 10,
+        *args,
+        **kwargs,
+    ) -> list[VectorSearchResult]:
+        pass
+
+    @abstractmethod
+    def hybrid_search(
+        self,
+        query_text: str,
+        query_vector: list[float],
+        limit: int = 10,
+        filters: Optional[dict[str, Union[bool, int, str]]] = None,
+        # Hybrid search parameters
+        full_text_weight: float = 1.0,
+        semantic_weight: float = 1.0,
+        rrf_k: int = 20,  # typical value is ~2x the number of results you want
+        *args,
+        **kwargs,
+    ) -> list[VectorSearchResult]:
+        pass
+
+    @abstractmethod
+    def create_index(self, index_type, column_name, index_options):
+        pass
+
+    def upsert_entries(
+        self, entries: list[VectorEntry], commit: bool = True
+    ) -> None:
+        for entry in entries:
+            self.upsert(entry, commit=commit)
+
+    def copy_entries(
+        self, entries: list[VectorEntry], commit: bool = True
+    ) -> None:
+        for entry in entries:
+            self.copy(entry, commit=commit)
+
+    @abstractmethod
+    def delete_by_metadata(
+        self,
+        metadata_fields: list[str],
+        metadata_values: list[Union[bool, int, str]],
+    ) -> list[str]:
+        if len(metadata_fields) != len(metadata_values):
+            raise ValueError(
+                "The number of metadata fields and values must be equal."
+            )
+        pass
+
+    @abstractmethod
+    def get_metadatas(
+        self,
+        metadata_fields: list[str],
+        filter_field: Optional[str] = None,
+        filter_value: Optional[str] = None,
+    ) -> list[str]:
+        pass
+
+    @abstractmethod
+    def upsert_documents_overview(
+        self, document_infs: list[DocumentInfo]
+    ) -> None:
+        pass
+
+    @abstractmethod
+    def get_documents_overview(
+        self,
+        filter_document_ids: Optional[list[str]] = None,
+        filter_user_ids: Optional[list[str]] = None,
+    ) -> list[DocumentInfo]:
+        pass
+
+    @abstractmethod
+    def get_document_chunks(self, document_id: str) -> list[dict]:
+        pass
+
+    @abstractmethod
+    def delete_from_documents_overview(
+        self, document_id: str, version: Optional[str] = None
+    ) -> dict:
+        pass
+
+    @abstractmethod
+    def get_users_overview(self, user_ids: Optional[list[str]] = None) -> dict:
+        pass