about summary refs log tree commit diff
import logging
from abc import ABC, abstractmethod
from typing import Optional, Union

from ..abstractions.document import DocumentInfo
from ..abstractions.search import VectorSearchResult
from ..abstractions.vector import VectorEntry
from .base_provider import Provider, ProviderConfig

logger = logging.getLogger(__name__)


class VectorDBConfig(ProviderConfig):
    provider: str

    def __post_init__(self):
        self.validate()
        # Capture additional fields
        for key, value in self.extra_fields.items():
            setattr(self, key, value)

    def validate(self) -> None:
        if self.provider not in self.supported_providers:
            raise ValueError(f"Provider '{self.provider}' is not supported.")

    @property
    def supported_providers(self) -> list[str]:
        return ["local", "pgvector"]


class VectorDBProvider(Provider, ABC):
    def __init__(self, config: VectorDBConfig):
        if not isinstance(config, VectorDBConfig):
            raise ValueError(
                "VectorDBProvider must be initialized with a `VectorDBConfig`."
            )
        logger.info(f"Initializing VectorDBProvider with config {config}.")
        super().__init__(config)

    @abstractmethod
    def initialize_collection(self, dimension: int) -> None:
        pass

    @abstractmethod
    def copy(self, entry: VectorEntry, commit: bool = True) -> None:
        pass

    @abstractmethod
    def upsert(self, entry: VectorEntry, commit: bool = True) -> None:
        pass

    @abstractmethod
    def search(
        self,
        query_vector: list[float],
        filters: dict[str, Union[bool, int, str]] = {},
        limit: int = 10,
        *args,
        **kwargs,
    ) -> list[VectorSearchResult]:
        pass

    @abstractmethod
    def hybrid_search(
        self,
        query_text: str,
        query_vector: list[float],
        limit: int = 10,
        filters: Optional[dict[str, Union[bool, int, str]]] = None,
        # Hybrid search parameters
        full_text_weight: float = 1.0,
        semantic_weight: float = 1.0,
        rrf_k: int = 20,  # typical value is ~2x the number of results you want
        *args,
        **kwargs,
    ) -> list[VectorSearchResult]:
        pass

    @abstractmethod
    def create_index(self, index_type, column_name, index_options):
        pass

    def upsert_entries(
        self, entries: list[VectorEntry], commit: bool = True
    ) -> None:
        for entry in entries:
            self.upsert(entry, commit=commit)

    def copy_entries(
        self, entries: list[VectorEntry], commit: bool = True
    ) -> None:
        for entry in entries:
            self.copy(entry, commit=commit)

    @abstractmethod
    def delete_by_metadata(
        self,
        metadata_fields: list[str],
        metadata_values: list[Union[bool, int, str]],
    ) -> list[str]:
        if len(metadata_fields) != len(metadata_values):
            raise ValueError(
                "The number of metadata fields and values must be equal."
            )
        pass

    @abstractmethod
    def get_metadatas(
        self,
        metadata_fields: list[str],
        filter_field: Optional[str] = None,
        filter_value: Optional[str] = None,
    ) -> list[str]:
        pass

    @abstractmethod
    def upsert_documents_overview(
        self, document_infs: list[DocumentInfo]
    ) -> None:
        pass

    @abstractmethod
    def get_documents_overview(
        self,
        filter_document_ids: Optional[list[str]] = None,
        filter_user_ids: Optional[list[str]] = None,
    ) -> list[DocumentInfo]:
        pass

    @abstractmethod
    def get_document_chunks(self, document_id: str) -> list[dict]:
        pass

    @abstractmethod
    def delete_from_documents_overview(
        self, document_id: str, version: Optional[str] = None
    ) -> dict:
        pass

    @abstractmethod
    def get_users_overview(self, user_ids: Optional[list[str]] = None) -> dict:
        pass