aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/vecs
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/vecs')
-rwxr-xr-xR2R/r2r/vecs/__init__.py28
-rwxr-xr-xR2R/r2r/vecs/adapter/__init__.py15
-rwxr-xr-xR2R/r2r/vecs/adapter/base.py111
-rwxr-xr-xR2R/r2r/vecs/adapter/markdown.py88
-rwxr-xr-xR2R/r2r/vecs/adapter/noop.py55
-rwxr-xr-xR2R/r2r/vecs/adapter/text.py151
-rwxr-xr-xR2R/r2r/vecs/client.py313
-rwxr-xr-xR2R/r2r/vecs/collection.py1132
-rwxr-xr-xR2R/r2r/vecs/exc.py83
9 files changed, 1976 insertions, 0 deletions
diff --git a/R2R/r2r/vecs/__init__.py b/R2R/r2r/vecs/__init__.py
new file mode 100755
index 00000000..9d4f1d7e
--- /dev/null
+++ b/R2R/r2r/vecs/__init__.py
@@ -0,0 +1,28 @@
+from . import exc
+from .client import Client
+from .collection import (
+ Collection,
+ IndexArgsHNSW,
+ IndexArgsIVFFlat,
+ IndexMeasure,
+ IndexMethod,
+)
+
+__project__ = "vecs"
+__version__ = "0.4.2"
+
+
+__all__ = [
+ "IndexArgsIVFFlat",
+ "IndexArgsHNSW",
+ "IndexMethod",
+ "IndexMeasure",
+ "Collection",
+ "Client",
+ "exc",
+]
+
+
+def create_client(connection_string: str, *args, **kwargs) -> Client:
+ """Creates a client from a Postgres connection string"""
+ return Client(connection_string, *args, **kwargs)
diff --git a/R2R/r2r/vecs/adapter/__init__.py b/R2R/r2r/vecs/adapter/__init__.py
new file mode 100755
index 00000000..9cd9860d
--- /dev/null
+++ b/R2R/r2r/vecs/adapter/__init__.py
@@ -0,0 +1,15 @@
+from .base import Adapter, AdapterContext, AdapterStep
+from .markdown import MarkdownChunker
+from .noop import NoOp
+from .text import ParagraphChunker, TextEmbedding, TextEmbeddingModel
+
+__all__ = [
+ "Adapter",
+ "AdapterContext",
+ "AdapterStep",
+ "NoOp",
+ "ParagraphChunker",
+ "TextEmbedding",
+ "TextEmbeddingModel",
+ "MarkdownChunker",
+]
diff --git a/R2R/r2r/vecs/adapter/base.py b/R2R/r2r/vecs/adapter/base.py
new file mode 100755
index 00000000..7734e802
--- /dev/null
+++ b/R2R/r2r/vecs/adapter/base.py
@@ -0,0 +1,111 @@
+"""
+The `vecs.experimental.adapter.base` module provides abstract classes and utilities
+for creating and handling adapters in vecs. Adapters allow users to interact with
+a collection using media types other than vectors.
+
+All public classes, enums, and functions are re-exported by `vecs.adapters` module.
+"""
+
+from abc import ABC, abstractmethod
+from enum import Enum
+from typing import Any, Dict, Generator, Iterable, Optional, Tuple
+
+from vecs.exc import ArgError
+
+
+class AdapterContext(str, Enum):
+ """
+ An enum representing the different contexts in which a Pipeline
+ will be invoked.
+
+ Attributes:
+ upsert (str): The Collection.upsert method
+ query (str): The Collection.query method
+ """
+
+ upsert = "upsert"
+ query = "query"
+
+
+class AdapterStep(ABC):
+ """
+ Abstract class representing a step in the adapter pipeline.
+
+ Each adapter step should adapt a user media into a tuple of:
+ - id (str)
+ - media (unknown type)
+ - metadata (dict)
+
+ If the user provides id or metadata, default production is overridden.
+ """
+
+ @property
+ def exported_dimension(self) -> Optional[int]:
+ """
+ Property that should be overridden by subclasses to provide the output dimension
+ of the adapter step.
+ """
+ return None
+
+ @abstractmethod
+ def __call__(
+ self,
+ records: Iterable[Tuple[str, Any, Optional[Dict]]],
+ adapter_context: AdapterContext,
+ ) -> Generator[Tuple[str, Any, Dict], None, None]:
+ """
+ Abstract method that should be overridden by subclasses to handle each record.
+ """
+
+
+class Adapter:
+ """
+ Class representing a sequence of AdapterStep instances forming a pipeline.
+ """
+
+ def __init__(self, steps: list[AdapterStep]):
+ """
+ Initialize an Adapter instance with a list of AdapterStep instances.
+
+ Args:
+ steps: list of AdapterStep instances.
+
+ Raises:
+ ArgError: Raised if the steps list is empty.
+ """
+ self.steps = steps
+ if len(steps) < 1:
+ raise ArgError("Adapter must contain at least 1 step")
+
+ @property
+ def exported_dimension(self) -> Optional[int]:
+ """
+ The output dimension of the adapter. Returns the exported dimension of the last
+ AdapterStep that provides one (from end to start of the steps list).
+ """
+ for step in reversed(self.steps):
+ step_dim = step.exported_dimension
+ if step_dim is not None:
+ return step_dim
+ return None
+
+ def __call__(
+ self,
+ records: Iterable[Tuple[str, Any, Optional[Dict]]],
+ adapter_context: AdapterContext,
+ ) -> Generator[Tuple[str, Any, Dict], None, None]:
+ """
+ Invokes the adapter pipeline on an iterable of records.
+
+ Args:
+ records: Iterable of tuples each containing an id, a media and an optional dict.
+ adapter_context: Context of the adapter.
+
+ Yields:
+ Tuples each containing an id, a media and a dict.
+ """
+ pipeline = records
+ for step in self.steps:
+ pipeline = step(pipeline, adapter_context)
+
+ yield from pipeline # type: ignore
diff --git a/R2R/r2r/vecs/adapter/markdown.py b/R2R/r2r/vecs/adapter/markdown.py
new file mode 100755
index 00000000..149573f4
--- /dev/null
+++ b/R2R/r2r/vecs/adapter/markdown.py
@@ -0,0 +1,88 @@
+import re
+from typing import Any, Dict, Generator, Iterable, Optional, Tuple
+
+from flupy import flu
+
+from .base import AdapterContext, AdapterStep
+
+
+class MarkdownChunker(AdapterStep):
+ """
+ MarkdownChunker is an AdapterStep that splits a markdown string into chunks where a heading signifies the start of a chunk, and yields each chunk as a separate record.
+ """
+
+ def __init__(self, *, skip_during_query: bool):
+ """
+ Initializes the MarkdownChunker adapter.
+
+ Args:
+ skip_during_query (bool): Whether to skip chunking during querying.
+ """
+ self.skip_during_query = skip_during_query
+
+ @staticmethod
+ def split_by_heading(
+ md: str, max_tokens: int
+ ) -> Generator[str, None, None]:
+ regex_split = r"^(#{1,6}\s+.+)$"
+ headings = [
+ match.span()[0]
+ for match in re.finditer(regex_split, md, flags=re.MULTILINE)
+ ]
+
+ if headings == [] or headings[0] != 0:
+ headings.insert(0, 0)
+
+ sections = [md[i:j] for i, j in zip(headings, headings[1:] + [None])]
+
+ for section in sections:
+ chunks = flu(section.split(" ")).chunk(max_tokens)
+
+ is_not_useless_chunk = lambda i: not i in ["", "\n", []]
+
+ joined_chunks = filter(
+ is_not_useless_chunk, [" ".join(chunk) for chunk in chunks]
+ )
+
+ for joined_chunk in joined_chunks:
+ yield joined_chunk
+
+ def __call__(
+ self,
+ records: Iterable[Tuple[str, Any, Optional[Dict]]],
+ adapter_context: AdapterContext,
+ max_tokens: int = 99999999,
+ ) -> Generator[Tuple[str, Any, Dict], None, None]:
+ """
+ Splits each markdown string in the records into chunks where each heading starts a new chunk, and yields each chunk
+ as a separate record. If the `skip_during_query` attribute is set to True,
+ this step is skipped during querying.
+
+ Args:
+ records (Iterable[Tuple[str, Any, Optional[Dict]]]): Iterable of tuples each containing an id, a markdown string and an optional dict.
+ adapter_context (AdapterContext): Context of the adapter.
+ max_tokens (int): The maximum number of tokens per chunk
+
+ Yields:
+ Tuple[str, Any, Dict]: The id appended with chunk index, the chunk, and the metadata.
+ """
+ if max_tokens and max_tokens < 1:
+ raise ValueError("max_tokens must be a nonzero positive integer")
+
+ if (
+ adapter_context == AdapterContext("query")
+ and self.skip_during_query
+ ):
+ for id, markdown, metadata in records:
+ yield (id, markdown, metadata or {})
+ else:
+ for id, markdown, metadata in records:
+ headings = MarkdownChunker.split_by_heading(
+ markdown, max_tokens
+ )
+ for heading_ix, heading in enumerate(headings):
+ yield (
+ f"{id}_head_{str(heading_ix).zfill(3)}",
+ heading,
+ metadata or {},
+ )
diff --git a/R2R/r2r/vecs/adapter/noop.py b/R2R/r2r/vecs/adapter/noop.py
new file mode 100755
index 00000000..b587a552
--- /dev/null
+++ b/R2R/r2r/vecs/adapter/noop.py
@@ -0,0 +1,55 @@
+"""
+The `vecs.experimental.adapter.noop` module provides a default no-op (no operation) adapter
+that passes the inputs through without any modification. This can be useful when no specific
+adapter processing is required.
+
+All public classes, enums, and functions are re-exported by `vecs.adapters` module.
+"""
+
+from typing import Any, Dict, Generator, Iterable, Optional, Tuple
+
+from .base import AdapterContext, AdapterStep
+
+
+class NoOp(AdapterStep):
+ """
+ NoOp is a no-operation AdapterStep. It is a default adapter that passes through
+ the input records without any modifications.
+ """
+
+ def __init__(self, dimension: int):
+ """
+ Initializes the NoOp adapter with a dimension.
+
+ Args:
+ dimension (int): The dimension of the input vectors.
+ """
+ self._dimension = dimension
+
+ @property
+ def exported_dimension(self) -> Optional[int]:
+ """
+ Returns the dimension of the adapter.
+
+ Returns:
+ int: The dimension of the input vectors.
+ """
+ return self._dimension
+
+ def __call__(
+ self,
+ records: Iterable[Tuple[str, Any, Optional[Dict]]],
+ adapter_context: AdapterContext,
+ ) -> Generator[Tuple[str, Any, Dict], None, None]:
+ """
+ Yields the input records without any modification.
+
+ Args:
+ records: Iterable of tuples each containing an id, a media and an optional dict.
+ adapter_context: Context of the adapter.
+
+ Yields:
+ Tuple[str, Any, Dict]: The input record.
+ """
+ for id, media, metadata in records:
+ yield (id, media, metadata or {})
diff --git a/R2R/r2r/vecs/adapter/text.py b/R2R/r2r/vecs/adapter/text.py
new file mode 100755
index 00000000..78ae7732
--- /dev/null
+++ b/R2R/r2r/vecs/adapter/text.py
@@ -0,0 +1,151 @@
+"""
+The `vecs.experimental.adapter.text` module provides adapter steps specifically designed for
+handling text data. It provides two main classes, `TextEmbedding` and `ParagraphChunker`.
+
+All public classes, enums, and functions are re-exported by `vecs.adapters` module.
+"""
+
+from typing import Any, Dict, Generator, Iterable, Literal, Optional, Tuple
+
+from flupy import flu
+from vecs.exc import MissingDependency
+
+from .base import AdapterContext, AdapterStep
+
+TextEmbeddingModel = Literal[
+ "all-mpnet-base-v2",
+ "multi-qa-mpnet-base-dot-v1",
+ "all-distilroberta-v1",
+ "all-MiniLM-L12-v2",
+ "multi-qa-distilbert-cos-v1",
+ "mixedbread-ai/mxbai-embed-large-v1",
+ "multi-qa-MiniLM-L6-cos-v1",
+ "paraphrase-multilingual-mpnet-base-v2",
+ "paraphrase-albert-small-v2",
+ "paraphrase-multilingual-MiniLM-L12-v2",
+ "paraphrase-MiniLM-L3-v2",
+ "distiluse-base-multilingual-cased-v1",
+ "distiluse-base-multilingual-cased-v2",
+]
+
+
+class TextEmbedding(AdapterStep):
+ """
+ TextEmbedding is an AdapterStep that converts text media into
+ embeddings using a specified sentence transformers model.
+ """
+
+ def __init__(
+ self,
+ *,
+ model: TextEmbeddingModel,
+ batch_size: int = 8,
+ use_auth_token: str = None,
+ ):
+ """
+ Initializes the TextEmbedding adapter with a sentence transformers model.
+
+ Args:
+ model (TextEmbeddingModel): The sentence transformers model to use for embeddings.
+ batch_size (int): The number of records to encode simultaneously.
+ use_auth_token (str): The HuggingFace Hub auth token to use for private models.
+
+ Raises:
+ MissingDependency: If the sentence_transformers library is not installed.
+ """
+ try:
+ from sentence_transformers import SentenceTransformer as ST
+ except ImportError:
+ raise MissingDependency(
+ "Missing feature vecs[text_embedding]. Hint: `pip install 'vecs[text_embedding]'`"
+ )
+
+ self.model = ST(model, use_auth_token=use_auth_token)
+ self._exported_dimension = (
+ self.model.get_sentence_embedding_dimension()
+ )
+ self.batch_size = batch_size
+
+ @property
+ def exported_dimension(self) -> Optional[int]:
+ """
+ Returns the dimension of the embeddings produced by the sentence transformers model.
+
+ Returns:
+ int: The dimension of the embeddings.
+ """
+ return self._exported_dimension
+
+ def __call__(
+ self,
+ records: Iterable[Tuple[str, Any, Optional[Dict]]],
+ adapter_context: AdapterContext, # pyright: ignore
+ ) -> Generator[Tuple[str, Any, Dict], None, None]:
+ """
+ Converts each media in the records to an embedding and yields the result.
+
+ Args:
+ records: Iterable of tuples each containing an id, a media and an optional dict.
+ adapter_context: Context of the adapter.
+
+ Yields:
+ Tuple[str, Any, Dict]: The id, the embedding, and the metadata.
+ """
+ for batch in flu(records).chunk(self.batch_size):
+ batch_records = [x for x in batch]
+ media = [text for _, text, _ in batch_records]
+
+ embeddings = self.model.encode(media, normalize_embeddings=True)
+
+ for (id, _, metadata), embedding in zip(batch_records, embeddings): # type: ignore
+ yield (id, embedding, metadata or {})
+
+
+class ParagraphChunker(AdapterStep):
+ """
+ ParagraphChunker is an AdapterStep that splits text media into
+ paragraphs and yields each paragraph as a separate record.
+ """
+
+ def __init__(self, *, skip_during_query: bool):
+ """
+ Initializes the ParagraphChunker adapter.
+
+ Args:
+ skip_during_query (bool): Whether to skip chunking during querying.
+ """
+ self.skip_during_query = skip_during_query
+
+ def __call__(
+ self,
+ records: Iterable[Tuple[str, Any, Optional[Dict]]],
+ adapter_context: AdapterContext,
+ ) -> Generator[Tuple[str, Any, Dict], None, None]:
+ """
+ Splits each media in the records into paragraphs and yields each paragraph
+ as a separate record. If the `skip_during_query` attribute is set to True,
+ this step is skipped during querying.
+
+ Args:
+ records (Iterable[Tuple[str, Any, Optional[Dict]]]): Iterable of tuples each containing an id, a media and an optional dict.
+ adapter_context (AdapterContext): Context of the adapter.
+
+ Yields:
+ Tuple[str, Any, Dict]: The id appended with paragraph index, the paragraph, and the metadata.
+ """
+ if (
+ adapter_context == AdapterContext("query")
+ and self.skip_during_query
+ ):
+ for id, media, metadata in records:
+ yield (id, media, metadata or {})
+ else:
+ for id, media, metadata in records:
+ paragraphs = media.split("\n\n")
+
+ for paragraph_ix, paragraph in enumerate(paragraphs):
+ yield (
+ f"{id}_para_{str(paragraph_ix).zfill(3)}",
+ paragraph,
+ metadata or {},
+ )
diff --git a/R2R/r2r/vecs/client.py b/R2R/r2r/vecs/client.py
new file mode 100755
index 00000000..6259f1d8
--- /dev/null
+++ b/R2R/r2r/vecs/client.py
@@ -0,0 +1,313 @@
+"""
+Defines the 'Client' class
+
+Importing from the `vecs.client` directly is not supported.
+All public classes, enums, and functions are re-exported by the top level `vecs` module.
+"""
+
+from __future__ import annotations
+
+import logging
+import time
+from typing import TYPE_CHECKING, List, Optional
+
+import sqlalchemy
+from deprecated import deprecated
+from sqlalchemy import MetaData, create_engine, text
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy.pool import QueuePool
+
+from .adapter import Adapter
+from .exc import CollectionNotFound
+
+if TYPE_CHECKING:
+ from r2r.vecs.collection import Collection
+
+logger = logging.getLogger(__name__)
+
+
+class Client:
+ """
+ The `vecs.Client` class serves as an interface to a PostgreSQL database with pgvector support. It facilitates
+ the creation, retrieval, listing and deletion of vector collections, while managing connections to the
+ database.
+
+ A `Client` instance represents a connection to a PostgreSQL database. This connection can be used to create
+ and manipulate vector collections, where each collection is a group of vector records in a PostgreSQL table.
+
+ The `vecs.Client` class can be also supports usage as a context manager to ensure the connection to the database
+ is properly closed after operations, or used directly.
+
+ Example usage:
+
+ DB_CONNECTION = "postgresql://<user>:<password>@<host>:<port>/<db_name>"
+
+ with vecs.create_client(DB_CONNECTION) as vx:
+ # do some work
+ pass
+
+ # OR
+
+ vx = vecs.create_client(DB_CONNECTION)
+ # do some work
+ vx.disconnect()
+ """
+
+ def __init__(
+ self,
+ connection_string: str,
+ pool_size: int = 1,
+ max_retries: int = 3,
+ retry_delay: int = 1,
+ ):
+ self.engine = create_engine(
+ connection_string,
+ pool_size=pool_size,
+ poolclass=QueuePool,
+ pool_recycle=300, # Recycle connections after 5 min
+ )
+ self.meta = MetaData(schema="vecs")
+ self.Session = sessionmaker(self.engine)
+ self.max_retries = max_retries
+ self.retry_delay = retry_delay
+ self.vector_version: Optional[str] = None
+ self._initialize_database()
+
+ def _initialize_database(self):
+ retries = 0
+ error = None
+ while retries < self.max_retries:
+ try:
+ with self.Session() as sess:
+ with sess.begin():
+ self._create_schema(sess)
+ self._create_extension(sess)
+ self._get_vector_version(sess)
+ return
+ except Exception as e:
+ logger.warning(
+ f"Database connection error: {str(e)}. Retrying in {self.retry_delay} seconds..."
+ )
+ retries += 1
+ time.sleep(self.retry_delay)
+ error = e
+
+ error_message = f"Failed to initialize database after {self.max_retries} retries with error: {str(error)}"
+ logger.error(error_message)
+ raise RuntimeError(error_message)
+
+ def _create_schema(self, sess):
+ try:
+ sess.execute(text("CREATE SCHEMA IF NOT EXISTS vecs;"))
+ except Exception as e:
+ logger.warning(f"Failed to create schema: {str(e)}")
+
+ def _create_extension(self, sess):
+ try:
+ sess.execute(text("CREATE EXTENSION IF NOT EXISTS vector;"))
+ sess.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm;"))
+ sess.execute(text("CREATE EXTENSION IF NOT EXISTS fuzzystrmatch;"))
+ except Exception as e:
+ logger.warning(f"Failed to create extension: {str(e)}")
+
+ def _get_vector_version(self, sess):
+ try:
+ self.vector_version = sess.execute(
+ text(
+ "SELECT installed_version FROM pg_available_extensions WHERE name = 'vector' LIMIT 1;"
+ )
+ ).scalar_one()
+ except sqlalchemy.exc.InternalError as e:
+ logger.error(f"Failed with internal alchemy error: {str(e)}")
+
+ import psycopg2
+
+ if isinstance(e.orig, psycopg2.errors.InFailedSqlTransaction):
+ sess.rollback()
+ self.vector_version = sess.execute(
+ text(
+ "SELECT installed_version FROM pg_available_extensions WHERE name = 'vector' LIMIT 1;"
+ )
+ ).scalar_one()
+ else:
+ raise e
+ except Exception as e:
+ logger.error(f"Failed to retrieve vector version: {str(e)}")
+ raise e
+
+ def _supports_hnsw(self):
+ return (
+ not self.vector_version.startswith("0.4")
+ and not self.vector_version.startswith("0.3")
+ and not self.vector_version.startswith("0.2")
+ and not self.vector_version.startswith("0.1")
+ and not self.vector_version.startswith("0.0")
+ )
+
+ def get_or_create_collection(
+ self,
+ name: str,
+ *,
+ dimension: Optional[int] = None,
+ adapter: Optional[Adapter] = None,
+ ) -> Collection:
+ """
+ Get a vector collection by name, or create it if no collection with
+ *name* exists.
+
+ Args:
+ name (str): The name of the collection.
+
+ Keyword Args:
+ dimension (int): The dimensionality of the vectors in the collection.
+ pipeline (int): The dimensionality of the vectors in the collection.
+
+ Returns:
+ Collection: The created collection.
+
+ Raises:
+ CollectionAlreadyExists: If a collection with the same name already exists
+ """
+ from r2r.vecs.collection import Collection
+
+ adapter_dimension = adapter.exported_dimension if adapter else None
+
+ collection = Collection(
+ name=name,
+ dimension=dimension or adapter_dimension, # type: ignore
+ client=self,
+ adapter=adapter,
+ )
+
+ return collection._create_if_not_exists()
+
+ @deprecated("use Client.get_or_create_collection")
+ def create_collection(self, name: str, dimension: int) -> Collection:
+ """
+ Create a new vector collection.
+
+ Args:
+ name (str): The name of the collection.
+ dimension (int): The dimensionality of the vectors in the collection.
+
+ Returns:
+ Collection: The created collection.
+
+ Raises:
+ CollectionAlreadyExists: If a collection with the same name already exists
+ """
+ from r2r.vecs.collection import Collection
+
+ return Collection(name, dimension, self)._create()
+
+ @deprecated("use Client.get_or_create_collection")
+ def get_collection(self, name: str) -> Collection:
+ """
+ Retrieve an existing vector collection.
+
+ Args:
+ name (str): The name of the collection.
+
+ Returns:
+ Collection: The retrieved collection.
+
+ Raises:
+ CollectionNotFound: If no collection with the given name exists.
+ """
+ from r2r.vecs.collection import Collection
+
+ query = text(
+ f"""
+ select
+ relname as table_name,
+ atttypmod as embedding_dim
+ from
+ pg_class pc
+ join pg_attribute pa
+ on pc.oid = pa.attrelid
+ where
+ pc.relnamespace = 'vecs'::regnamespace
+ and pc.relkind = 'r'
+ and pa.attname = 'vec'
+ and not pc.relname ^@ '_'
+ and pc.relname = :name
+ """
+ ).bindparams(name=name)
+ with self.Session() as sess:
+ query_result = sess.execute(query).fetchone()
+
+ if query_result is None:
+ raise CollectionNotFound(
+ "No collection found with requested name"
+ )
+
+ name, dimension = query_result
+ return Collection(
+ name,
+ dimension,
+ self,
+ )
+
+ def list_collections(self) -> List["Collection"]:
+ """
+ List all vector collections.
+
+ Returns:
+ list[Collection]: A list of all collections.
+ """
+ from r2r.vecs.collection import Collection
+
+ return Collection._list_collections(self)
+
+ def delete_collection(self, name: str) -> None:
+ """
+ Delete a vector collection.
+
+ If no collection with requested name exists, does nothing.
+
+ Args:
+ name (str): The name of the collection.
+
+ Returns:
+ None
+ """
+ from r2r.vecs.collection import Collection
+
+ Collection(name, -1, self)._drop()
+ return
+
+ def disconnect(self) -> None:
+ """
+ Disconnect the client from the database.
+
+ Returns:
+ None
+ """
+ self.engine.dispose()
+ logger.info("Disconnected from the database.")
+ return
+
+ def __enter__(self) -> "Client":
+ """
+ Enable use of the 'with' statement.
+
+ Returns:
+ Client: The current instance of the Client.
+ """
+
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """
+ Disconnect the client on exiting the 'with' statement context.
+
+ Args:
+ exc_type: The exception type, if any.
+ exc_val: The exception value, if any.
+ exc_tb: The traceback, if any.
+
+ Returns:
+ None
+ """
+ self.disconnect()
+ return
diff --git a/R2R/r2r/vecs/collection.py b/R2R/r2r/vecs/collection.py
new file mode 100755
index 00000000..2293d49b
--- /dev/null
+++ b/R2R/r2r/vecs/collection.py
@@ -0,0 +1,1132 @@
+"""
+Defines the 'Collection' class
+
+Importing from the `vecs.collection` directly is not supported.
+All public classes, enums, and functions are re-exported by the top level `vecs` module.
+"""
+
+from __future__ import annotations
+
+import math
+import uuid
+import warnings
+from dataclasses import dataclass
+from enum import Enum
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
+
+import psycopg2
+from flupy import flu
+from sqlalchemy import (
+ Column,
+ MetaData,
+ String,
+ Table,
+ alias,
+ and_,
+ cast,
+ delete,
+ distinct,
+ func,
+ or_,
+ select,
+ text,
+)
+from sqlalchemy.dialects import postgresql
+from sqlalchemy.types import Float, UserDefinedType
+
+from .adapter import Adapter, AdapterContext, NoOp
+from .exc import (
+ ArgError,
+ CollectionAlreadyExists,
+ CollectionNotFound,
+ FilterError,
+ MismatchedDimension,
+ Unreachable,
+)
+
+if TYPE_CHECKING:
+ from vecs.client import Client
+
+
+MetadataValues = Union[str, int, float, bool, List[str]]
+Metadata = Dict[str, MetadataValues]
+Numeric = Union[int, float, complex]
+Record = Tuple[str, Iterable[Numeric], Metadata]
+
+
+class IndexMethod(str, Enum):
+ """
+ An enum representing the index methods available.
+
+ This class currently only supports the 'ivfflat' method but may
+ expand in the future.
+
+ Attributes:
+ auto (str): Automatically choose the best available index method.
+ ivfflat (str): The ivfflat index method.
+ hnsw (str): The hnsw index method.
+ """
+
+ auto = "auto"
+ ivfflat = "ivfflat"
+ hnsw = "hnsw"
+
+
+class IndexMeasure(str, Enum):
+ """
+ An enum representing the types of distance measures available for indexing.
+
+ Attributes:
+ cosine_distance (str): The cosine distance measure for indexing.
+ l2_distance (str): The Euclidean (L2) distance measure for indexing.
+ max_inner_product (str): The maximum inner product measure for indexing.
+ """
+
+ cosine_distance = "cosine_distance"
+ l2_distance = "l2_distance"
+ max_inner_product = "max_inner_product"
+
+
+@dataclass
+class IndexArgsIVFFlat:
+ """
+ A class for arguments that can optionally be supplied to the index creation
+ method when building an IVFFlat type index.
+
+ Attributes:
+ nlist (int): The number of IVF centroids that the index should use
+ """
+
+ n_lists: int
+
+
+@dataclass
+class IndexArgsHNSW:
+ """
+ A class for arguments that can optionally be supplied to the index creation
+ method when building an HNSW type index.
+
+ Ref: https://github.com/pgvector/pgvector#index-options
+
+ Both attributes are Optional in case the user only wants to specify one and
+ leave the other as default
+
+ Attributes:
+ m (int): Maximum number of connections per node per layer (default: 16)
+ ef_construction (int): Size of the dynamic candidate list for
+ constructing the graph (default: 64)
+ """
+
+ m: Optional[int] = 16
+ ef_construction: Optional[int] = 64
+
+
+INDEX_MEASURE_TO_OPS = {
+ # Maps the IndexMeasure enum options to the SQL ops string required by
+ # the pgvector `create index` statement
+ IndexMeasure.cosine_distance: "vector_cosine_ops",
+ IndexMeasure.l2_distance: "vector_l2_ops",
+ IndexMeasure.max_inner_product: "vector_ip_ops",
+}
+
+INDEX_MEASURE_TO_SQLA_ACC = {
+ IndexMeasure.cosine_distance: lambda x: x.cosine_distance,
+ IndexMeasure.l2_distance: lambda x: x.l2_distance,
+ IndexMeasure.max_inner_product: lambda x: x.max_inner_product,
+}
+
+
+class Vector(UserDefinedType):
+ cache_ok = True
+
+ def __init__(self, dim=None):
+ super(UserDefinedType, self).__init__()
+ self.dim = dim
+
+ def get_col_spec(self, **kw):
+ return "VECTOR" if self.dim is None else f"VECTOR({self.dim})"
+
+ def bind_processor(self, dialect):
+ def process(value):
+ if value is None:
+ return value
+ if not isinstance(value, list):
+ raise ValueError("Expected a list")
+ if self.dim is not None and len(value) != self.dim:
+ raise ValueError(
+ f"Expected {self.dim} dimensions, not {len(value)}"
+ )
+ return "[" + ",".join(str(float(v)) for v in value) + "]"
+
+ return process
+
+ def result_processor(self, dialect, coltype):
+ return lambda value: (
+ value
+ if value is None
+ else [float(v) for v in value[1:-1].split(",")]
+ )
+
+ class comparator_factory(UserDefinedType.Comparator):
+ def l2_distance(self, other):
+ return self.op("<->", return_type=Float)(other)
+
+ def max_inner_product(self, other):
+ return self.op("<#>", return_type=Float)(other)
+
+ def cosine_distance(self, other):
+ return self.op("<=>", return_type=Float)(other)
+
+
+class Collection:
+ """
+ The `vecs.Collection` class represents a collection of vectors within a PostgreSQL database with pgvector support.
+ It provides methods to manage (create, delete, fetch, upsert), index, and perform similarity searches on these vector collections.
+
+ The collections are stored in separate tables in the database, with each vector associated with an identifier and optional metadata.
+
+ Example usage:
+
+ with vecs.create_client(DB_CONNECTION) as vx:
+ collection = vx.create_collection(name="docs", dimension=3)
+ collection.upsert([("id1", [1, 1, 1], {"key": "value"})])
+ # Further operations on 'collection'
+
+ Public Attributes:
+ name: The name of the vector collection.
+ dimension: The dimension of vectors in the collection.
+
+ Note: Some methods of this class can raise exceptions from the `vecs.exc` module if errors occur.
+ """
+
+ def __init__(
+ self,
+ name: str,
+ dimension: int,
+ client: Client,
+ adapter: Optional[Adapter] = None,
+ ):
+ """
+ Initializes a new instance of the `Collection` class.
+
+ During expected use, developers initialize instances of `Collection` using the
+ `vecs.Client` with `vecs.Client.create_collection(...)` rather than directly.
+
+ Args:
+ name (str): The name of the collection.
+ dimension (int): The dimension of the vectors in the collection.
+ client (Client): The client to use for interacting with the database.
+ """
+ from r2r.vecs.adapter import Adapter
+
+ self.client = client
+ self.name = name
+ self.dimension = dimension
+ self.table = build_table(name, client.meta, dimension)
+ self._index: Optional[str] = None
+ self.adapter = adapter or Adapter(steps=[NoOp(dimension=dimension)])
+
+ reported_dimensions = set(
+ [
+ x
+ for x in [
+ dimension,
+ adapter.exported_dimension if adapter else None,
+ ]
+ if x is not None
+ ]
+ )
+ if len(reported_dimensions) == 0:
+ raise ArgError(
+ "One of dimension or adapter must provide a dimension"
+ )
+ elif len(reported_dimensions) > 1:
+ raise MismatchedDimension(
+ "Mismatch in the reported dimensions of the selected vector collection and embedding model. Correct the selected embedding model or specify a new vector collection by modifying the `POSTGRES_VECS_COLLECTION` environment variable."
+ )
+
+ def __repr__(self):
+ """
+ Returns a string representation of the `Collection` instance.
+
+ Returns:
+ str: A string representation of the `Collection` instance.
+ """
+ return (
+ f'vecs.Collection(name="{self.name}", dimension={self.dimension})'
+ )
+
+ def __len__(self) -> int:
+ """
+ Returns the number of vectors in the collection.
+
+ Returns:
+ int: The number of vectors in the collection.
+ """
+ with self.client.Session() as sess:
+ with sess.begin():
+ stmt = select(func.count()).select_from(self.table)
+ return sess.execute(stmt).scalar() or 0
+
+ def _create_if_not_exists(self):
+ """
+ PRIVATE
+
+ Creates a new collection in the database if it doesn't already exist
+
+ Returns:
+ Collection: The found or created collection.
+ """
+ query = text(
+ f"""
+ select
+ relname as table_name,
+ atttypmod as embedding_dim
+ from
+ pg_class pc
+ join pg_attribute pa
+ on pc.oid = pa.attrelid
+ where
+ pc.relnamespace = 'vecs'::regnamespace
+ and pc.relkind = 'r'
+ and pa.attname = 'vec'
+ and not pc.relname ^@ '_'
+ and pc.relname = :name
+ """
+ ).bindparams(name=self.name)
+ with self.client.Session() as sess:
+ query_result = sess.execute(query).fetchone()
+
+ if query_result:
+ _, collection_dimension = query_result
+ else:
+ collection_dimension = None
+
+ reported_dimensions = set(
+ [
+ x
+ for x in [self.dimension, collection_dimension]
+ if x is not None
+ ]
+ )
+ if len(reported_dimensions) > 1:
+ raise MismatchedDimension(
+ "Dimensions reported by adapter, dimension, and collection do not match. The likely cause of this is a mismatch between the dimensions of the selected vector collection and embedding model. Select the correct embedding model, or specify a new vector collection by modifying your `POSTGRES_VECS_COLLECTION` environment variable. If the selected colelction does not exist then it will be automatically with dimensions that match the selected embedding model."
+ )
+
+ if not collection_dimension:
+ self.table.create(self.client.engine)
+
+ return self
+
+ def _create(self):
+ """
+ PRIVATE
+
+ Creates a new collection in the database. Raises a `vecs.exc.CollectionAlreadyExists`
+ exception if a collection with the specified name already exists.
+
+ Returns:
+ Collection: The newly created collection.
+ """
+
+ collection_exists = self.__class__._does_collection_exist(
+ self.client, self.name
+ )
+ if collection_exists:
+ raise CollectionAlreadyExists(
+ "Collection with requested name already exists"
+ )
+ self.table.create(self.client.engine)
+
+ unique_string = str(uuid.uuid4()).replace("-", "_")[0:7]
+ with self.client.Session() as sess:
+ sess.execute(
+ text(
+ f"""
+ create index ix_meta_{unique_string}
+ on vecs."{self.table.name}"
+ using gin ( metadata jsonb_path_ops )
+ """
+ )
+ )
+ return self
+
+ def _drop(self):
+ """
+ PRIVATE
+
+ Deletes the collection from the database. Raises a `vecs.exc.CollectionNotFound`
+ exception if no collection with the specified name exists.
+
+ Returns:
+ Collection: The deleted collection.
+ """
+ with self.client.Session() as sess:
+ sess.execute(text(f"DROP TABLE IF EXISTS {self.name} CASCADE"))
+ sess.commit()
+
+ return self
+
+ def get_unique_metadata_values(
+ self,
+ field: str,
+ filter_field: Optional[str] = None,
+ filter_value: Optional[MetadataValues] = None,
+ ) -> List[MetadataValues]:
+ """
+ Fetches all unique metadata values of a specific field, optionally filtered by another metadata field.
+ Args:
+ field (str): The metadata field for which to fetch unique values.
+ filter_field (Optional[str], optional): The metadata field to filter on. Defaults to None.
+ filter_value (Optional[MetadataValues], optional): The value to filter the metadata field with. Defaults to None.
+ Returns:
+ List[MetadataValues]: A list of unique metadata values for the specified field.
+ """
+ with self.client.Session() as sess:
+ with sess.begin():
+ stmt = select(
+ distinct(self.table.c.metadata[field].astext)
+ ).where(self.table.c.metadata[field] != None)
+
+ if filter_field is not None and filter_value is not None:
+ stmt = stmt.where(
+ self.table.c.metadata[filter_field].astext
+ == str(filter_value)
+ )
+
+ result = sess.execute(stmt)
+ unique_values = result.scalars().all()
+
+ return unique_values
+
+ def copy(
+ self,
+ records: Iterable[Tuple[str, Any, Metadata]],
+ skip_adapter: bool = False,
+ ) -> None:
+ """
+ Copies records into the collection.
+
+ Args:
+ records (Iterable[Tuple[str, Any, Metadata]]): An iterable of content to copy.
+ Each record is a tuple where:
+ - the first element is a unique string identifier
+ - the second element is an iterable of numeric values or relevant input type for the
+ adapter assigned to the collection
+ - the third element is metadata associated with the vector
+
+ skip_adapter (bool): Should the adapter be skipped while copying. i.e. if vectors are being
+ provided, rather than a media type that needs to be transformed
+ """
+ import csv
+ import io
+ import json
+ import os
+
+ pipeline = flu(records)
+ for record in pipeline:
+ with psycopg2.connect(
+ database=os.getenv("POSTGRES_DBNAME"),
+ user=os.getenv("POSTGRES_USER"),
+ password=os.getenv("POSTGRES_PASSWORD"),
+ host=os.getenv("POSTGRES_HOST"),
+ port=os.getenv("POSTGRES_PORT"),
+ ) as conn:
+ with conn.cursor() as cur:
+ f = io.StringIO()
+ id, vec, metadata = record
+
+ writer = csv.writer(f, delimiter=",", quotechar='"')
+ writer.writerow(
+ [
+ str(id),
+ [float(ele) for ele in vec],
+ json.dumps(metadata),
+ ]
+ )
+ f.seek(0)
+ result = f.getvalue()
+
+ writer_name = (
+ f'vecs."{self.table.fullname.split(".")[-1]}"'
+ )
+ g = io.StringIO(result)
+ cur.copy_expert(
+ f"COPY {writer_name}(id, vec, metadata) FROM STDIN WITH (FORMAT csv)",
+ g,
+ )
+ conn.commit()
+ cur.close()
+ conn.close()
+
+ def upsert(
+ self,
+ records: Iterable[Tuple[str, Any, Metadata]],
+ skip_adapter: bool = False,
+ ) -> None:
+ """
+ Inserts or updates *vectors* records in the collection.
+
+ Args:
+ records (Iterable[Tuple[str, Any, Metadata]]): An iterable of content to upsert.
+ Each record is a tuple where:
+ - the first element is a unique string identifier
+ - the second element is an iterable of numeric values or relevant input type for the
+ adapter assigned to the collection
+ - the third element is metadata associated with the vector
+
+ skip_adapter (bool): Should the adapter be skipped while upserting. i.e. if vectors are being
+ provided, rather than a media type that needs to be transformed
+ """
+
+ chunk_size = 512
+
+ if skip_adapter:
+ pipeline = flu(records).chunk(chunk_size)
+ else:
+ # Construct a lazy pipeline of steps to transform and chunk user input
+ pipeline = flu(
+ self.adapter(records, AdapterContext("upsert"))
+ ).chunk(chunk_size)
+
+ with self.client.Session() as sess:
+ with sess.begin():
+ for chunk in pipeline:
+ stmt = postgresql.insert(self.table).values(chunk)
+ stmt = stmt.on_conflict_do_update(
+ index_elements=[self.table.c.id],
+ set_=dict(
+ vec=stmt.excluded.vec,
+ metadata=stmt.excluded.metadata,
+ ),
+ )
+ sess.execute(stmt)
+ return None
+
+ def fetch(self, ids: Iterable[str]) -> List[Record]:
+ """
+ Fetches vectors from the collection by their identifiers.
+
+ Args:
+ ids (Iterable[str]): An iterable of vector identifiers.
+
+ Returns:
+ List[Record]: A list of the fetched vectors.
+ """
+ if isinstance(ids, str):
+ raise ArgError("ids must be a list of strings")
+
+ chunk_size = 12
+ records = []
+ with self.client.Session() as sess:
+ with sess.begin():
+ for id_chunk in flu(ids).chunk(chunk_size):
+ stmt = select(self.table).where(
+ self.table.c.id.in_(id_chunk)
+ )
+ chunk_records = sess.execute(stmt)
+ records.extend(chunk_records)
+ return records
+
+ def delete(
+ self,
+ ids: Optional[Iterable[str]] = None,
+ filters: Optional[Dict[str, Any]] = None,
+ ) -> List[str]:
+ """
+ Deletes vectors from the collection by matching filters or ids.
+
+ Args:
+ ids (Iterable[str], optional): An iterable of vector identifiers.
+ filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
+
+ Returns:
+ List[str]: A list of the document IDs of the deleted vectors.
+ """
+ if ids is None and filters is None:
+ raise ArgError("Either ids or filters must be provided.")
+
+ if ids is not None and filters is not None:
+ raise ArgError("Either ids or filters must be provided, not both.")
+
+ if isinstance(ids, str):
+ raise ArgError("ids must be a list of strings")
+
+ ids = ids or []
+ filters = filters or {}
+ del_document_ids = set([])
+
+ with self.client.Session() as sess:
+ with sess.begin():
+ if ids:
+ for id_chunk in flu(ids).chunk(12):
+ stmt = select(self.table.c.metadata).where(
+ self.table.c.id.in_(id_chunk)
+ )
+ results = sess.execute(stmt).fetchall()
+ for result in results:
+ metadata_json = result[0]
+ document_id = metadata_json.get("document_id")
+ if document_id:
+ del_document_ids.add(document_id)
+
+ delete_stmt = (
+ delete(self.table)
+ .where(self.table.c.id.in_(id_chunk))
+ .returning(self.table.c.id)
+ )
+ sess.execute(delete_stmt)
+
+ if filters:
+ meta_filter = build_filters(self.table.c.metadata, filters)
+ stmt = select(self.table.c.metadata).where(meta_filter)
+ results = sess.execute(stmt).fetchall()
+ for result in results:
+ metadata_json = result[0]
+ document_id = metadata_json.get("document_id")
+ if document_id:
+ del_document_ids.add(document_id)
+
+ delete_stmt = (
+ delete(self.table)
+ .where(meta_filter)
+ .returning(self.table.c.id)
+ )
+ sess.execute(delete_stmt)
+
+ return list(del_document_ids)
+
+ def __getitem__(self, items):
+ """
+ Fetches a vector from the collection by its identifier.
+
+ Args:
+ items (str): The identifier of the vector.
+
+ Returns:
+ Record: The fetched vector.
+ """
+ if not isinstance(items, str):
+ raise ArgError("items must be a string id")
+
+ row = self.fetch([items])
+
+ if row == []:
+ raise KeyError("no item found with requested id")
+ return row[0]
+
+ def query(
+ self,
+ data: Union[Iterable[Numeric], Any],
+ limit: int = 10,
+ filters: Optional[Dict] = None,
+ measure: Union[IndexMeasure, str] = IndexMeasure.cosine_distance,
+ include_value: bool = False,
+ include_metadata: bool = False,
+ *,
+ probes: Optional[int] = None,
+ ef_search: Optional[int] = None,
+ skip_adapter: bool = False,
+ ) -> Union[List[Record], List[str]]:
+ """
+ Executes a similarity search in the collection.
+
+ The return type is dependent on arguments *include_value* and *include_metadata*
+
+ Args:
+ data (Any): The vector to use as the query.
+ limit (int, optional): The maximum number of results to return. Defaults to 10.
+ filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None.
+ measure (Union[IndexMeasure, str], optional): The distance measure to use for the search. Defaults to 'cosine_distance'.
+ include_value (bool, optional): Whether to include the distance value in the results. Defaults to False.
+ include_metadata (bool, optional): Whether to include the metadata in the results. Defaults to False.
+ probes (Optional[Int], optional): Number of ivfflat index lists to query. Higher increases accuracy but decreases speed
+ ef_search (Optional[Int], optional): Size of the dynamic candidate list for HNSW index search. Higher increases accuracy but decreases speed
+ skip_adapter (bool, optional): When True, skips any associated adapter and queries using a literal vector provided to *data*
+
+ Returns:
+ Union[List[Record], List[str]]: The result of the similarity search.
+ """
+
+ if probes is None:
+ probes = 10
+
+ if ef_search is None:
+ ef_search = 40
+
+ if not isinstance(probes, int):
+ raise ArgError("probes must be an integer")
+
+ if probes < 1:
+ raise ArgError("probes must be >= 1")
+
+ if limit > 1000:
+ raise ArgError("limit must be <= 1000")
+
+ # ValueError on bad input
+ try:
+ imeasure = IndexMeasure(measure)
+ except ValueError:
+ raise ArgError("Invalid index measure")
+
+ if not self.is_indexed_for_measure(imeasure):
+ warnings.warn(
+ UserWarning(
+ f"Query does not have a covering index for {imeasure}. See Collection.create_index"
+ )
+ )
+
+ if skip_adapter:
+ adapted_query = [("", data, {})]
+ else:
+ # Adapt the query using the pipeline
+ adapted_query = [
+ x
+ for x in self.adapter(
+ records=[("", data, {})],
+ adapter_context=AdapterContext("query"),
+ )
+ ]
+
+ if len(adapted_query) != 1:
+ raise ArgError(
+ "Failed to produce exactly one query vector from input"
+ )
+
+ _, vec, _ = adapted_query[0]
+
+ distance_lambda = INDEX_MEASURE_TO_SQLA_ACC.get(imeasure)
+ if distance_lambda is None:
+ # unreachable
+ raise ArgError("invalid distance_measure") # pragma: no cover
+
+ distance_clause = distance_lambda(self.table.c.vec)(vec)
+
+ cols = [self.table.c.id]
+
+ if include_value:
+ cols.append(distance_clause)
+
+ if include_metadata:
+ cols.append(self.table.c.metadata)
+
+ stmt = select(*cols)
+ if filters:
+ stmt = stmt.filter(
+ build_filters(self.table.c.metadata, filters) # type: ignore
+ )
+
+ stmt = stmt.order_by(distance_clause)
+ stmt = stmt.limit(limit)
+
+ with self.client.Session() as sess:
+ with sess.begin():
+ # index ignored if greater than n_lists
+ sess.execute(
+ text("set local ivfflat.probes = :probes").bindparams(
+ probes=probes
+ )
+ )
+ if self.client._supports_hnsw():
+ sess.execute(
+ text(
+ "set local hnsw.ef_search = :ef_search"
+ ).bindparams(ef_search=ef_search)
+ )
+ if len(cols) == 1:
+ return [str(x) for x in sess.scalars(stmt).fetchall()]
+ return sess.execute(stmt).fetchall() or []
+
+ @classmethod
+ def _list_collections(cls, client: "Client") -> List["Collection"]:
+ """
+ PRIVATE
+
+ Retrieves all collections from the database.
+
+ Args:
+ client (Client): The database client.
+
+ Returns:
+ List[Collection]: A list of all existing collections.
+ """
+
+ query = text(
+ """
+ select
+ relname as table_name,
+ atttypmod as embedding_dim
+ from
+ pg_class pc
+ join pg_attribute pa
+ on pc.oid = pa.attrelid
+ where
+ pc.relnamespace = 'vecs'::regnamespace
+ and pc.relkind = 'r'
+ and pa.attname = 'vec'
+ and not pc.relname ^@ '_'
+ """
+ )
+ xc = []
+ with client.Session() as sess:
+ for name, dimension in sess.execute(query):
+ existing_collection = cls(name, dimension, client)
+ xc.append(existing_collection)
+ return xc
+
+ @classmethod
+ def _does_collection_exist(cls, client: "Client", name: str) -> bool:
+ """
+ PRIVATE
+
+ Checks if a collection with a given name exists within the database
+
+ Args:
+ client (Client): The database client.
+ name (str): The name of the collection
+
+ Returns:
+ Exists: Whether the collection exists or not
+ """
+
+ try:
+ client.get_collection(name)
+ return True
+ except CollectionNotFound:
+ return False
+
+ @property
+ def index(self) -> Optional[str]:
+ """
+ PRIVATE
+
+ Note:
+ The `index` property is private and expected to undergo refactoring.
+ Do not rely on it's output.
+
+ Retrieves the SQL name of the collection's vector index, if it exists.
+
+ Returns:
+ Optional[str]: The name of the index, or None if no index exists.
+ """
+
+ if self._index is None:
+ query = text(
+ """
+ select
+ relname as table_name
+ from
+ pg_class pc
+ where
+ pc.relnamespace = 'vecs'::regnamespace
+ and relname ilike 'ix_vector%'
+ and pc.relkind = 'i'
+ """
+ )
+ with self.client.Session() as sess:
+ ix_name = sess.execute(query).scalar()
+ self._index = ix_name
+ return self._index
+
+ def is_indexed_for_measure(self, measure: IndexMeasure):
+ """
+ Checks if the collection is indexed for a specific measure.
+
+ Args:
+ measure (IndexMeasure): The measure to check for.
+
+ Returns:
+ bool: True if the collection is indexed for the measure, False otherwise.
+ """
+
+ index_name = self.index
+ if index_name is None:
+ return False
+
+ ops = INDEX_MEASURE_TO_OPS.get(measure)
+ if ops is None:
+ return False
+
+ if ops in index_name:
+ return True
+
+ return False
+
+ def create_index(
+ self,
+ measure: IndexMeasure = IndexMeasure.cosine_distance,
+ method: IndexMethod = IndexMethod.auto,
+ index_arguments: Optional[
+ Union[IndexArgsIVFFlat, IndexArgsHNSW]
+ ] = None,
+ replace=True,
+ ) -> None:
+ """
+ Creates an index for the collection.
+
+ Note:
+ When `vecs` creates an index on a pgvector column in PostgreSQL, it uses a multi-step
+ process that enables performant indexes to be built for large collections with low end
+ database hardware.
+
+ Those steps are:
+
+ - Creates a new table with a different name
+ - Randomly selects records from the existing table
+ - Inserts the random records from the existing table into the new table
+ - Creates the requested vector index on the new table
+ - Upserts all data from the existing table into the new table
+ - Drops the existing table
+ - Renames the new table to the existing tables name
+
+ If you create dependencies (like views) on the table that underpins
+ a `vecs.Collection` the `create_index` step may require you to drop those dependencies before
+ it will succeed.
+
+ Args:
+ measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'.
+ method (IndexMethod, optional): The indexing method to use. Defaults to 'auto'.
+ index_arguments: (IndexArgsIVFFlat | IndexArgsHNSW, optional): Index type specific arguments
+ replace (bool, optional): Whether to replace the existing index. Defaults to True.
+
+ Raises:
+ ArgError: If an invalid index method is used, or if *replace* is False and an index already exists.
+ """
+
+ if method not in (
+ IndexMethod.ivfflat,
+ IndexMethod.hnsw,
+ IndexMethod.auto,
+ ):
+ raise ArgError("invalid index method")
+
+ if index_arguments:
+ # Disallow case where user submits index arguments but uses the
+ # IndexMethod.auto index (index build arguments should only be
+ # used with a specific index)
+ if method == IndexMethod.auto:
+ raise ArgError(
+ "Index build parameters are not allowed when using the IndexMethod.auto index."
+ )
+ # Disallow case where user specifies one index type but submits
+ # index build arguments for the other index type
+ if (
+ isinstance(index_arguments, IndexArgsHNSW)
+ and method != IndexMethod.hnsw
+ ) or (
+ isinstance(index_arguments, IndexArgsIVFFlat)
+ and method != IndexMethod.ivfflat
+ ):
+ raise ArgError(
+ f"{index_arguments.__class__.__name__} build parameters were supplied but {method} index was specified."
+ )
+
+ if method == IndexMethod.auto:
+ if self.client._supports_hnsw():
+ method = IndexMethod.hnsw
+ else:
+ method = IndexMethod.ivfflat
+
+ if method == IndexMethod.hnsw and not self.client._supports_hnsw():
+ raise ArgError(
+ "HNSW Unavailable. Upgrade your pgvector installation to > 0.5.0 to enable HNSW support"
+ )
+
+ ops = INDEX_MEASURE_TO_OPS.get(measure)
+ if ops is None:
+ raise ArgError("Unknown index measure")
+
+ unique_string = str(uuid.uuid4()).replace("-", "_")[0:7]
+
+ with self.client.Session() as sess:
+ with sess.begin():
+ if self.index is not None:
+ if replace:
+ sess.execute(text(f'drop index vecs."{self.index}";'))
+ self._index = None
+ else:
+ raise ArgError(
+ "replace is set to False but an index exists"
+ )
+
+ if method == IndexMethod.ivfflat:
+ if not index_arguments:
+ n_records: int = sess.execute(func.count(self.table.c.id)).scalar() # type: ignore
+
+ n_lists = (
+ int(max(n_records / 1000, 30))
+ if n_records < 1_000_000
+ else int(math.sqrt(n_records))
+ )
+ else:
+ # The following mypy error is ignored because mypy
+ # complains that `index_arguments` is typed as a union
+ # of IndexArgsIVFFlat and IndexArgsHNSW types,
+ # which both don't necessarily contain the `n_lists`
+ # parameter, however we have validated that the
+ # correct type is being used above.
+ n_lists = index_arguments.n_lists # type: ignore
+
+ sess.execute(
+ text(
+ f"""
+ create index ix_{ops}_ivfflat_nl{n_lists}_{unique_string}
+ on vecs."{self.table.name}"
+ using ivfflat (vec {ops}) with (lists={n_lists})
+ """
+ )
+ )
+
+ if method == IndexMethod.hnsw:
+ if not index_arguments:
+ index_arguments = IndexArgsHNSW()
+
+ # See above for explanation of why the following lines
+ # are ignored
+ m = index_arguments.m # type: ignore
+ ef_construction = index_arguments.ef_construction # type: ignore
+
+ sess.execute(
+ text(
+ f"""
+ create index ix_{ops}_hnsw_m{m}_efc{ef_construction}_{unique_string}
+ on vecs."{self.table.name}"
+ using hnsw (vec {ops}) WITH (m={m}, ef_construction={ef_construction});
+ """
+ )
+ )
+
+ return None
+
+
+def build_filters(json_col: Column, filters: Dict):
+ """
+ Builds filters for SQL query based on provided dictionary.
+
+ Args:
+ json_col (Column): The column in the database table.
+ filters (Dict): The dictionary specifying filter conditions.
+
+ Raises:
+ FilterError: If filter conditions are not correctly formatted.
+
+ Returns:
+ The filter clause for the SQL query.
+ """
+ if not isinstance(filters, dict):
+ raise FilterError("filters must be a dict")
+
+ filter_clauses = []
+
+ for key, value in filters.items():
+ if not isinstance(key, str):
+ raise FilterError("*filters* keys must be strings")
+
+ if isinstance(value, dict):
+ if len(value) > 1:
+ raise FilterError("only one operator permitted per key")
+ for operator, clause in value.items():
+ if operator not in (
+ "$eq",
+ "$ne",
+ "$lt",
+ "$lte",
+ "$gt",
+ "$gte",
+ "$in",
+ ):
+ raise FilterError("unknown operator")
+
+ if operator == "$eq" and not hasattr(clause, "__len__"):
+ contains_value = cast({key: clause}, postgresql.JSONB)
+ filter_clauses.append(json_col.op("@>")(contains_value))
+ elif operator == "$in":
+ if not isinstance(clause, list):
+ raise FilterError(
+ "argument to $in filter must be a list"
+ )
+ for elem in clause:
+ if not isinstance(elem, (int, str, float)):
+ raise FilterError(
+ "argument to $in filter must be a list of scalars"
+ )
+ contains_value = [
+ cast(elem, postgresql.JSONB) for elem in clause
+ ]
+ filter_clauses.append(
+ json_col.op("->")(key).in_(contains_value)
+ )
+ else:
+ matches_value = cast(clause, postgresql.JSONB)
+ if operator == "$eq":
+ filter_clauses.append(
+ json_col.op("->")(key) == matches_value
+ )
+ elif operator == "$ne":
+ filter_clauses.append(
+ json_col.op("->")(key) != matches_value
+ )
+ elif operator == "$lt":
+ filter_clauses.append(
+ json_col.op("->")(key) < matches_value
+ )
+ elif operator == "$lte":
+ filter_clauses.append(
+ json_col.op("->")(key) <= matches_value
+ )
+ elif operator == "$gt":
+ filter_clauses.append(
+ json_col.op("->")(key) > matches_value
+ )
+ elif operator == "$gte":
+ filter_clauses.append(
+ json_col.op("->")(key) >= matches_value
+ )
+ else:
+ raise Unreachable()
+ else:
+ raise FilterError("Filter value must be a dict with an operator")
+
+ if len(filter_clauses) == 1:
+ return filter_clauses[0]
+ else:
+ return and_(*filter_clauses)
+
+
+def build_table(name: str, meta: MetaData, dimension: int) -> Table:
+ """
+ PRIVATE
+
+ Builds a SQLAlchemy model underpinning a `vecs.Collection`.
+
+ Args:
+ name (str): The name of the table.
+ meta (MetaData): MetaData instance associated with the SQL database.
+ dimension: The dimension of the vectors in the collection.
+
+ Returns:
+ Table: The constructed SQL table.
+ """
+ return Table(
+ name,
+ meta,
+ Column("id", String, primary_key=True),
+ Column("vec", Vector(dimension), nullable=False),
+ Column(
+ "metadata",
+ postgresql.JSONB,
+ server_default=text("'{}'::jsonb"),
+ nullable=False,
+ ),
+ extend_existing=True,
+ )
diff --git a/R2R/r2r/vecs/exc.py b/R2R/r2r/vecs/exc.py
new file mode 100755
index 00000000..0ae4500c
--- /dev/null
+++ b/R2R/r2r/vecs/exc.py
@@ -0,0 +1,83 @@
+__all__ = [
+ "VecsException",
+ "CollectionAlreadyExists",
+ "CollectionNotFound",
+ "ArgError",
+ "FilterError",
+ "IndexNotFound",
+ "Unreachable",
+]
+
+
+class VecsException(Exception):
+ """
+ Base exception class for the 'vecs' package.
+ All custom exceptions in the 'vecs' package should derive from this class.
+ """
+
+ ...
+
+
+class CollectionAlreadyExists(VecsException):
+ """
+ Exception raised when attempting to create a collection that already exists.
+ """
+
+ ...
+
+
+class CollectionNotFound(VecsException):
+ """
+ Exception raised when attempting to access or manipulate a collection that does not exist.
+ """
+
+ ...
+
+
+class ArgError(VecsException):
+ """
+ Exception raised for invalid arguments when calling a method.
+ """
+
+ ...
+
+
+class MismatchedDimension(ArgError):
+ """
+ Exception raised when multiple sources of truth for a collection's embedding dimension do not match.
+ """
+
+ ...
+
+
+class FilterError(VecsException):
+ """
+ Exception raised when there's an error related to filter usage in a query.
+ """
+
+ ...
+
+
+class IndexNotFound(VecsException):
+ """
+ Exception raised when attempting to access an index that does not exist.
+ """
+
+ ...
+
+
+class Unreachable(VecsException):
+ """
+ Exception raised when an unreachable part of the code is executed.
+ This is typically used for error handling in cases that should be logically impossible.
+ """
+
+ ...
+
+
+class MissingDependency(VecsException, ImportError):
+ """
+ Exception raised when attempting to access a feature that requires an optional dependency when the optional dependency is not present.
+ """
+
+ ...