import logging
from abc import ABC, abstractmethod
from typing import Optional, Union
from ..abstractions.document import DocumentInfo
from ..abstractions.search import VectorSearchResult
from ..abstractions.vector import VectorEntry
from .base_provider import Provider, ProviderConfig
logger = logging.getLogger(__name__)
class VectorDBConfig(ProviderConfig):
provider: str
def __post_init__(self):
self.validate()
# Capture additional fields
for key, value in self.extra_fields.items():
setattr(self, key, value)
def validate(self) -> None:
if self.provider not in self.supported_providers:
raise ValueError(f"Provider '{self.provider}' is not supported.")
@property
def supported_providers(self) -> list[str]:
return ["local", "pgvector"]
class VectorDBProvider(Provider, ABC):
def __init__(self, config: VectorDBConfig):
if not isinstance(config, VectorDBConfig):
raise ValueError(
"VectorDBProvider must be initialized with a `VectorDBConfig`."
)
logger.info(f"Initializing VectorDBProvider with config {config}.")
super().__init__(config)
@abstractmethod
def initialize_collection(self, dimension: int) -> None:
pass
@abstractmethod
def copy(self, entry: VectorEntry, commit: bool = True) -> None:
pass
@abstractmethod
def upsert(self, entry: VectorEntry, commit: bool = True) -> None:
pass
@abstractmethod
def search(
self,
query_vector: list[float],
filters: dict[str, Union[bool, int, str]] = {},
limit: int = 10,
*args,
**kwargs,
) -> list[VectorSearchResult]:
pass
@abstractmethod
def hybrid_search(
self,
query_text: str,
query_vector: list[float],
limit: int = 10,
filters: Optional[dict[str, Union[bool, int, str]]] = None,
# Hybrid search parameters
full_text_weight: float = 1.0,
semantic_weight: float = 1.0,
rrf_k: int = 20, # typical value is ~2x the number of results you want
*args,
**kwargs,
) -> list[VectorSearchResult]:
pass
@abstractmethod
def create_index(self, index_type, column_name, index_options):
pass
def upsert_entries(
self, entries: list[VectorEntry], commit: bool = True
) -> None:
for entry in entries:
self.upsert(entry, commit=commit)
def copy_entries(
self, entries: list[VectorEntry], commit: bool = True
) -> None:
for entry in entries:
self.copy(entry, commit=commit)
@abstractmethod
def delete_by_metadata(
self,
metadata_fields: list[str],
metadata_values: list[Union[bool, int, str]],
) -> list[str]:
if len(metadata_fields) != len(metadata_values):
raise ValueError(
"The number of metadata fields and values must be equal."
)
pass
@abstractmethod
def get_metadatas(
self,
metadata_fields: list[str],
filter_field: Optional[str] = None,
filter_value: Optional[str] = None,
) -> list[str]:
pass
@abstractmethod
def upsert_documents_overview(
self, document_infs: list[DocumentInfo]
) -> None:
pass
@abstractmethod
def get_documents_overview(
self,
filter_document_ids: Optional[list[str]] = None,
filter_user_ids: Optional[list[str]] = None,
) -> list[DocumentInfo]:
pass
@abstractmethod
def get_document_chunks(self, document_id: str) -> list[dict]:
pass
@abstractmethod
def delete_from_documents_overview(
self, document_id: str, version: Optional[str] = None
) -> dict:
pass
@abstractmethod
def get_users_overview(self, user_ids: Optional[list[str]] = None) -> dict:
pass