diff options
Diffstat (limited to 'R2R/r2r/vecs/collection.py')
-rwxr-xr-x | R2R/r2r/vecs/collection.py | 1132 |
1 files changed, 1132 insertions, 0 deletions
diff --git a/R2R/r2r/vecs/collection.py b/R2R/r2r/vecs/collection.py new file mode 100755 index 00000000..2293d49b --- /dev/null +++ b/R2R/r2r/vecs/collection.py @@ -0,0 +1,1132 @@ +""" +Defines the 'Collection' class + +Importing from the `vecs.collection` directly is not supported. +All public classes, enums, and functions are re-exported by the top level `vecs` module. +""" + +from __future__ import annotations + +import math +import uuid +import warnings +from dataclasses import dataclass +from enum import Enum +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) + +import psycopg2 +from flupy import flu +from sqlalchemy import ( + Column, + MetaData, + String, + Table, + alias, + and_, + cast, + delete, + distinct, + func, + or_, + select, + text, +) +from sqlalchemy.dialects import postgresql +from sqlalchemy.types import Float, UserDefinedType + +from .adapter import Adapter, AdapterContext, NoOp +from .exc import ( + ArgError, + CollectionAlreadyExists, + CollectionNotFound, + FilterError, + MismatchedDimension, + Unreachable, +) + +if TYPE_CHECKING: + from vecs.client import Client + + +MetadataValues = Union[str, int, float, bool, List[str]] +Metadata = Dict[str, MetadataValues] +Numeric = Union[int, float, complex] +Record = Tuple[str, Iterable[Numeric], Metadata] + + +class IndexMethod(str, Enum): + """ + An enum representing the index methods available. + + This class currently only supports the 'ivfflat' method but may + expand in the future. + + Attributes: + auto (str): Automatically choose the best available index method. + ivfflat (str): The ivfflat index method. + hnsw (str): The hnsw index method. + """ + + auto = "auto" + ivfflat = "ivfflat" + hnsw = "hnsw" + + +class IndexMeasure(str, Enum): + """ + An enum representing the types of distance measures available for indexing. + + Attributes: + cosine_distance (str): The cosine distance measure for indexing. + l2_distance (str): The Euclidean (L2) distance measure for indexing. + max_inner_product (str): The maximum inner product measure for indexing. + """ + + cosine_distance = "cosine_distance" + l2_distance = "l2_distance" + max_inner_product = "max_inner_product" + + +@dataclass +class IndexArgsIVFFlat: + """ + A class for arguments that can optionally be supplied to the index creation + method when building an IVFFlat type index. + + Attributes: + nlist (int): The number of IVF centroids that the index should use + """ + + n_lists: int + + +@dataclass +class IndexArgsHNSW: + """ + A class for arguments that can optionally be supplied to the index creation + method when building an HNSW type index. + + Ref: https://github.com/pgvector/pgvector#index-options + + Both attributes are Optional in case the user only wants to specify one and + leave the other as default + + Attributes: + m (int): Maximum number of connections per node per layer (default: 16) + ef_construction (int): Size of the dynamic candidate list for + constructing the graph (default: 64) + """ + + m: Optional[int] = 16 + ef_construction: Optional[int] = 64 + + +INDEX_MEASURE_TO_OPS = { + # Maps the IndexMeasure enum options to the SQL ops string required by + # the pgvector `create index` statement + IndexMeasure.cosine_distance: "vector_cosine_ops", + IndexMeasure.l2_distance: "vector_l2_ops", + IndexMeasure.max_inner_product: "vector_ip_ops", +} + +INDEX_MEASURE_TO_SQLA_ACC = { + IndexMeasure.cosine_distance: lambda x: x.cosine_distance, + IndexMeasure.l2_distance: lambda x: x.l2_distance, + IndexMeasure.max_inner_product: lambda x: x.max_inner_product, +} + + +class Vector(UserDefinedType): + cache_ok = True + + def __init__(self, dim=None): + super(UserDefinedType, self).__init__() + self.dim = dim + + def get_col_spec(self, **kw): + return "VECTOR" if self.dim is None else f"VECTOR({self.dim})" + + def bind_processor(self, dialect): + def process(value): + if value is None: + return value + if not isinstance(value, list): + raise ValueError("Expected a list") + if self.dim is not None and len(value) != self.dim: + raise ValueError( + f"Expected {self.dim} dimensions, not {len(value)}" + ) + return "[" + ",".join(str(float(v)) for v in value) + "]" + + return process + + def result_processor(self, dialect, coltype): + return lambda value: ( + value + if value is None + else [float(v) for v in value[1:-1].split(",")] + ) + + class comparator_factory(UserDefinedType.Comparator): + def l2_distance(self, other): + return self.op("<->", return_type=Float)(other) + + def max_inner_product(self, other): + return self.op("<#>", return_type=Float)(other) + + def cosine_distance(self, other): + return self.op("<=>", return_type=Float)(other) + + +class Collection: + """ + The `vecs.Collection` class represents a collection of vectors within a PostgreSQL database with pgvector support. + It provides methods to manage (create, delete, fetch, upsert), index, and perform similarity searches on these vector collections. + + The collections are stored in separate tables in the database, with each vector associated with an identifier and optional metadata. + + Example usage: + + with vecs.create_client(DB_CONNECTION) as vx: + collection = vx.create_collection(name="docs", dimension=3) + collection.upsert([("id1", [1, 1, 1], {"key": "value"})]) + # Further operations on 'collection' + + Public Attributes: + name: The name of the vector collection. + dimension: The dimension of vectors in the collection. + + Note: Some methods of this class can raise exceptions from the `vecs.exc` module if errors occur. + """ + + def __init__( + self, + name: str, + dimension: int, + client: Client, + adapter: Optional[Adapter] = None, + ): + """ + Initializes a new instance of the `Collection` class. + + During expected use, developers initialize instances of `Collection` using the + `vecs.Client` with `vecs.Client.create_collection(...)` rather than directly. + + Args: + name (str): The name of the collection. + dimension (int): The dimension of the vectors in the collection. + client (Client): The client to use for interacting with the database. + """ + from r2r.vecs.adapter import Adapter + + self.client = client + self.name = name + self.dimension = dimension + self.table = build_table(name, client.meta, dimension) + self._index: Optional[str] = None + self.adapter = adapter or Adapter(steps=[NoOp(dimension=dimension)]) + + reported_dimensions = set( + [ + x + for x in [ + dimension, + adapter.exported_dimension if adapter else None, + ] + if x is not None + ] + ) + if len(reported_dimensions) == 0: + raise ArgError( + "One of dimension or adapter must provide a dimension" + ) + elif len(reported_dimensions) > 1: + raise MismatchedDimension( + "Mismatch in the reported dimensions of the selected vector collection and embedding model. Correct the selected embedding model or specify a new vector collection by modifying the `POSTGRES_VECS_COLLECTION` environment variable." + ) + + def __repr__(self): + """ + Returns a string representation of the `Collection` instance. + + Returns: + str: A string representation of the `Collection` instance. + """ + return ( + f'vecs.Collection(name="{self.name}", dimension={self.dimension})' + ) + + def __len__(self) -> int: + """ + Returns the number of vectors in the collection. + + Returns: + int: The number of vectors in the collection. + """ + with self.client.Session() as sess: + with sess.begin(): + stmt = select(func.count()).select_from(self.table) + return sess.execute(stmt).scalar() or 0 + + def _create_if_not_exists(self): + """ + PRIVATE + + Creates a new collection in the database if it doesn't already exist + + Returns: + Collection: The found or created collection. + """ + query = text( + f""" + select + relname as table_name, + atttypmod as embedding_dim + from + pg_class pc + join pg_attribute pa + on pc.oid = pa.attrelid + where + pc.relnamespace = 'vecs'::regnamespace + and pc.relkind = 'r' + and pa.attname = 'vec' + and not pc.relname ^@ '_' + and pc.relname = :name + """ + ).bindparams(name=self.name) + with self.client.Session() as sess: + query_result = sess.execute(query).fetchone() + + if query_result: + _, collection_dimension = query_result + else: + collection_dimension = None + + reported_dimensions = set( + [ + x + for x in [self.dimension, collection_dimension] + if x is not None + ] + ) + if len(reported_dimensions) > 1: + raise MismatchedDimension( + "Dimensions reported by adapter, dimension, and collection do not match. The likely cause of this is a mismatch between the dimensions of the selected vector collection and embedding model. Select the correct embedding model, or specify a new vector collection by modifying your `POSTGRES_VECS_COLLECTION` environment variable. If the selected colelction does not exist then it will be automatically with dimensions that match the selected embedding model." + ) + + if not collection_dimension: + self.table.create(self.client.engine) + + return self + + def _create(self): + """ + PRIVATE + + Creates a new collection in the database. Raises a `vecs.exc.CollectionAlreadyExists` + exception if a collection with the specified name already exists. + + Returns: + Collection: The newly created collection. + """ + + collection_exists = self.__class__._does_collection_exist( + self.client, self.name + ) + if collection_exists: + raise CollectionAlreadyExists( + "Collection with requested name already exists" + ) + self.table.create(self.client.engine) + + unique_string = str(uuid.uuid4()).replace("-", "_")[0:7] + with self.client.Session() as sess: + sess.execute( + text( + f""" + create index ix_meta_{unique_string} + on vecs."{self.table.name}" + using gin ( metadata jsonb_path_ops ) + """ + ) + ) + return self + + def _drop(self): + """ + PRIVATE + + Deletes the collection from the database. Raises a `vecs.exc.CollectionNotFound` + exception if no collection with the specified name exists. + + Returns: + Collection: The deleted collection. + """ + with self.client.Session() as sess: + sess.execute(text(f"DROP TABLE IF EXISTS {self.name} CASCADE")) + sess.commit() + + return self + + def get_unique_metadata_values( + self, + field: str, + filter_field: Optional[str] = None, + filter_value: Optional[MetadataValues] = None, + ) -> List[MetadataValues]: + """ + Fetches all unique metadata values of a specific field, optionally filtered by another metadata field. + Args: + field (str): The metadata field for which to fetch unique values. + filter_field (Optional[str], optional): The metadata field to filter on. Defaults to None. + filter_value (Optional[MetadataValues], optional): The value to filter the metadata field with. Defaults to None. + Returns: + List[MetadataValues]: A list of unique metadata values for the specified field. + """ + with self.client.Session() as sess: + with sess.begin(): + stmt = select( + distinct(self.table.c.metadata[field].astext) + ).where(self.table.c.metadata[field] != None) + + if filter_field is not None and filter_value is not None: + stmt = stmt.where( + self.table.c.metadata[filter_field].astext + == str(filter_value) + ) + + result = sess.execute(stmt) + unique_values = result.scalars().all() + + return unique_values + + def copy( + self, + records: Iterable[Tuple[str, Any, Metadata]], + skip_adapter: bool = False, + ) -> None: + """ + Copies records into the collection. + + Args: + records (Iterable[Tuple[str, Any, Metadata]]): An iterable of content to copy. + Each record is a tuple where: + - the first element is a unique string identifier + - the second element is an iterable of numeric values or relevant input type for the + adapter assigned to the collection + - the third element is metadata associated with the vector + + skip_adapter (bool): Should the adapter be skipped while copying. i.e. if vectors are being + provided, rather than a media type that needs to be transformed + """ + import csv + import io + import json + import os + + pipeline = flu(records) + for record in pipeline: + with psycopg2.connect( + database=os.getenv("POSTGRES_DBNAME"), + user=os.getenv("POSTGRES_USER"), + password=os.getenv("POSTGRES_PASSWORD"), + host=os.getenv("POSTGRES_HOST"), + port=os.getenv("POSTGRES_PORT"), + ) as conn: + with conn.cursor() as cur: + f = io.StringIO() + id, vec, metadata = record + + writer = csv.writer(f, delimiter=",", quotechar='"') + writer.writerow( + [ + str(id), + [float(ele) for ele in vec], + json.dumps(metadata), + ] + ) + f.seek(0) + result = f.getvalue() + + writer_name = ( + f'vecs."{self.table.fullname.split(".")[-1]}"' + ) + g = io.StringIO(result) + cur.copy_expert( + f"COPY {writer_name}(id, vec, metadata) FROM STDIN WITH (FORMAT csv)", + g, + ) + conn.commit() + cur.close() + conn.close() + + def upsert( + self, + records: Iterable[Tuple[str, Any, Metadata]], + skip_adapter: bool = False, + ) -> None: + """ + Inserts or updates *vectors* records in the collection. + + Args: + records (Iterable[Tuple[str, Any, Metadata]]): An iterable of content to upsert. + Each record is a tuple where: + - the first element is a unique string identifier + - the second element is an iterable of numeric values or relevant input type for the + adapter assigned to the collection + - the third element is metadata associated with the vector + + skip_adapter (bool): Should the adapter be skipped while upserting. i.e. if vectors are being + provided, rather than a media type that needs to be transformed + """ + + chunk_size = 512 + + if skip_adapter: + pipeline = flu(records).chunk(chunk_size) + else: + # Construct a lazy pipeline of steps to transform and chunk user input + pipeline = flu( + self.adapter(records, AdapterContext("upsert")) + ).chunk(chunk_size) + + with self.client.Session() as sess: + with sess.begin(): + for chunk in pipeline: + stmt = postgresql.insert(self.table).values(chunk) + stmt = stmt.on_conflict_do_update( + index_elements=[self.table.c.id], + set_=dict( + vec=stmt.excluded.vec, + metadata=stmt.excluded.metadata, + ), + ) + sess.execute(stmt) + return None + + def fetch(self, ids: Iterable[str]) -> List[Record]: + """ + Fetches vectors from the collection by their identifiers. + + Args: + ids (Iterable[str]): An iterable of vector identifiers. + + Returns: + List[Record]: A list of the fetched vectors. + """ + if isinstance(ids, str): + raise ArgError("ids must be a list of strings") + + chunk_size = 12 + records = [] + with self.client.Session() as sess: + with sess.begin(): + for id_chunk in flu(ids).chunk(chunk_size): + stmt = select(self.table).where( + self.table.c.id.in_(id_chunk) + ) + chunk_records = sess.execute(stmt) + records.extend(chunk_records) + return records + + def delete( + self, + ids: Optional[Iterable[str]] = None, + filters: Optional[Dict[str, Any]] = None, + ) -> List[str]: + """ + Deletes vectors from the collection by matching filters or ids. + + Args: + ids (Iterable[str], optional): An iterable of vector identifiers. + filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None. + + Returns: + List[str]: A list of the document IDs of the deleted vectors. + """ + if ids is None and filters is None: + raise ArgError("Either ids or filters must be provided.") + + if ids is not None and filters is not None: + raise ArgError("Either ids or filters must be provided, not both.") + + if isinstance(ids, str): + raise ArgError("ids must be a list of strings") + + ids = ids or [] + filters = filters or {} + del_document_ids = set([]) + + with self.client.Session() as sess: + with sess.begin(): + if ids: + for id_chunk in flu(ids).chunk(12): + stmt = select(self.table.c.metadata).where( + self.table.c.id.in_(id_chunk) + ) + results = sess.execute(stmt).fetchall() + for result in results: + metadata_json = result[0] + document_id = metadata_json.get("document_id") + if document_id: + del_document_ids.add(document_id) + + delete_stmt = ( + delete(self.table) + .where(self.table.c.id.in_(id_chunk)) + .returning(self.table.c.id) + ) + sess.execute(delete_stmt) + + if filters: + meta_filter = build_filters(self.table.c.metadata, filters) + stmt = select(self.table.c.metadata).where(meta_filter) + results = sess.execute(stmt).fetchall() + for result in results: + metadata_json = result[0] + document_id = metadata_json.get("document_id") + if document_id: + del_document_ids.add(document_id) + + delete_stmt = ( + delete(self.table) + .where(meta_filter) + .returning(self.table.c.id) + ) + sess.execute(delete_stmt) + + return list(del_document_ids) + + def __getitem__(self, items): + """ + Fetches a vector from the collection by its identifier. + + Args: + items (str): The identifier of the vector. + + Returns: + Record: The fetched vector. + """ + if not isinstance(items, str): + raise ArgError("items must be a string id") + + row = self.fetch([items]) + + if row == []: + raise KeyError("no item found with requested id") + return row[0] + + def query( + self, + data: Union[Iterable[Numeric], Any], + limit: int = 10, + filters: Optional[Dict] = None, + measure: Union[IndexMeasure, str] = IndexMeasure.cosine_distance, + include_value: bool = False, + include_metadata: bool = False, + *, + probes: Optional[int] = None, + ef_search: Optional[int] = None, + skip_adapter: bool = False, + ) -> Union[List[Record], List[str]]: + """ + Executes a similarity search in the collection. + + The return type is dependent on arguments *include_value* and *include_metadata* + + Args: + data (Any): The vector to use as the query. + limit (int, optional): The maximum number of results to return. Defaults to 10. + filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None. + measure (Union[IndexMeasure, str], optional): The distance measure to use for the search. Defaults to 'cosine_distance'. + include_value (bool, optional): Whether to include the distance value in the results. Defaults to False. + include_metadata (bool, optional): Whether to include the metadata in the results. Defaults to False. + probes (Optional[Int], optional): Number of ivfflat index lists to query. Higher increases accuracy but decreases speed + ef_search (Optional[Int], optional): Size of the dynamic candidate list for HNSW index search. Higher increases accuracy but decreases speed + skip_adapter (bool, optional): When True, skips any associated adapter and queries using a literal vector provided to *data* + + Returns: + Union[List[Record], List[str]]: The result of the similarity search. + """ + + if probes is None: + probes = 10 + + if ef_search is None: + ef_search = 40 + + if not isinstance(probes, int): + raise ArgError("probes must be an integer") + + if probes < 1: + raise ArgError("probes must be >= 1") + + if limit > 1000: + raise ArgError("limit must be <= 1000") + + # ValueError on bad input + try: + imeasure = IndexMeasure(measure) + except ValueError: + raise ArgError("Invalid index measure") + + if not self.is_indexed_for_measure(imeasure): + warnings.warn( + UserWarning( + f"Query does not have a covering index for {imeasure}. See Collection.create_index" + ) + ) + + if skip_adapter: + adapted_query = [("", data, {})] + else: + # Adapt the query using the pipeline + adapted_query = [ + x + for x in self.adapter( + records=[("", data, {})], + adapter_context=AdapterContext("query"), + ) + ] + + if len(adapted_query) != 1: + raise ArgError( + "Failed to produce exactly one query vector from input" + ) + + _, vec, _ = adapted_query[0] + + distance_lambda = INDEX_MEASURE_TO_SQLA_ACC.get(imeasure) + if distance_lambda is None: + # unreachable + raise ArgError("invalid distance_measure") # pragma: no cover + + distance_clause = distance_lambda(self.table.c.vec)(vec) + + cols = [self.table.c.id] + + if include_value: + cols.append(distance_clause) + + if include_metadata: + cols.append(self.table.c.metadata) + + stmt = select(*cols) + if filters: + stmt = stmt.filter( + build_filters(self.table.c.metadata, filters) # type: ignore + ) + + stmt = stmt.order_by(distance_clause) + stmt = stmt.limit(limit) + + with self.client.Session() as sess: + with sess.begin(): + # index ignored if greater than n_lists + sess.execute( + text("set local ivfflat.probes = :probes").bindparams( + probes=probes + ) + ) + if self.client._supports_hnsw(): + sess.execute( + text( + "set local hnsw.ef_search = :ef_search" + ).bindparams(ef_search=ef_search) + ) + if len(cols) == 1: + return [str(x) for x in sess.scalars(stmt).fetchall()] + return sess.execute(stmt).fetchall() or [] + + @classmethod + def _list_collections(cls, client: "Client") -> List["Collection"]: + """ + PRIVATE + + Retrieves all collections from the database. + + Args: + client (Client): The database client. + + Returns: + List[Collection]: A list of all existing collections. + """ + + query = text( + """ + select + relname as table_name, + atttypmod as embedding_dim + from + pg_class pc + join pg_attribute pa + on pc.oid = pa.attrelid + where + pc.relnamespace = 'vecs'::regnamespace + and pc.relkind = 'r' + and pa.attname = 'vec' + and not pc.relname ^@ '_' + """ + ) + xc = [] + with client.Session() as sess: + for name, dimension in sess.execute(query): + existing_collection = cls(name, dimension, client) + xc.append(existing_collection) + return xc + + @classmethod + def _does_collection_exist(cls, client: "Client", name: str) -> bool: + """ + PRIVATE + + Checks if a collection with a given name exists within the database + + Args: + client (Client): The database client. + name (str): The name of the collection + + Returns: + Exists: Whether the collection exists or not + """ + + try: + client.get_collection(name) + return True + except CollectionNotFound: + return False + + @property + def index(self) -> Optional[str]: + """ + PRIVATE + + Note: + The `index` property is private and expected to undergo refactoring. + Do not rely on it's output. + + Retrieves the SQL name of the collection's vector index, if it exists. + + Returns: + Optional[str]: The name of the index, or None if no index exists. + """ + + if self._index is None: + query = text( + """ + select + relname as table_name + from + pg_class pc + where + pc.relnamespace = 'vecs'::regnamespace + and relname ilike 'ix_vector%' + and pc.relkind = 'i' + """ + ) + with self.client.Session() as sess: + ix_name = sess.execute(query).scalar() + self._index = ix_name + return self._index + + def is_indexed_for_measure(self, measure: IndexMeasure): + """ + Checks if the collection is indexed for a specific measure. + + Args: + measure (IndexMeasure): The measure to check for. + + Returns: + bool: True if the collection is indexed for the measure, False otherwise. + """ + + index_name = self.index + if index_name is None: + return False + + ops = INDEX_MEASURE_TO_OPS.get(measure) + if ops is None: + return False + + if ops in index_name: + return True + + return False + + def create_index( + self, + measure: IndexMeasure = IndexMeasure.cosine_distance, + method: IndexMethod = IndexMethod.auto, + index_arguments: Optional[ + Union[IndexArgsIVFFlat, IndexArgsHNSW] + ] = None, + replace=True, + ) -> None: + """ + Creates an index for the collection. + + Note: + When `vecs` creates an index on a pgvector column in PostgreSQL, it uses a multi-step + process that enables performant indexes to be built for large collections with low end + database hardware. + + Those steps are: + + - Creates a new table with a different name + - Randomly selects records from the existing table + - Inserts the random records from the existing table into the new table + - Creates the requested vector index on the new table + - Upserts all data from the existing table into the new table + - Drops the existing table + - Renames the new table to the existing tables name + + If you create dependencies (like views) on the table that underpins + a `vecs.Collection` the `create_index` step may require you to drop those dependencies before + it will succeed. + + Args: + measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'. + method (IndexMethod, optional): The indexing method to use. Defaults to 'auto'. + index_arguments: (IndexArgsIVFFlat | IndexArgsHNSW, optional): Index type specific arguments + replace (bool, optional): Whether to replace the existing index. Defaults to True. + + Raises: + ArgError: If an invalid index method is used, or if *replace* is False and an index already exists. + """ + + if method not in ( + IndexMethod.ivfflat, + IndexMethod.hnsw, + IndexMethod.auto, + ): + raise ArgError("invalid index method") + + if index_arguments: + # Disallow case where user submits index arguments but uses the + # IndexMethod.auto index (index build arguments should only be + # used with a specific index) + if method == IndexMethod.auto: + raise ArgError( + "Index build parameters are not allowed when using the IndexMethod.auto index." + ) + # Disallow case where user specifies one index type but submits + # index build arguments for the other index type + if ( + isinstance(index_arguments, IndexArgsHNSW) + and method != IndexMethod.hnsw + ) or ( + isinstance(index_arguments, IndexArgsIVFFlat) + and method != IndexMethod.ivfflat + ): + raise ArgError( + f"{index_arguments.__class__.__name__} build parameters were supplied but {method} index was specified." + ) + + if method == IndexMethod.auto: + if self.client._supports_hnsw(): + method = IndexMethod.hnsw + else: + method = IndexMethod.ivfflat + + if method == IndexMethod.hnsw and not self.client._supports_hnsw(): + raise ArgError( + "HNSW Unavailable. Upgrade your pgvector installation to > 0.5.0 to enable HNSW support" + ) + + ops = INDEX_MEASURE_TO_OPS.get(measure) + if ops is None: + raise ArgError("Unknown index measure") + + unique_string = str(uuid.uuid4()).replace("-", "_")[0:7] + + with self.client.Session() as sess: + with sess.begin(): + if self.index is not None: + if replace: + sess.execute(text(f'drop index vecs."{self.index}";')) + self._index = None + else: + raise ArgError( + "replace is set to False but an index exists" + ) + + if method == IndexMethod.ivfflat: + if not index_arguments: + n_records: int = sess.execute(func.count(self.table.c.id)).scalar() # type: ignore + + n_lists = ( + int(max(n_records / 1000, 30)) + if n_records < 1_000_000 + else int(math.sqrt(n_records)) + ) + else: + # The following mypy error is ignored because mypy + # complains that `index_arguments` is typed as a union + # of IndexArgsIVFFlat and IndexArgsHNSW types, + # which both don't necessarily contain the `n_lists` + # parameter, however we have validated that the + # correct type is being used above. + n_lists = index_arguments.n_lists # type: ignore + + sess.execute( + text( + f""" + create index ix_{ops}_ivfflat_nl{n_lists}_{unique_string} + on vecs."{self.table.name}" + using ivfflat (vec {ops}) with (lists={n_lists}) + """ + ) + ) + + if method == IndexMethod.hnsw: + if not index_arguments: + index_arguments = IndexArgsHNSW() + + # See above for explanation of why the following lines + # are ignored + m = index_arguments.m # type: ignore + ef_construction = index_arguments.ef_construction # type: ignore + + sess.execute( + text( + f""" + create index ix_{ops}_hnsw_m{m}_efc{ef_construction}_{unique_string} + on vecs."{self.table.name}" + using hnsw (vec {ops}) WITH (m={m}, ef_construction={ef_construction}); + """ + ) + ) + + return None + + +def build_filters(json_col: Column, filters: Dict): + """ + Builds filters for SQL query based on provided dictionary. + + Args: + json_col (Column): The column in the database table. + filters (Dict): The dictionary specifying filter conditions. + + Raises: + FilterError: If filter conditions are not correctly formatted. + + Returns: + The filter clause for the SQL query. + """ + if not isinstance(filters, dict): + raise FilterError("filters must be a dict") + + filter_clauses = [] + + for key, value in filters.items(): + if not isinstance(key, str): + raise FilterError("*filters* keys must be strings") + + if isinstance(value, dict): + if len(value) > 1: + raise FilterError("only one operator permitted per key") + for operator, clause in value.items(): + if operator not in ( + "$eq", + "$ne", + "$lt", + "$lte", + "$gt", + "$gte", + "$in", + ): + raise FilterError("unknown operator") + + if operator == "$eq" and not hasattr(clause, "__len__"): + contains_value = cast({key: clause}, postgresql.JSONB) + filter_clauses.append(json_col.op("@>")(contains_value)) + elif operator == "$in": + if not isinstance(clause, list): + raise FilterError( + "argument to $in filter must be a list" + ) + for elem in clause: + if not isinstance(elem, (int, str, float)): + raise FilterError( + "argument to $in filter must be a list of scalars" + ) + contains_value = [ + cast(elem, postgresql.JSONB) for elem in clause + ] + filter_clauses.append( + json_col.op("->")(key).in_(contains_value) + ) + else: + matches_value = cast(clause, postgresql.JSONB) + if operator == "$eq": + filter_clauses.append( + json_col.op("->")(key) == matches_value + ) + elif operator == "$ne": + filter_clauses.append( + json_col.op("->")(key) != matches_value + ) + elif operator == "$lt": + filter_clauses.append( + json_col.op("->")(key) < matches_value + ) + elif operator == "$lte": + filter_clauses.append( + json_col.op("->")(key) <= matches_value + ) + elif operator == "$gt": + filter_clauses.append( + json_col.op("->")(key) > matches_value + ) + elif operator == "$gte": + filter_clauses.append( + json_col.op("->")(key) >= matches_value + ) + else: + raise Unreachable() + else: + raise FilterError("Filter value must be a dict with an operator") + + if len(filter_clauses) == 1: + return filter_clauses[0] + else: + return and_(*filter_clauses) + + +def build_table(name: str, meta: MetaData, dimension: int) -> Table: + """ + PRIVATE + + Builds a SQLAlchemy model underpinning a `vecs.Collection`. + + Args: + name (str): The name of the table. + meta (MetaData): MetaData instance associated with the SQL database. + dimension: The dimension of the vectors in the collection. + + Returns: + Table: The constructed SQL table. + """ + return Table( + name, + meta, + Column("id", String, primary_key=True), + Column("vec", Vector(dimension), nullable=False), + Column( + "metadata", + postgresql.JSONB, + server_default=text("'{}'::jsonb"), + nullable=False, + ), + extend_existing=True, + ) |