diff options
Diffstat (limited to 'R2R/r2r/vecs')
-rwxr-xr-x | R2R/r2r/vecs/__init__.py | 28 | ||||
-rwxr-xr-x | R2R/r2r/vecs/adapter/__init__.py | 15 | ||||
-rwxr-xr-x | R2R/r2r/vecs/adapter/base.py | 111 | ||||
-rwxr-xr-x | R2R/r2r/vecs/adapter/markdown.py | 88 | ||||
-rwxr-xr-x | R2R/r2r/vecs/adapter/noop.py | 55 | ||||
-rwxr-xr-x | R2R/r2r/vecs/adapter/text.py | 151 | ||||
-rwxr-xr-x | R2R/r2r/vecs/client.py | 313 | ||||
-rwxr-xr-x | R2R/r2r/vecs/collection.py | 1132 | ||||
-rwxr-xr-x | R2R/r2r/vecs/exc.py | 83 |
9 files changed, 1976 insertions, 0 deletions
diff --git a/R2R/r2r/vecs/__init__.py b/R2R/r2r/vecs/__init__.py new file mode 100755 index 00000000..9d4f1d7e --- /dev/null +++ b/R2R/r2r/vecs/__init__.py @@ -0,0 +1,28 @@ +from . import exc +from .client import Client +from .collection import ( + Collection, + IndexArgsHNSW, + IndexArgsIVFFlat, + IndexMeasure, + IndexMethod, +) + +__project__ = "vecs" +__version__ = "0.4.2" + + +__all__ = [ + "IndexArgsIVFFlat", + "IndexArgsHNSW", + "IndexMethod", + "IndexMeasure", + "Collection", + "Client", + "exc", +] + + +def create_client(connection_string: str, *args, **kwargs) -> Client: + """Creates a client from a Postgres connection string""" + return Client(connection_string, *args, **kwargs) diff --git a/R2R/r2r/vecs/adapter/__init__.py b/R2R/r2r/vecs/adapter/__init__.py new file mode 100755 index 00000000..9cd9860d --- /dev/null +++ b/R2R/r2r/vecs/adapter/__init__.py @@ -0,0 +1,15 @@ +from .base import Adapter, AdapterContext, AdapterStep +from .markdown import MarkdownChunker +from .noop import NoOp +from .text import ParagraphChunker, TextEmbedding, TextEmbeddingModel + +__all__ = [ + "Adapter", + "AdapterContext", + "AdapterStep", + "NoOp", + "ParagraphChunker", + "TextEmbedding", + "TextEmbeddingModel", + "MarkdownChunker", +] diff --git a/R2R/r2r/vecs/adapter/base.py b/R2R/r2r/vecs/adapter/base.py new file mode 100755 index 00000000..7734e802 --- /dev/null +++ b/R2R/r2r/vecs/adapter/base.py @@ -0,0 +1,111 @@ +""" +The `vecs.experimental.adapter.base` module provides abstract classes and utilities +for creating and handling adapters in vecs. Adapters allow users to interact with +a collection using media types other than vectors. + +All public classes, enums, and functions are re-exported by `vecs.adapters` module. +""" + +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Dict, Generator, Iterable, Optional, Tuple + +from vecs.exc import ArgError + + +class AdapterContext(str, Enum): + """ + An enum representing the different contexts in which a Pipeline + will be invoked. + + Attributes: + upsert (str): The Collection.upsert method + query (str): The Collection.query method + """ + + upsert = "upsert" + query = "query" + + +class AdapterStep(ABC): + """ + Abstract class representing a step in the adapter pipeline. + + Each adapter step should adapt a user media into a tuple of: + - id (str) + - media (unknown type) + - metadata (dict) + + If the user provides id or metadata, default production is overridden. + """ + + @property + def exported_dimension(self) -> Optional[int]: + """ + Property that should be overridden by subclasses to provide the output dimension + of the adapter step. + """ + return None + + @abstractmethod + def __call__( + self, + records: Iterable[Tuple[str, Any, Optional[Dict]]], + adapter_context: AdapterContext, + ) -> Generator[Tuple[str, Any, Dict], None, None]: + """ + Abstract method that should be overridden by subclasses to handle each record. + """ + + +class Adapter: + """ + Class representing a sequence of AdapterStep instances forming a pipeline. + """ + + def __init__(self, steps: list[AdapterStep]): + """ + Initialize an Adapter instance with a list of AdapterStep instances. + + Args: + steps: list of AdapterStep instances. + + Raises: + ArgError: Raised if the steps list is empty. + """ + self.steps = steps + if len(steps) < 1: + raise ArgError("Adapter must contain at least 1 step") + + @property + def exported_dimension(self) -> Optional[int]: + """ + The output dimension of the adapter. Returns the exported dimension of the last + AdapterStep that provides one (from end to start of the steps list). + """ + for step in reversed(self.steps): + step_dim = step.exported_dimension + if step_dim is not None: + return step_dim + return None + + def __call__( + self, + records: Iterable[Tuple[str, Any, Optional[Dict]]], + adapter_context: AdapterContext, + ) -> Generator[Tuple[str, Any, Dict], None, None]: + """ + Invokes the adapter pipeline on an iterable of records. + + Args: + records: Iterable of tuples each containing an id, a media and an optional dict. + adapter_context: Context of the adapter. + + Yields: + Tuples each containing an id, a media and a dict. + """ + pipeline = records + for step in self.steps: + pipeline = step(pipeline, adapter_context) + + yield from pipeline # type: ignore diff --git a/R2R/r2r/vecs/adapter/markdown.py b/R2R/r2r/vecs/adapter/markdown.py new file mode 100755 index 00000000..149573f4 --- /dev/null +++ b/R2R/r2r/vecs/adapter/markdown.py @@ -0,0 +1,88 @@ +import re +from typing import Any, Dict, Generator, Iterable, Optional, Tuple + +from flupy import flu + +from .base import AdapterContext, AdapterStep + + +class MarkdownChunker(AdapterStep): + """ + MarkdownChunker is an AdapterStep that splits a markdown string into chunks where a heading signifies the start of a chunk, and yields each chunk as a separate record. + """ + + def __init__(self, *, skip_during_query: bool): + """ + Initializes the MarkdownChunker adapter. + + Args: + skip_during_query (bool): Whether to skip chunking during querying. + """ + self.skip_during_query = skip_during_query + + @staticmethod + def split_by_heading( + md: str, max_tokens: int + ) -> Generator[str, None, None]: + regex_split = r"^(#{1,6}\s+.+)$" + headings = [ + match.span()[0] + for match in re.finditer(regex_split, md, flags=re.MULTILINE) + ] + + if headings == [] or headings[0] != 0: + headings.insert(0, 0) + + sections = [md[i:j] for i, j in zip(headings, headings[1:] + [None])] + + for section in sections: + chunks = flu(section.split(" ")).chunk(max_tokens) + + is_not_useless_chunk = lambda i: not i in ["", "\n", []] + + joined_chunks = filter( + is_not_useless_chunk, [" ".join(chunk) for chunk in chunks] + ) + + for joined_chunk in joined_chunks: + yield joined_chunk + + def __call__( + self, + records: Iterable[Tuple[str, Any, Optional[Dict]]], + adapter_context: AdapterContext, + max_tokens: int = 99999999, + ) -> Generator[Tuple[str, Any, Dict], None, None]: + """ + Splits each markdown string in the records into chunks where each heading starts a new chunk, and yields each chunk + as a separate record. If the `skip_during_query` attribute is set to True, + this step is skipped during querying. + + Args: + records (Iterable[Tuple[str, Any, Optional[Dict]]]): Iterable of tuples each containing an id, a markdown string and an optional dict. + adapter_context (AdapterContext): Context of the adapter. + max_tokens (int): The maximum number of tokens per chunk + + Yields: + Tuple[str, Any, Dict]: The id appended with chunk index, the chunk, and the metadata. + """ + if max_tokens and max_tokens < 1: + raise ValueError("max_tokens must be a nonzero positive integer") + + if ( + adapter_context == AdapterContext("query") + and self.skip_during_query + ): + for id, markdown, metadata in records: + yield (id, markdown, metadata or {}) + else: + for id, markdown, metadata in records: + headings = MarkdownChunker.split_by_heading( + markdown, max_tokens + ) + for heading_ix, heading in enumerate(headings): + yield ( + f"{id}_head_{str(heading_ix).zfill(3)}", + heading, + metadata or {}, + ) diff --git a/R2R/r2r/vecs/adapter/noop.py b/R2R/r2r/vecs/adapter/noop.py new file mode 100755 index 00000000..b587a552 --- /dev/null +++ b/R2R/r2r/vecs/adapter/noop.py @@ -0,0 +1,55 @@ +""" +The `vecs.experimental.adapter.noop` module provides a default no-op (no operation) adapter +that passes the inputs through without any modification. This can be useful when no specific +adapter processing is required. + +All public classes, enums, and functions are re-exported by `vecs.adapters` module. +""" + +from typing import Any, Dict, Generator, Iterable, Optional, Tuple + +from .base import AdapterContext, AdapterStep + + +class NoOp(AdapterStep): + """ + NoOp is a no-operation AdapterStep. It is a default adapter that passes through + the input records without any modifications. + """ + + def __init__(self, dimension: int): + """ + Initializes the NoOp adapter with a dimension. + + Args: + dimension (int): The dimension of the input vectors. + """ + self._dimension = dimension + + @property + def exported_dimension(self) -> Optional[int]: + """ + Returns the dimension of the adapter. + + Returns: + int: The dimension of the input vectors. + """ + return self._dimension + + def __call__( + self, + records: Iterable[Tuple[str, Any, Optional[Dict]]], + adapter_context: AdapterContext, + ) -> Generator[Tuple[str, Any, Dict], None, None]: + """ + Yields the input records without any modification. + + Args: + records: Iterable of tuples each containing an id, a media and an optional dict. + adapter_context: Context of the adapter. + + Yields: + Tuple[str, Any, Dict]: The input record. + """ + for id, media, metadata in records: + yield (id, media, metadata or {}) diff --git a/R2R/r2r/vecs/adapter/text.py b/R2R/r2r/vecs/adapter/text.py new file mode 100755 index 00000000..78ae7732 --- /dev/null +++ b/R2R/r2r/vecs/adapter/text.py @@ -0,0 +1,151 @@ +""" +The `vecs.experimental.adapter.text` module provides adapter steps specifically designed for +handling text data. It provides two main classes, `TextEmbedding` and `ParagraphChunker`. + +All public classes, enums, and functions are re-exported by `vecs.adapters` module. +""" + +from typing import Any, Dict, Generator, Iterable, Literal, Optional, Tuple + +from flupy import flu +from vecs.exc import MissingDependency + +from .base import AdapterContext, AdapterStep + +TextEmbeddingModel = Literal[ + "all-mpnet-base-v2", + "multi-qa-mpnet-base-dot-v1", + "all-distilroberta-v1", + "all-MiniLM-L12-v2", + "multi-qa-distilbert-cos-v1", + "mixedbread-ai/mxbai-embed-large-v1", + "multi-qa-MiniLM-L6-cos-v1", + "paraphrase-multilingual-mpnet-base-v2", + "paraphrase-albert-small-v2", + "paraphrase-multilingual-MiniLM-L12-v2", + "paraphrase-MiniLM-L3-v2", + "distiluse-base-multilingual-cased-v1", + "distiluse-base-multilingual-cased-v2", +] + + +class TextEmbedding(AdapterStep): + """ + TextEmbedding is an AdapterStep that converts text media into + embeddings using a specified sentence transformers model. + """ + + def __init__( + self, + *, + model: TextEmbeddingModel, + batch_size: int = 8, + use_auth_token: str = None, + ): + """ + Initializes the TextEmbedding adapter with a sentence transformers model. + + Args: + model (TextEmbeddingModel): The sentence transformers model to use for embeddings. + batch_size (int): The number of records to encode simultaneously. + use_auth_token (str): The HuggingFace Hub auth token to use for private models. + + Raises: + MissingDependency: If the sentence_transformers library is not installed. + """ + try: + from sentence_transformers import SentenceTransformer as ST + except ImportError: + raise MissingDependency( + "Missing feature vecs[text_embedding]. Hint: `pip install 'vecs[text_embedding]'`" + ) + + self.model = ST(model, use_auth_token=use_auth_token) + self._exported_dimension = ( + self.model.get_sentence_embedding_dimension() + ) + self.batch_size = batch_size + + @property + def exported_dimension(self) -> Optional[int]: + """ + Returns the dimension of the embeddings produced by the sentence transformers model. + + Returns: + int: The dimension of the embeddings. + """ + return self._exported_dimension + + def __call__( + self, + records: Iterable[Tuple[str, Any, Optional[Dict]]], + adapter_context: AdapterContext, # pyright: ignore + ) -> Generator[Tuple[str, Any, Dict], None, None]: + """ + Converts each media in the records to an embedding and yields the result. + + Args: + records: Iterable of tuples each containing an id, a media and an optional dict. + adapter_context: Context of the adapter. + + Yields: + Tuple[str, Any, Dict]: The id, the embedding, and the metadata. + """ + for batch in flu(records).chunk(self.batch_size): + batch_records = [x for x in batch] + media = [text for _, text, _ in batch_records] + + embeddings = self.model.encode(media, normalize_embeddings=True) + + for (id, _, metadata), embedding in zip(batch_records, embeddings): # type: ignore + yield (id, embedding, metadata or {}) + + +class ParagraphChunker(AdapterStep): + """ + ParagraphChunker is an AdapterStep that splits text media into + paragraphs and yields each paragraph as a separate record. + """ + + def __init__(self, *, skip_during_query: bool): + """ + Initializes the ParagraphChunker adapter. + + Args: + skip_during_query (bool): Whether to skip chunking during querying. + """ + self.skip_during_query = skip_during_query + + def __call__( + self, + records: Iterable[Tuple[str, Any, Optional[Dict]]], + adapter_context: AdapterContext, + ) -> Generator[Tuple[str, Any, Dict], None, None]: + """ + Splits each media in the records into paragraphs and yields each paragraph + as a separate record. If the `skip_during_query` attribute is set to True, + this step is skipped during querying. + + Args: + records (Iterable[Tuple[str, Any, Optional[Dict]]]): Iterable of tuples each containing an id, a media and an optional dict. + adapter_context (AdapterContext): Context of the adapter. + + Yields: + Tuple[str, Any, Dict]: The id appended with paragraph index, the paragraph, and the metadata. + """ + if ( + adapter_context == AdapterContext("query") + and self.skip_during_query + ): + for id, media, metadata in records: + yield (id, media, metadata or {}) + else: + for id, media, metadata in records: + paragraphs = media.split("\n\n") + + for paragraph_ix, paragraph in enumerate(paragraphs): + yield ( + f"{id}_para_{str(paragraph_ix).zfill(3)}", + paragraph, + metadata or {}, + ) diff --git a/R2R/r2r/vecs/client.py b/R2R/r2r/vecs/client.py new file mode 100755 index 00000000..6259f1d8 --- /dev/null +++ b/R2R/r2r/vecs/client.py @@ -0,0 +1,313 @@ +""" +Defines the 'Client' class + +Importing from the `vecs.client` directly is not supported. +All public classes, enums, and functions are re-exported by the top level `vecs` module. +""" + +from __future__ import annotations + +import logging +import time +from typing import TYPE_CHECKING, List, Optional + +import sqlalchemy +from deprecated import deprecated +from sqlalchemy import MetaData, create_engine, text +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import QueuePool + +from .adapter import Adapter +from .exc import CollectionNotFound + +if TYPE_CHECKING: + from r2r.vecs.collection import Collection + +logger = logging.getLogger(__name__) + + +class Client: + """ + The `vecs.Client` class serves as an interface to a PostgreSQL database with pgvector support. It facilitates + the creation, retrieval, listing and deletion of vector collections, while managing connections to the + database. + + A `Client` instance represents a connection to a PostgreSQL database. This connection can be used to create + and manipulate vector collections, where each collection is a group of vector records in a PostgreSQL table. + + The `vecs.Client` class can be also supports usage as a context manager to ensure the connection to the database + is properly closed after operations, or used directly. + + Example usage: + + DB_CONNECTION = "postgresql://<user>:<password>@<host>:<port>/<db_name>" + + with vecs.create_client(DB_CONNECTION) as vx: + # do some work + pass + + # OR + + vx = vecs.create_client(DB_CONNECTION) + # do some work + vx.disconnect() + """ + + def __init__( + self, + connection_string: str, + pool_size: int = 1, + max_retries: int = 3, + retry_delay: int = 1, + ): + self.engine = create_engine( + connection_string, + pool_size=pool_size, + poolclass=QueuePool, + pool_recycle=300, # Recycle connections after 5 min + ) + self.meta = MetaData(schema="vecs") + self.Session = sessionmaker(self.engine) + self.max_retries = max_retries + self.retry_delay = retry_delay + self.vector_version: Optional[str] = None + self._initialize_database() + + def _initialize_database(self): + retries = 0 + error = None + while retries < self.max_retries: + try: + with self.Session() as sess: + with sess.begin(): + self._create_schema(sess) + self._create_extension(sess) + self._get_vector_version(sess) + return + except Exception as e: + logger.warning( + f"Database connection error: {str(e)}. Retrying in {self.retry_delay} seconds..." + ) + retries += 1 + time.sleep(self.retry_delay) + error = e + + error_message = f"Failed to initialize database after {self.max_retries} retries with error: {str(error)}" + logger.error(error_message) + raise RuntimeError(error_message) + + def _create_schema(self, sess): + try: + sess.execute(text("CREATE SCHEMA IF NOT EXISTS vecs;")) + except Exception as e: + logger.warning(f"Failed to create schema: {str(e)}") + + def _create_extension(self, sess): + try: + sess.execute(text("CREATE EXTENSION IF NOT EXISTS vector;")) + sess.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;")) + sess.execute(text("CREATE EXTENSION IF NOT EXISTS fuzzystrmatch;")) + except Exception as e: + logger.warning(f"Failed to create extension: {str(e)}") + + def _get_vector_version(self, sess): + try: + self.vector_version = sess.execute( + text( + "SELECT installed_version FROM pg_available_extensions WHERE name = 'vector' LIMIT 1;" + ) + ).scalar_one() + except sqlalchemy.exc.InternalError as e: + logger.error(f"Failed with internal alchemy error: {str(e)}") + + import psycopg2 + + if isinstance(e.orig, psycopg2.errors.InFailedSqlTransaction): + sess.rollback() + self.vector_version = sess.execute( + text( + "SELECT installed_version FROM pg_available_extensions WHERE name = 'vector' LIMIT 1;" + ) + ).scalar_one() + else: + raise e + except Exception as e: + logger.error(f"Failed to retrieve vector version: {str(e)}") + raise e + + def _supports_hnsw(self): + return ( + not self.vector_version.startswith("0.4") + and not self.vector_version.startswith("0.3") + and not self.vector_version.startswith("0.2") + and not self.vector_version.startswith("0.1") + and not self.vector_version.startswith("0.0") + ) + + def get_or_create_collection( + self, + name: str, + *, + dimension: Optional[int] = None, + adapter: Optional[Adapter] = None, + ) -> Collection: + """ + Get a vector collection by name, or create it if no collection with + *name* exists. + + Args: + name (str): The name of the collection. + + Keyword Args: + dimension (int): The dimensionality of the vectors in the collection. + pipeline (int): The dimensionality of the vectors in the collection. + + Returns: + Collection: The created collection. + + Raises: + CollectionAlreadyExists: If a collection with the same name already exists + """ + from r2r.vecs.collection import Collection + + adapter_dimension = adapter.exported_dimension if adapter else None + + collection = Collection( + name=name, + dimension=dimension or adapter_dimension, # type: ignore + client=self, + adapter=adapter, + ) + + return collection._create_if_not_exists() + + @deprecated("use Client.get_or_create_collection") + def create_collection(self, name: str, dimension: int) -> Collection: + """ + Create a new vector collection. + + Args: + name (str): The name of the collection. + dimension (int): The dimensionality of the vectors in the collection. + + Returns: + Collection: The created collection. + + Raises: + CollectionAlreadyExists: If a collection with the same name already exists + """ + from r2r.vecs.collection import Collection + + return Collection(name, dimension, self)._create() + + @deprecated("use Client.get_or_create_collection") + def get_collection(self, name: str) -> Collection: + """ + Retrieve an existing vector collection. + + Args: + name (str): The name of the collection. + + Returns: + Collection: The retrieved collection. + + Raises: + CollectionNotFound: If no collection with the given name exists. + """ + from r2r.vecs.collection import Collection + + query = text( + f""" + select + relname as table_name, + atttypmod as embedding_dim + from + pg_class pc + join pg_attribute pa + on pc.oid = pa.attrelid + where + pc.relnamespace = 'vecs'::regnamespace + and pc.relkind = 'r' + and pa.attname = 'vec' + and not pc.relname ^@ '_' + and pc.relname = :name + """ + ).bindparams(name=name) + with self.Session() as sess: + query_result = sess.execute(query).fetchone() + + if query_result is None: + raise CollectionNotFound( + "No collection found with requested name" + ) + + name, dimension = query_result + return Collection( + name, + dimension, + self, + ) + + def list_collections(self) -> List["Collection"]: + """ + List all vector collections. + + Returns: + list[Collection]: A list of all collections. + """ + from r2r.vecs.collection import Collection + + return Collection._list_collections(self) + + def delete_collection(self, name: str) -> None: + """ + Delete a vector collection. + + If no collection with requested name exists, does nothing. + + Args: + name (str): The name of the collection. + + Returns: + None + """ + from r2r.vecs.collection import Collection + + Collection(name, -1, self)._drop() + return + + def disconnect(self) -> None: + """ + Disconnect the client from the database. + + Returns: + None + """ + self.engine.dispose() + logger.info("Disconnected from the database.") + return + + def __enter__(self) -> "Client": + """ + Enable use of the 'with' statement. + + Returns: + Client: The current instance of the Client. + """ + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Disconnect the client on exiting the 'with' statement context. + + Args: + exc_type: The exception type, if any. + exc_val: The exception value, if any. + exc_tb: The traceback, if any. + + Returns: + None + """ + self.disconnect() + return diff --git a/R2R/r2r/vecs/collection.py b/R2R/r2r/vecs/collection.py new file mode 100755 index 00000000..2293d49b --- /dev/null +++ b/R2R/r2r/vecs/collection.py @@ -0,0 +1,1132 @@ +""" +Defines the 'Collection' class + +Importing from the `vecs.collection` directly is not supported. +All public classes, enums, and functions are re-exported by the top level `vecs` module. +""" + +from __future__ import annotations + +import math +import uuid +import warnings +from dataclasses import dataclass +from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) + +import psycopg2 +from flupy import flu +from sqlalchemy import ( + Column, + MetaData, + String, + Table, + alias, + and_, + cast, + delete, + distinct, + func, + or_, + select, + text, +) +from sqlalchemy.dialects import postgresql +from sqlalchemy.types import Float, UserDefinedType + +from .adapter import Adapter, AdapterContext, NoOp +from .exc import ( + ArgError, + CollectionAlreadyExists, + CollectionNotFound, + FilterError, + MismatchedDimension, + Unreachable, +) + +if TYPE_CHECKING: + from vecs.client import Client + + +MetadataValues = Union[str, int, float, bool, List[str]] +Metadata = Dict[str, MetadataValues] +Numeric = Union[int, float, complex] +Record = Tuple[str, Iterable[Numeric], Metadata] + + +class IndexMethod(str, Enum): + """ + An enum representing the index methods available. + + This class currently only supports the 'ivfflat' method but may + expand in the future. + + Attributes: + auto (str): Automatically choose the best available index method. + ivfflat (str): The ivfflat index method. + hnsw (str): The hnsw index method. + """ + + auto = "auto" + ivfflat = "ivfflat" + hnsw = "hnsw" + + +class IndexMeasure(str, Enum): + """ + An enum representing the types of distance measures available for indexing. + + Attributes: + cosine_distance (str): The cosine distance measure for indexing. + l2_distance (str): The Euclidean (L2) distance measure for indexing. + max_inner_product (str): The maximum inner product measure for indexing. + """ + + cosine_distance = "cosine_distance" + l2_distance = "l2_distance" + max_inner_product = "max_inner_product" + + +@dataclass +class IndexArgsIVFFlat: + """ + A class for arguments that can optionally be supplied to the index creation + method when building an IVFFlat type index. + + Attributes: + nlist (int): The number of IVF centroids that the index should use + """ + + n_lists: int + + +@dataclass +class IndexArgsHNSW: + """ + A class for arguments that can optionally be supplied to the index creation + method when building an HNSW type index. + + Ref: https://github.com/pgvector/pgvector#index-options + + Both attributes are Optional in case the user only wants to specify one and + leave the other as default + + Attributes: + m (int): Maximum number of connections per node per layer (default: 16) + ef_construction (int): Size of the dynamic candidate list for + constructing the graph (default: 64) + """ + + m: Optional[int] = 16 + ef_construction: Optional[int] = 64 + + +INDEX_MEASURE_TO_OPS = { + # Maps the IndexMeasure enum options to the SQL ops string required by + # the pgvector `create index` statement + IndexMeasure.cosine_distance: "vector_cosine_ops", + IndexMeasure.l2_distance: "vector_l2_ops", + IndexMeasure.max_inner_product: "vector_ip_ops", +} + +INDEX_MEASURE_TO_SQLA_ACC = { + IndexMeasure.cosine_distance: lambda x: x.cosine_distance, + IndexMeasure.l2_distance: lambda x: x.l2_distance, + IndexMeasure.max_inner_product: lambda x: x.max_inner_product, +} + + +class Vector(UserDefinedType): + cache_ok = True + + def __init__(self, dim=None): + super(UserDefinedType, self).__init__() + self.dim = dim + + def get_col_spec(self, **kw): + return "VECTOR" if self.dim is None else f"VECTOR({self.dim})" + + def bind_processor(self, dialect): + def process(value): + if value is None: + return value + if not isinstance(value, list): + raise ValueError("Expected a list") + if self.dim is not None and len(value) != self.dim: + raise ValueError( + f"Expected {self.dim} dimensions, not {len(value)}" + ) + return "[" + ",".join(str(float(v)) for v in value) + "]" + + return process + + def result_processor(self, dialect, coltype): + return lambda value: ( + value + if value is None + else [float(v) for v in value[1:-1].split(",")] + ) + + class comparator_factory(UserDefinedType.Comparator): + def l2_distance(self, other): + return self.op("<->", return_type=Float)(other) + + def max_inner_product(self, other): + return self.op("<#>", return_type=Float)(other) + + def cosine_distance(self, other): + return self.op("<=>", return_type=Float)(other) + + +class Collection: + """ + The `vecs.Collection` class represents a collection of vectors within a PostgreSQL database with pgvector support. + It provides methods to manage (create, delete, fetch, upsert), index, and perform similarity searches on these vector collections. + + The collections are stored in separate tables in the database, with each vector associated with an identifier and optional metadata. + + Example usage: + + with vecs.create_client(DB_CONNECTION) as vx: + collection = vx.create_collection(name="docs", dimension=3) + collection.upsert([("id1", [1, 1, 1], {"key": "value"})]) + # Further operations on 'collection' + + Public Attributes: + name: The name of the vector collection. + dimension: The dimension of vectors in the collection. + + Note: Some methods of this class can raise exceptions from the `vecs.exc` module if errors occur. + """ + + def __init__( + self, + name: str, + dimension: int, + client: Client, + adapter: Optional[Adapter] = None, + ): + """ + Initializes a new instance of the `Collection` class. + + During expected use, developers initialize instances of `Collection` using the + `vecs.Client` with `vecs.Client.create_collection(...)` rather than directly. + + Args: + name (str): The name of the collection. + dimension (int): The dimension of the vectors in the collection. + client (Client): The client to use for interacting with the database. + """ + from r2r.vecs.adapter import Adapter + + self.client = client + self.name = name + self.dimension = dimension + self.table = build_table(name, client.meta, dimension) + self._index: Optional[str] = None + self.adapter = adapter or Adapter(steps=[NoOp(dimension=dimension)]) + + reported_dimensions = set( + [ + x + for x in [ + dimension, + adapter.exported_dimension if adapter else None, + ] + if x is not None + ] + ) + if len(reported_dimensions) == 0: + raise ArgError( + "One of dimension or adapter must provide a dimension" + ) + elif len(reported_dimensions) > 1: + raise MismatchedDimension( + "Mismatch in the reported dimensions of the selected vector collection and embedding model. Correct the selected embedding model or specify a new vector collection by modifying the `POSTGRES_VECS_COLLECTION` environment variable." + ) + + def __repr__(self): + """ + Returns a string representation of the `Collection` instance. + + Returns: + str: A string representation of the `Collection` instance. + """ + return ( + f'vecs.Collection(name="{self.name}", dimension={self.dimension})' + ) + + def __len__(self) -> int: + """ + Returns the number of vectors in the collection. + + Returns: + int: The number of vectors in the collection. + """ + with self.client.Session() as sess: + with sess.begin(): + stmt = select(func.count()).select_from(self.table) + return sess.execute(stmt).scalar() or 0 + + def _create_if_not_exists(self): + """ + PRIVATE + + Creates a new collection in the database if it doesn't already exist + + Returns: + Collection: The found or created collection. + """ + query = text( + f""" + select + relname as table_name, + atttypmod as embedding_dim + from + pg_class pc + join pg_attribute pa + on pc.oid = pa.attrelid + where + pc.relnamespace = 'vecs'::regnamespace + and pc.relkind = 'r' + and pa.attname = 'vec' + and not pc.relname ^@ '_' + and pc.relname = :name + """ + ).bindparams(name=self.name) + with self.client.Session() as sess: + query_result = sess.execute(query).fetchone() + + if query_result: + _, collection_dimension = query_result + else: + collection_dimension = None + + reported_dimensions = set( + [ + x + for x in [self.dimension, collection_dimension] + if x is not None + ] + ) + if len(reported_dimensions) > 1: + raise MismatchedDimension( + "Dimensions reported by adapter, dimension, and collection do not match. The likely cause of this is a mismatch between the dimensions of the selected vector collection and embedding model. Select the correct embedding model, or specify a new vector collection by modifying your `POSTGRES_VECS_COLLECTION` environment variable. If the selected colelction does not exist then it will be automatically with dimensions that match the selected embedding model." + ) + + if not collection_dimension: + self.table.create(self.client.engine) + + return self + + def _create(self): + """ + PRIVATE + + Creates a new collection in the database. Raises a `vecs.exc.CollectionAlreadyExists` + exception if a collection with the specified name already exists. + + Returns: + Collection: The newly created collection. + """ + + collection_exists = self.__class__._does_collection_exist( + self.client, self.name + ) + if collection_exists: + raise CollectionAlreadyExists( + "Collection with requested name already exists" + ) + self.table.create(self.client.engine) + + unique_string = str(uuid.uuid4()).replace("-", "_")[0:7] + with self.client.Session() as sess: + sess.execute( + text( + f""" + create index ix_meta_{unique_string} + on vecs."{self.table.name}" + using gin ( metadata jsonb_path_ops ) + """ + ) + ) + return self + + def _drop(self): + """ + PRIVATE + + Deletes the collection from the database. Raises a `vecs.exc.CollectionNotFound` + exception if no collection with the specified name exists. + + Returns: + Collection: The deleted collection. + """ + with self.client.Session() as sess: + sess.execute(text(f"DROP TABLE IF EXISTS {self.name} CASCADE")) + sess.commit() + + return self + + def get_unique_metadata_values( + self, + field: str, + filter_field: Optional[str] = None, + filter_value: Optional[MetadataValues] = None, + ) -> List[MetadataValues]: + """ + Fetches all unique metadata values of a specific field, optionally filtered by another metadata field. + Args: + field (str): The metadata field for which to fetch unique values. + filter_field (Optional[str], optional): The metadata field to filter on. Defaults to None. + filter_value (Optional[MetadataValues], optional): The value to filter the metadata field with. Defaults to None. + Returns: + List[MetadataValues]: A list of unique metadata values for the specified field. + """ + with self.client.Session() as sess: + with sess.begin(): + stmt = select( + distinct(self.table.c.metadata[field].astext) + ).where(self.table.c.metadata[field] != None) + + if filter_field is not None and filter_value is not None: + stmt = stmt.where( + self.table.c.metadata[filter_field].astext + == str(filter_value) + ) + + result = sess.execute(stmt) + unique_values = result.scalars().all() + + return unique_values + + def copy( + self, + records: Iterable[Tuple[str, Any, Metadata]], + skip_adapter: bool = False, + ) -> None: + """ + Copies records into the collection. + + Args: + records (Iterable[Tuple[str, Any, Metadata]]): An iterable of content to copy. + Each record is a tuple where: + - the first element is a unique string identifier + - the second element is an iterable of numeric values or relevant input type for the + adapter assigned to the collection + - the third element is metadata associated with the vector + + skip_adapter (bool): Should the adapter be skipped while copying. i.e. if vectors are being + provided, rather than a media type that needs to be transformed + """ + import csv + import io + import json + import os + + pipeline = flu(records) + for record in pipeline: + with psycopg2.connect( + database=os.getenv("POSTGRES_DBNAME"), + user=os.getenv("POSTGRES_USER"), + password=os.getenv("POSTGRES_PASSWORD"), + host=os.getenv("POSTGRES_HOST"), + port=os.getenv("POSTGRES_PORT"), + ) as conn: + with conn.cursor() as cur: + f = io.StringIO() + id, vec, metadata = record + + writer = csv.writer(f, delimiter=",", quotechar='"') + writer.writerow( + [ + str(id), + [float(ele) for ele in vec], + json.dumps(metadata), + ] + ) + f.seek(0) + result = f.getvalue() + + writer_name = ( + f'vecs."{self.table.fullname.split(".")[-1]}"' + ) + g = io.StringIO(result) + cur.copy_expert( + f"COPY {writer_name}(id, vec, metadata) FROM STDIN WITH (FORMAT csv)", + g, + ) + conn.commit() + cur.close() + conn.close() + + def upsert( + self, + records: Iterable[Tuple[str, Any, Metadata]], + skip_adapter: bool = False, + ) -> None: + """ + Inserts or updates *vectors* records in the collection. + + Args: + records (Iterable[Tuple[str, Any, Metadata]]): An iterable of content to upsert. + Each record is a tuple where: + - the first element is a unique string identifier + - the second element is an iterable of numeric values or relevant input type for the + adapter assigned to the collection + - the third element is metadata associated with the vector + + skip_adapter (bool): Should the adapter be skipped while upserting. i.e. if vectors are being + provided, rather than a media type that needs to be transformed + """ + + chunk_size = 512 + + if skip_adapter: + pipeline = flu(records).chunk(chunk_size) + else: + # Construct a lazy pipeline of steps to transform and chunk user input + pipeline = flu( + self.adapter(records, AdapterContext("upsert")) + ).chunk(chunk_size) + + with self.client.Session() as sess: + with sess.begin(): + for chunk in pipeline: + stmt = postgresql.insert(self.table).values(chunk) + stmt = stmt.on_conflict_do_update( + index_elements=[self.table.c.id], + set_=dict( + vec=stmt.excluded.vec, + metadata=stmt.excluded.metadata, + ), + ) + sess.execute(stmt) + return None + + def fetch(self, ids: Iterable[str]) -> List[Record]: + """ + Fetches vectors from the collection by their identifiers. + + Args: + ids (Iterable[str]): An iterable of vector identifiers. + + Returns: + List[Record]: A list of the fetched vectors. + """ + if isinstance(ids, str): + raise ArgError("ids must be a list of strings") + + chunk_size = 12 + records = [] + with self.client.Session() as sess: + with sess.begin(): + for id_chunk in flu(ids).chunk(chunk_size): + stmt = select(self.table).where( + self.table.c.id.in_(id_chunk) + ) + chunk_records = sess.execute(stmt) + records.extend(chunk_records) + return records + + def delete( + self, + ids: Optional[Iterable[str]] = None, + filters: Optional[Dict[str, Any]] = None, + ) -> List[str]: + """ + Deletes vectors from the collection by matching filters or ids. + + Args: + ids (Iterable[str], optional): An iterable of vector identifiers. + filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None. + + Returns: + List[str]: A list of the document IDs of the deleted vectors. + """ + if ids is None and filters is None: + raise ArgError("Either ids or filters must be provided.") + + if ids is not None and filters is not None: + raise ArgError("Either ids or filters must be provided, not both.") + + if isinstance(ids, str): + raise ArgError("ids must be a list of strings") + + ids = ids or [] + filters = filters or {} + del_document_ids = set([]) + + with self.client.Session() as sess: + with sess.begin(): + if ids: + for id_chunk in flu(ids).chunk(12): + stmt = select(self.table.c.metadata).where( + self.table.c.id.in_(id_chunk) + ) + results = sess.execute(stmt).fetchall() + for result in results: + metadata_json = result[0] + document_id = metadata_json.get("document_id") + if document_id: + del_document_ids.add(document_id) + + delete_stmt = ( + delete(self.table) + .where(self.table.c.id.in_(id_chunk)) + .returning(self.table.c.id) + ) + sess.execute(delete_stmt) + + if filters: + meta_filter = build_filters(self.table.c.metadata, filters) + stmt = select(self.table.c.metadata).where(meta_filter) + results = sess.execute(stmt).fetchall() + for result in results: + metadata_json = result[0] + document_id = metadata_json.get("document_id") + if document_id: + del_document_ids.add(document_id) + + delete_stmt = ( + delete(self.table) + .where(meta_filter) + .returning(self.table.c.id) + ) + sess.execute(delete_stmt) + + return list(del_document_ids) + + def __getitem__(self, items): + """ + Fetches a vector from the collection by its identifier. + + Args: + items (str): The identifier of the vector. + + Returns: + Record: The fetched vector. + """ + if not isinstance(items, str): + raise ArgError("items must be a string id") + + row = self.fetch([items]) + + if row == []: + raise KeyError("no item found with requested id") + return row[0] + + def query( + self, + data: Union[Iterable[Numeric], Any], + limit: int = 10, + filters: Optional[Dict] = None, + measure: Union[IndexMeasure, str] = IndexMeasure.cosine_distance, + include_value: bool = False, + include_metadata: bool = False, + *, + probes: Optional[int] = None, + ef_search: Optional[int] = None, + skip_adapter: bool = False, + ) -> Union[List[Record], List[str]]: + """ + Executes a similarity search in the collection. + + The return type is dependent on arguments *include_value* and *include_metadata* + + Args: + data (Any): The vector to use as the query. + limit (int, optional): The maximum number of results to return. Defaults to 10. + filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None. + measure (Union[IndexMeasure, str], optional): The distance measure to use for the search. Defaults to 'cosine_distance'. + include_value (bool, optional): Whether to include the distance value in the results. Defaults to False. + include_metadata (bool, optional): Whether to include the metadata in the results. Defaults to False. + probes (Optional[Int], optional): Number of ivfflat index lists to query. Higher increases accuracy but decreases speed + ef_search (Optional[Int], optional): Size of the dynamic candidate list for HNSW index search. Higher increases accuracy but decreases speed + skip_adapter (bool, optional): When True, skips any associated adapter and queries using a literal vector provided to *data* + + Returns: + Union[List[Record], List[str]]: The result of the similarity search. + """ + + if probes is None: + probes = 10 + + if ef_search is None: + ef_search = 40 + + if not isinstance(probes, int): + raise ArgError("probes must be an integer") + + if probes < 1: + raise ArgError("probes must be >= 1") + + if limit > 1000: + raise ArgError("limit must be <= 1000") + + # ValueError on bad input + try: + imeasure = IndexMeasure(measure) + except ValueError: + raise ArgError("Invalid index measure") + + if not self.is_indexed_for_measure(imeasure): + warnings.warn( + UserWarning( + f"Query does not have a covering index for {imeasure}. See Collection.create_index" + ) + ) + + if skip_adapter: + adapted_query = [("", data, {})] + else: + # Adapt the query using the pipeline + adapted_query = [ + x + for x in self.adapter( + records=[("", data, {})], + adapter_context=AdapterContext("query"), + ) + ] + + if len(adapted_query) != 1: + raise ArgError( + "Failed to produce exactly one query vector from input" + ) + + _, vec, _ = adapted_query[0] + + distance_lambda = INDEX_MEASURE_TO_SQLA_ACC.get(imeasure) + if distance_lambda is None: + # unreachable + raise ArgError("invalid distance_measure") # pragma: no cover + + distance_clause = distance_lambda(self.table.c.vec)(vec) + + cols = [self.table.c.id] + + if include_value: + cols.append(distance_clause) + + if include_metadata: + cols.append(self.table.c.metadata) + + stmt = select(*cols) + if filters: + stmt = stmt.filter( + build_filters(self.table.c.metadata, filters) # type: ignore + ) + + stmt = stmt.order_by(distance_clause) + stmt = stmt.limit(limit) + + with self.client.Session() as sess: + with sess.begin(): + # index ignored if greater than n_lists + sess.execute( + text("set local ivfflat.probes = :probes").bindparams( + probes=probes + ) + ) + if self.client._supports_hnsw(): + sess.execute( + text( + "set local hnsw.ef_search = :ef_search" + ).bindparams(ef_search=ef_search) + ) + if len(cols) == 1: + return [str(x) for x in sess.scalars(stmt).fetchall()] + return sess.execute(stmt).fetchall() or [] + + @classmethod + def _list_collections(cls, client: "Client") -> List["Collection"]: + """ + PRIVATE + + Retrieves all collections from the database. + + Args: + client (Client): The database client. + + Returns: + List[Collection]: A list of all existing collections. + """ + + query = text( + """ + select + relname as table_name, + atttypmod as embedding_dim + from + pg_class pc + join pg_attribute pa + on pc.oid = pa.attrelid + where + pc.relnamespace = 'vecs'::regnamespace + and pc.relkind = 'r' + and pa.attname = 'vec' + and not pc.relname ^@ '_' + """ + ) + xc = [] + with client.Session() as sess: + for name, dimension in sess.execute(query): + existing_collection = cls(name, dimension, client) + xc.append(existing_collection) + return xc + + @classmethod + def _does_collection_exist(cls, client: "Client", name: str) -> bool: + """ + PRIVATE + + Checks if a collection with a given name exists within the database + + Args: + client (Client): The database client. + name (str): The name of the collection + + Returns: + Exists: Whether the collection exists or not + """ + + try: + client.get_collection(name) + return True + except CollectionNotFound: + return False + + @property + def index(self) -> Optional[str]: + """ + PRIVATE + + Note: + The `index` property is private and expected to undergo refactoring. + Do not rely on it's output. + + Retrieves the SQL name of the collection's vector index, if it exists. + + Returns: + Optional[str]: The name of the index, or None if no index exists. + """ + + if self._index is None: + query = text( + """ + select + relname as table_name + from + pg_class pc + where + pc.relnamespace = 'vecs'::regnamespace + and relname ilike 'ix_vector%' + and pc.relkind = 'i' + """ + ) + with self.client.Session() as sess: + ix_name = sess.execute(query).scalar() + self._index = ix_name + return self._index + + def is_indexed_for_measure(self, measure: IndexMeasure): + """ + Checks if the collection is indexed for a specific measure. + + Args: + measure (IndexMeasure): The measure to check for. + + Returns: + bool: True if the collection is indexed for the measure, False otherwise. + """ + + index_name = self.index + if index_name is None: + return False + + ops = INDEX_MEASURE_TO_OPS.get(measure) + if ops is None: + return False + + if ops in index_name: + return True + + return False + + def create_index( + self, + measure: IndexMeasure = IndexMeasure.cosine_distance, + method: IndexMethod = IndexMethod.auto, + index_arguments: Optional[ + Union[IndexArgsIVFFlat, IndexArgsHNSW] + ] = None, + replace=True, + ) -> None: + """ + Creates an index for the collection. + + Note: + When `vecs` creates an index on a pgvector column in PostgreSQL, it uses a multi-step + process that enables performant indexes to be built for large collections with low end + database hardware. + + Those steps are: + + - Creates a new table with a different name + - Randomly selects records from the existing table + - Inserts the random records from the existing table into the new table + - Creates the requested vector index on the new table + - Upserts all data from the existing table into the new table + - Drops the existing table + - Renames the new table to the existing tables name + + If you create dependencies (like views) on the table that underpins + a `vecs.Collection` the `create_index` step may require you to drop those dependencies before + it will succeed. + + Args: + measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'. + method (IndexMethod, optional): The indexing method to use. Defaults to 'auto'. + index_arguments: (IndexArgsIVFFlat | IndexArgsHNSW, optional): Index type specific arguments + replace (bool, optional): Whether to replace the existing index. Defaults to True. + + Raises: + ArgError: If an invalid index method is used, or if *replace* is False and an index already exists. + """ + + if method not in ( + IndexMethod.ivfflat, + IndexMethod.hnsw, + IndexMethod.auto, + ): + raise ArgError("invalid index method") + + if index_arguments: + # Disallow case where user submits index arguments but uses the + # IndexMethod.auto index (index build arguments should only be + # used with a specific index) + if method == IndexMethod.auto: + raise ArgError( + "Index build parameters are not allowed when using the IndexMethod.auto index." + ) + # Disallow case where user specifies one index type but submits + # index build arguments for the other index type + if ( + isinstance(index_arguments, IndexArgsHNSW) + and method != IndexMethod.hnsw + ) or ( + isinstance(index_arguments, IndexArgsIVFFlat) + and method != IndexMethod.ivfflat + ): + raise ArgError( + f"{index_arguments.__class__.__name__} build parameters were supplied but {method} index was specified." + ) + + if method == IndexMethod.auto: + if self.client._supports_hnsw(): + method = IndexMethod.hnsw + else: + method = IndexMethod.ivfflat + + if method == IndexMethod.hnsw and not self.client._supports_hnsw(): + raise ArgError( + "HNSW Unavailable. Upgrade your pgvector installation to > 0.5.0 to enable HNSW support" + ) + + ops = INDEX_MEASURE_TO_OPS.get(measure) + if ops is None: + raise ArgError("Unknown index measure") + + unique_string = str(uuid.uuid4()).replace("-", "_")[0:7] + + with self.client.Session() as sess: + with sess.begin(): + if self.index is not None: + if replace: + sess.execute(text(f'drop index vecs."{self.index}";')) + self._index = None + else: + raise ArgError( + "replace is set to False but an index exists" + ) + + if method == IndexMethod.ivfflat: + if not index_arguments: + n_records: int = sess.execute(func.count(self.table.c.id)).scalar() # type: ignore + + n_lists = ( + int(max(n_records / 1000, 30)) + if n_records < 1_000_000 + else int(math.sqrt(n_records)) + ) + else: + # The following mypy error is ignored because mypy + # complains that `index_arguments` is typed as a union + # of IndexArgsIVFFlat and IndexArgsHNSW types, + # which both don't necessarily contain the `n_lists` + # parameter, however we have validated that the + # correct type is being used above. + n_lists = index_arguments.n_lists # type: ignore + + sess.execute( + text( + f""" + create index ix_{ops}_ivfflat_nl{n_lists}_{unique_string} + on vecs."{self.table.name}" + using ivfflat (vec {ops}) with (lists={n_lists}) + """ + ) + ) + + if method == IndexMethod.hnsw: + if not index_arguments: + index_arguments = IndexArgsHNSW() + + # See above for explanation of why the following lines + # are ignored + m = index_arguments.m # type: ignore + ef_construction = index_arguments.ef_construction # type: ignore + + sess.execute( + text( + f""" + create index ix_{ops}_hnsw_m{m}_efc{ef_construction}_{unique_string} + on vecs."{self.table.name}" + using hnsw (vec {ops}) WITH (m={m}, ef_construction={ef_construction}); + """ + ) + ) + + return None + + +def build_filters(json_col: Column, filters: Dict): + """ + Builds filters for SQL query based on provided dictionary. + + Args: + json_col (Column): The column in the database table. + filters (Dict): The dictionary specifying filter conditions. + + Raises: + FilterError: If filter conditions are not correctly formatted. + + Returns: + The filter clause for the SQL query. + """ + if not isinstance(filters, dict): + raise FilterError("filters must be a dict") + + filter_clauses = [] + + for key, value in filters.items(): + if not isinstance(key, str): + raise FilterError("*filters* keys must be strings") + + if isinstance(value, dict): + if len(value) > 1: + raise FilterError("only one operator permitted per key") + for operator, clause in value.items(): + if operator not in ( + "$eq", + "$ne", + "$lt", + "$lte", + "$gt", + "$gte", + "$in", + ): + raise FilterError("unknown operator") + + if operator == "$eq" and not hasattr(clause, "__len__"): + contains_value = cast({key: clause}, postgresql.JSONB) + filter_clauses.append(json_col.op("@>")(contains_value)) + elif operator == "$in": + if not isinstance(clause, list): + raise FilterError( + "argument to $in filter must be a list" + ) + for elem in clause: + if not isinstance(elem, (int, str, float)): + raise FilterError( + "argument to $in filter must be a list of scalars" + ) + contains_value = [ + cast(elem, postgresql.JSONB) for elem in clause + ] + filter_clauses.append( + json_col.op("->")(key).in_(contains_value) + ) + else: + matches_value = cast(clause, postgresql.JSONB) + if operator == "$eq": + filter_clauses.append( + json_col.op("->")(key) == matches_value + ) + elif operator == "$ne": + filter_clauses.append( + json_col.op("->")(key) != matches_value + ) + elif operator == "$lt": + filter_clauses.append( + json_col.op("->")(key) < matches_value + ) + elif operator == "$lte": + filter_clauses.append( + json_col.op("->")(key) <= matches_value + ) + elif operator == "$gt": + filter_clauses.append( + json_col.op("->")(key) > matches_value + ) + elif operator == "$gte": + filter_clauses.append( + json_col.op("->")(key) >= matches_value + ) + else: + raise Unreachable() + else: + raise FilterError("Filter value must be a dict with an operator") + + if len(filter_clauses) == 1: + return filter_clauses[0] + else: + return and_(*filter_clauses) + + +def build_table(name: str, meta: MetaData, dimension: int) -> Table: + """ + PRIVATE + + Builds a SQLAlchemy model underpinning a `vecs.Collection`. + + Args: + name (str): The name of the table. + meta (MetaData): MetaData instance associated with the SQL database. + dimension: The dimension of the vectors in the collection. + + Returns: + Table: The constructed SQL table. + """ + return Table( + name, + meta, + Column("id", String, primary_key=True), + Column("vec", Vector(dimension), nullable=False), + Column( + "metadata", + postgresql.JSONB, + server_default=text("'{}'::jsonb"), + nullable=False, + ), + extend_existing=True, + ) diff --git a/R2R/r2r/vecs/exc.py b/R2R/r2r/vecs/exc.py new file mode 100755 index 00000000..0ae4500c --- /dev/null +++ b/R2R/r2r/vecs/exc.py @@ -0,0 +1,83 @@ +__all__ = [ + "VecsException", + "CollectionAlreadyExists", + "CollectionNotFound", + "ArgError", + "FilterError", + "IndexNotFound", + "Unreachable", +] + + +class VecsException(Exception): + """ + Base exception class for the 'vecs' package. + All custom exceptions in the 'vecs' package should derive from this class. + """ + + ... + + +class CollectionAlreadyExists(VecsException): + """ + Exception raised when attempting to create a collection that already exists. + """ + + ... + + +class CollectionNotFound(VecsException): + """ + Exception raised when attempting to access or manipulate a collection that does not exist. + """ + + ... + + +class ArgError(VecsException): + """ + Exception raised for invalid arguments when calling a method. + """ + + ... + + +class MismatchedDimension(ArgError): + """ + Exception raised when multiple sources of truth for a collection's embedding dimension do not match. + """ + + ... + + +class FilterError(VecsException): + """ + Exception raised when there's an error related to filter usage in a query. + """ + + ... + + +class IndexNotFound(VecsException): + """ + Exception raised when attempting to access an index that does not exist. + """ + + ... + + +class Unreachable(VecsException): + """ + Exception raised when an unreachable part of the code is executed. + This is typically used for error handling in cases that should be logically impossible. + """ + + ... + + +class MissingDependency(VecsException, ImportError): + """ + Exception raised when attempting to access a feature that requires an optional dependency when the optional dependency is not present. + """ + + ... |