aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/providers/vector_dbs
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/providers/vector_dbs')
-rwxr-xr-xR2R/r2r/providers/vector_dbs/__init__.py5
-rwxr-xr-xR2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py610
2 files changed, 615 insertions, 0 deletions
diff --git a/R2R/r2r/providers/vector_dbs/__init__.py b/R2R/r2r/providers/vector_dbs/__init__.py
new file mode 100755
index 00000000..38ea0890
--- /dev/null
+++ b/R2R/r2r/providers/vector_dbs/__init__.py
@@ -0,0 +1,5 @@
+from .pgvector.pgvector_db import PGVectorDB
+
+__all__ = [
+ "PGVectorDB",
+]
diff --git a/R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py b/R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py
new file mode 100755
index 00000000..8cf728d1
--- /dev/null
+++ b/R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py
@@ -0,0 +1,610 @@
+import json
+import logging
+import os
+import time
+from typing import Literal, Optional, Union
+
+from sqlalchemy import exc, text
+from sqlalchemy.engine.url import make_url
+
+from r2r.base import (
+ DocumentInfo,
+ UserStats,
+ VectorDBConfig,
+ VectorDBProvider,
+ VectorEntry,
+ VectorSearchResult,
+)
+from r2r.vecs.client import Client
+from r2r.vecs.collection import Collection
+
+logger = logging.getLogger(__name__)
+
+
+class PGVectorDB(VectorDBProvider):
+ def __init__(self, config: VectorDBConfig) -> None:
+ super().__init__(config)
+ try:
+ import r2r.vecs
+ except ImportError:
+ raise ValueError(
+ f"Error, PGVectorDB requires the vecs library. Please run `pip install vecs`."
+ )
+
+ # Check if a complete Postgres URI is provided
+ postgres_uri = self.config.extra_fields.get(
+ "postgres_uri"
+ ) or os.getenv("POSTGRES_URI")
+
+ if postgres_uri:
+ # Log loudly that Postgres URI is being used
+ logger.warning("=" * 50)
+ logger.warning(
+ "ATTENTION: Using provided Postgres URI for connection"
+ )
+ logger.warning("=" * 50)
+
+ # Validate and use the provided URI
+ try:
+ parsed_uri = make_url(postgres_uri)
+ if not all([parsed_uri.username, parsed_uri.database]):
+ raise ValueError(
+ "The provided Postgres URI is missing required components."
+ )
+ DB_CONNECTION = postgres_uri
+
+ # Log the sanitized URI (without password)
+ sanitized_uri = parsed_uri.set(password="*****")
+ logger.info(f"Connecting using URI: {sanitized_uri}")
+ except Exception as e:
+ raise ValueError(f"Invalid Postgres URI provided: {e}")
+ else:
+ # Fall back to existing logic for individual connection parameters
+ user = self.config.extra_fields.get("user", None) or os.getenv(
+ "POSTGRES_USER"
+ )
+ password = self.config.extra_fields.get(
+ "password", None
+ ) or os.getenv("POSTGRES_PASSWORD")
+ host = self.config.extra_fields.get("host", None) or os.getenv(
+ "POSTGRES_HOST"
+ )
+ port = self.config.extra_fields.get("port", None) or os.getenv(
+ "POSTGRES_PORT"
+ )
+ db_name = self.config.extra_fields.get(
+ "db_name", None
+ ) or os.getenv("POSTGRES_DBNAME")
+
+ if not all([user, password, host, db_name]):
+ raise ValueError(
+ "Error, please set the POSTGRES_USER, POSTGRES_PASSWORD, POSTGRES_HOST, POSTGRES_DBNAME environment variables or provide them in the config."
+ )
+
+ # Check if it's a Unix socket connection
+ if host.startswith("/") and not port:
+ DB_CONNECTION = (
+ f"postgresql://{user}:{password}@/{db_name}?host={host}"
+ )
+ logger.info("Using Unix socket connection")
+ else:
+ DB_CONNECTION = (
+ f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
+ )
+ logger.info("Using TCP connection")
+
+ # The rest of the initialization remains the same
+ try:
+ self.vx: Client = r2r.vecs.create_client(DB_CONNECTION)
+ except Exception as e:
+ raise ValueError(
+ f"Error {e} occurred while attempting to connect to the pgvector provider with {DB_CONNECTION}."
+ )
+
+ self.collection_name = self.config.extra_fields.get(
+ "vecs_collection"
+ ) or os.getenv("POSTGRES_VECS_COLLECTION")
+ if not self.collection_name:
+ raise ValueError(
+ "Error, please set a valid POSTGRES_VECS_COLLECTION environment variable or set a 'vecs_collection' in the 'vector_database' settings of your `config.json`."
+ )
+
+ self.collection: Optional[Collection] = None
+
+ logger.info(
+ f"Successfully initialized PGVectorDB with collection: {self.collection_name}"
+ )
+
+ def initialize_collection(self, dimension: int) -> None:
+ self.collection = self.vx.get_or_create_collection(
+ name=self.collection_name, dimension=dimension
+ )
+ self._create_document_info_table()
+ self._create_hybrid_search_function()
+
+ def _create_document_info_table(self):
+ with self.vx.Session() as sess:
+ with sess.begin():
+ try:
+ # Enable uuid-ossp extension
+ sess.execute(
+ text('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
+ )
+ except exc.ProgrammingError as e:
+ logger.error(f"Error enabling uuid-ossp extension: {e}")
+ raise
+
+ # Create the table if it doesn't exist
+ create_table_query = f"""
+ CREATE TABLE IF NOT EXISTS document_info_"{self.collection_name}" (
+ document_id UUID PRIMARY KEY,
+ title TEXT,
+ user_id UUID NULL,
+ version TEXT,
+ size_in_bytes INT,
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW(),
+ metadata JSONB,
+ status TEXT
+ );
+ """
+ sess.execute(text(create_table_query))
+
+ # Add the new column if it doesn't exist
+ add_column_query = f"""
+ DO $$
+ BEGIN
+ IF NOT EXISTS (
+ SELECT 1
+ FROM information_schema.columns
+ WHERE table_name = 'document_info_"{self.collection_name}"'
+ AND column_name = 'status'
+ ) THEN
+ ALTER TABLE "document_info_{self.collection_name}"
+ ADD COLUMN status TEXT DEFAULT 'processing';
+ END IF;
+ END $$;
+ """
+ sess.execute(text(add_column_query))
+
+ sess.commit()
+
+ def _create_hybrid_search_function(self):
+ hybrid_search_function = f"""
+ CREATE OR REPLACE FUNCTION hybrid_search_{self.collection_name}(
+ query_text TEXT,
+ query_embedding VECTOR(512),
+ match_limit INT,
+ full_text_weight FLOAT = 1,
+ semantic_weight FLOAT = 1,
+ rrf_k INT = 50,
+ filter_condition JSONB = NULL
+ )
+ RETURNS SETOF vecs."{self.collection_name}"
+ LANGUAGE sql
+ AS $$
+ WITH full_text AS (
+ SELECT
+ id,
+ ROW_NUMBER() OVER (ORDER BY ts_rank(to_tsvector('english', metadata->>'text'), websearch_to_tsquery(query_text)) DESC) AS rank_ix
+ FROM vecs."{self.collection_name}"
+ WHERE to_tsvector('english', metadata->>'text') @@ websearch_to_tsquery(query_text)
+ AND (filter_condition IS NULL OR (metadata @> filter_condition))
+ ORDER BY rank_ix
+ LIMIT LEAST(match_limit, 30) * 2
+ ),
+ semantic AS (
+ SELECT
+ id,
+ ROW_NUMBER() OVER (ORDER BY vec <#> query_embedding) AS rank_ix
+ FROM vecs."{self.collection_name}"
+ WHERE filter_condition IS NULL OR (metadata @> filter_condition)
+ ORDER BY rank_ix
+ LIMIT LEAST(match_limit, 30) * 2
+ )
+ SELECT
+ vecs."{self.collection_name}".*
+ FROM
+ full_text
+ FULL OUTER JOIN semantic
+ ON full_text.id = semantic.id
+ JOIN vecs."{self.collection_name}"
+ ON vecs."{self.collection_name}".id = COALESCE(full_text.id, semantic.id)
+ ORDER BY
+ COALESCE(1.0 / (rrf_k + full_text.rank_ix), 0.0) * full_text_weight +
+ COALESCE(1.0 / (rrf_k + semantic.rank_ix), 0.0) * semantic_weight
+ DESC
+ LIMIT
+ LEAST(match_limit, 30);
+ $$;
+ """
+ retry_attempts = 5
+ for attempt in range(retry_attempts):
+ try:
+ with self.vx.Session() as sess:
+ # Acquire an advisory lock
+ sess.execute(text("SELECT pg_advisory_lock(123456789)"))
+ try:
+ sess.execute(text(hybrid_search_function))
+ sess.commit()
+ finally:
+ # Release the advisory lock
+ sess.execute(
+ text("SELECT pg_advisory_unlock(123456789)")
+ )
+ break # Break the loop if successful
+ except exc.InternalError as e:
+ if "tuple concurrently updated" in str(e):
+ time.sleep(2**attempt) # Exponential backoff
+ else:
+ raise # Re-raise the exception if it's not a concurrency issue
+ else:
+ raise RuntimeError(
+ "Failed to create hybrid search function after multiple attempts"
+ )
+
+ def copy(self, entry: VectorEntry, commit=True) -> None:
+ if self.collection is None:
+ raise ValueError(
+ "Please call `initialize_collection` before attempting to run `copy`."
+ )
+
+ serializeable_entry = entry.to_serializable()
+
+ self.collection.copy(
+ records=[
+ (
+ serializeable_entry["id"],
+ serializeable_entry["vector"],
+ serializeable_entry["metadata"],
+ )
+ ]
+ )
+
+ def copy_entries(
+ self, entries: list[VectorEntry], commit: bool = True
+ ) -> None:
+ if self.collection is None:
+ raise ValueError(
+ "Please call `initialize_collection` before attempting to run `copy_entries`."
+ )
+
+ self.collection.copy(
+ records=[
+ (
+ str(entry.id),
+ entry.vector.data,
+ entry.to_serializable()["metadata"],
+ )
+ for entry in entries
+ ]
+ )
+
+ def upsert(self, entry: VectorEntry, commit=True) -> None:
+ if self.collection is None:
+ raise ValueError(
+ "Please call `initialize_collection` before attempting to run `upsert`."
+ )
+
+ self.collection.upsert(
+ records=[
+ (
+ str(entry.id),
+ entry.vector.data,
+ entry.to_serializable()["metadata"],
+ )
+ ]
+ )
+
+ def upsert_entries(
+ self, entries: list[VectorEntry], commit: bool = True
+ ) -> None:
+ if self.collection is None:
+ raise ValueError(
+ "Please call `initialize_collection` before attempting to run `upsert_entries`."
+ )
+
+ self.collection.upsert(
+ records=[
+ (
+ str(entry.id),
+ entry.vector.data,
+ entry.to_serializable()["metadata"],
+ )
+ for entry in entries
+ ]
+ )
+
+ def search(
+ self,
+ query_vector: list[float],
+ filters: dict[str, Union[bool, int, str]] = {},
+ limit: int = 10,
+ *args,
+ **kwargs,
+ ) -> list[VectorSearchResult]:
+ if self.collection is None:
+ raise ValueError(
+ "Please call `initialize_collection` before attempting to run `search`."
+ )
+ measure = kwargs.get("measure", "cosine_distance")
+ mapped_filters = {
+ key: {"$eq": value} for key, value in filters.items()
+ }
+
+ return [
+ VectorSearchResult(id=ele[0], score=float(1 - ele[1]), metadata=ele[2]) # type: ignore
+ for ele in self.collection.query(
+ data=query_vector,
+ limit=limit,
+ filters=mapped_filters,
+ measure=measure,
+ include_value=True,
+ include_metadata=True,
+ )
+ ]
+
+ def hybrid_search(
+ self,
+ query_text: str,
+ query_vector: list[float],
+ limit: int = 10,
+ filters: Optional[dict[str, Union[bool, int, str]]] = None,
+ # Hybrid search parameters
+ full_text_weight: float = 1.0,
+ semantic_weight: float = 1.0,
+ rrf_k: int = 20, # typical value is ~2x the number of results you want
+ *args,
+ **kwargs,
+ ) -> list[VectorSearchResult]:
+ if self.collection is None:
+ raise ValueError(
+ "Please call `initialize_collection` before attempting to run `hybrid_search`."
+ )
+
+ # Convert filters to a JSON-compatible format
+ filter_condition = None
+ if filters:
+ filter_condition = json.dumps(filters)
+
+ query = text(
+ f"""
+ SELECT * FROM hybrid_search_{self.collection_name}(
+ cast(:query_text as TEXT), cast(:query_embedding as VECTOR), cast(:match_limit as INT),
+ cast(:full_text_weight as FLOAT), cast(:semantic_weight as FLOAT), cast(:rrf_k as INT),
+ cast(:filter_condition as JSONB)
+ )
+ """
+ )
+
+ params = {
+ "query_text": str(query_text),
+ "query_embedding": list(query_vector),
+ "match_limit": limit,
+ "full_text_weight": full_text_weight,
+ "semantic_weight": semantic_weight,
+ "rrf_k": rrf_k,
+ "filter_condition": filter_condition,
+ }
+
+ with self.vx.Session() as session:
+ result = session.execute(query, params).fetchall()
+ return [
+ VectorSearchResult(id=row[0], score=1.0, metadata=row[-1])
+ for row in result
+ ]
+
+ def create_index(self, index_type, column_name, index_options):
+ pass
+
+ def delete_by_metadata(
+ self,
+ metadata_fields: list[str],
+ metadata_values: list[Union[bool, int, str]],
+ logic: Literal["AND", "OR"] = "AND",
+ ) -> list[str]:
+ if logic == "OR":
+ raise ValueError(
+ "OR logic is still being tested before official support for `delete_by_metadata` in pgvector."
+ )
+ if self.collection is None:
+ raise ValueError(
+ "Please call `initialize_collection` before attempting to run `delete_by_metadata`."
+ )
+
+ if len(metadata_fields) != len(metadata_values):
+ raise ValueError(
+ "The number of metadata fields must match the number of metadata values."
+ )
+
+ # Construct the filter
+ if logic == "AND":
+ filters = {
+ k: {"$eq": v} for k, v in zip(metadata_fields, metadata_values)
+ }
+ else: # OR logic
+ # TODO - Test 'or' logic and remove check above
+ filters = {
+ "$or": [
+ {k: {"$eq": v}}
+ for k, v in zip(metadata_fields, metadata_values)
+ ]
+ }
+ return self.collection.delete(filters=filters)
+
+ def get_metadatas(
+ self,
+ metadata_fields: list[str],
+ filter_field: Optional[str] = None,
+ filter_value: Optional[Union[bool, int, str]] = None,
+ ) -> list[dict]:
+ if self.collection is None:
+ raise ValueError(
+ "Please call `initialize_collection` before attempting to run `get_metadatas`."
+ )
+
+ results = {tuple(metadata_fields): {}}
+ for field in metadata_fields:
+ unique_values = self.collection.get_unique_metadata_values(
+ field=field,
+ filter_field=filter_field,
+ filter_value=filter_value,
+ )
+ for value in unique_values:
+ if value not in results:
+ results[value] = {}
+ results[value][field] = value
+
+ return [
+ results[key] for key in results if key != tuple(metadata_fields)
+ ]
+
+ def upsert_documents_overview(
+ self, documents_overview: list[DocumentInfo]
+ ) -> None:
+ for document_info in documents_overview:
+ db_entry = document_info.convert_to_db_entry()
+
+ # Convert 'None' string to None type for user_id
+ if db_entry["user_id"] == "None":
+ db_entry["user_id"] = None
+
+ query = text(
+ f"""
+ INSERT INTO "document_info_{self.collection_name}" (document_id, title, user_id, version, created_at, updated_at, size_in_bytes, metadata, status)
+ VALUES (:document_id, :title, :user_id, :version, :created_at, :updated_at, :size_in_bytes, :metadata, :status)
+ ON CONFLICT (document_id) DO UPDATE SET
+ title = EXCLUDED.title,
+ user_id = EXCLUDED.user_id,
+ version = EXCLUDED.version,
+ updated_at = EXCLUDED.updated_at,
+ size_in_bytes = EXCLUDED.size_in_bytes,
+ metadata = EXCLUDED.metadata,
+ status = EXCLUDED.status;
+ """
+ )
+ with self.vx.Session() as sess:
+ sess.execute(query, db_entry)
+ sess.commit()
+
+ def delete_from_documents_overview(
+ self, document_id: str, version: Optional[str] = None
+ ) -> None:
+ query = f"""
+ DELETE FROM "document_info_{self.collection_name}"
+ WHERE document_id = :document_id
+ """
+ params = {"document_id": document_id}
+
+ if version is not None:
+ query += " AND version = :version"
+ params["version"] = version
+
+ with self.vx.Session() as sess:
+ with sess.begin():
+ sess.execute(text(query), params)
+ sess.commit()
+
+ def get_documents_overview(
+ self,
+ filter_document_ids: Optional[list[str]] = None,
+ filter_user_ids: Optional[list[str]] = None,
+ ):
+ conditions = []
+ params = {}
+
+ if filter_document_ids:
+ placeholders = ", ".join(
+ f":doc_id_{i}" for i in range(len(filter_document_ids))
+ )
+ conditions.append(f"document_id IN ({placeholders})")
+ params.update(
+ {
+ f"doc_id_{i}": str(document_id)
+ for i, document_id in enumerate(filter_document_ids)
+ }
+ )
+ if filter_user_ids:
+ placeholders = ", ".join(
+ f":user_id_{i}" for i in range(len(filter_user_ids))
+ )
+ conditions.append(f"user_id IN ({placeholders})")
+ params.update(
+ {
+ f"user_id_{i}": str(user_id)
+ for i, user_id in enumerate(filter_user_ids)
+ }
+ )
+
+ query = f"""
+ SELECT document_id, title, user_id, version, size_in_bytes, created_at, updated_at, metadata, status
+ FROM "document_info_{self.collection_name}"
+ """
+ if conditions:
+ query += " WHERE " + " AND ".join(conditions)
+
+ with self.vx.Session() as sess:
+ results = sess.execute(text(query), params).fetchall()
+ return [
+ DocumentInfo(
+ document_id=row[0],
+ title=row[1],
+ user_id=row[2],
+ version=row[3],
+ size_in_bytes=row[4],
+ created_at=row[5],
+ updated_at=row[6],
+ metadata=row[7],
+ status=row[8],
+ )
+ for row in results
+ ]
+
+ def get_document_chunks(self, document_id: str) -> list[dict]:
+ if not self.collection:
+ raise ValueError("Collection is not initialized.")
+
+ table_name = self.collection.table.name
+ query = text(
+ f"""
+ SELECT metadata
+ FROM vecs."{table_name}"
+ WHERE metadata->>'document_id' = :document_id
+ ORDER BY CAST(metadata->>'chunk_order' AS INTEGER)
+ """
+ )
+
+ params = {"document_id": document_id}
+
+ with self.vx.Session() as sess:
+ results = sess.execute(query, params).fetchall()
+ return [result[0] for result in results]
+
+ def get_users_overview(self, user_ids: Optional[list[str]] = None):
+ user_ids_condition = ""
+ params = {}
+ if user_ids:
+ user_ids_condition = "WHERE user_id IN :user_ids"
+ params["user_ids"] = tuple(
+ map(str, user_ids)
+ ) # Convert UUIDs to strings
+
+ query = f"""
+ SELECT user_id, COUNT(document_id) AS num_files, SUM(size_in_bytes) AS total_size_in_bytes, ARRAY_AGG(document_id) AS document_ids
+ FROM "document_info_{self.collection_name}"
+ {user_ids_condition}
+ GROUP BY user_id
+ """
+
+ with self.vx.Session() as sess:
+ results = sess.execute(text(query), params).fetchall()
+ return [
+ UserStats(
+ user_id=row[0],
+ num_files=row[1],
+ total_size_in_bytes=row[2],
+ document_ids=row[3],
+ )
+ for row in results
+ if row[0] is not None
+ ]