diff options
Diffstat (limited to 'R2R/r2r/providers/vector_dbs')
-rwxr-xr-x | R2R/r2r/providers/vector_dbs/__init__.py | 5 | ||||
-rwxr-xr-x | R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py | 610 |
2 files changed, 615 insertions, 0 deletions
diff --git a/R2R/r2r/providers/vector_dbs/__init__.py b/R2R/r2r/providers/vector_dbs/__init__.py new file mode 100755 index 00000000..38ea0890 --- /dev/null +++ b/R2R/r2r/providers/vector_dbs/__init__.py @@ -0,0 +1,5 @@ +from .pgvector.pgvector_db import PGVectorDB + +__all__ = [ + "PGVectorDB", +] diff --git a/R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py b/R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py new file mode 100755 index 00000000..8cf728d1 --- /dev/null +++ b/R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py @@ -0,0 +1,610 @@ +import json +import logging +import os +import time +from typing import Literal, Optional, Union + +from sqlalchemy import exc, text +from sqlalchemy.engine.url import make_url + +from r2r.base import ( + DocumentInfo, + UserStats, + VectorDBConfig, + VectorDBProvider, + VectorEntry, + VectorSearchResult, +) +from r2r.vecs.client import Client +from r2r.vecs.collection import Collection + +logger = logging.getLogger(__name__) + + +class PGVectorDB(VectorDBProvider): + def __init__(self, config: VectorDBConfig) -> None: + super().__init__(config) + try: + import r2r.vecs + except ImportError: + raise ValueError( + f"Error, PGVectorDB requires the vecs library. Please run `pip install vecs`." + ) + + # Check if a complete Postgres URI is provided + postgres_uri = self.config.extra_fields.get( + "postgres_uri" + ) or os.getenv("POSTGRES_URI") + + if postgres_uri: + # Log loudly that Postgres URI is being used + logger.warning("=" * 50) + logger.warning( + "ATTENTION: Using provided Postgres URI for connection" + ) + logger.warning("=" * 50) + + # Validate and use the provided URI + try: + parsed_uri = make_url(postgres_uri) + if not all([parsed_uri.username, parsed_uri.database]): + raise ValueError( + "The provided Postgres URI is missing required components." + ) + DB_CONNECTION = postgres_uri + + # Log the sanitized URI (without password) + sanitized_uri = parsed_uri.set(password="*****") + logger.info(f"Connecting using URI: {sanitized_uri}") + except Exception as e: + raise ValueError(f"Invalid Postgres URI provided: {e}") + else: + # Fall back to existing logic for individual connection parameters + user = self.config.extra_fields.get("user", None) or os.getenv( + "POSTGRES_USER" + ) + password = self.config.extra_fields.get( + "password", None + ) or os.getenv("POSTGRES_PASSWORD") + host = self.config.extra_fields.get("host", None) or os.getenv( + "POSTGRES_HOST" + ) + port = self.config.extra_fields.get("port", None) or os.getenv( + "POSTGRES_PORT" + ) + db_name = self.config.extra_fields.get( + "db_name", None + ) or os.getenv("POSTGRES_DBNAME") + + if not all([user, password, host, db_name]): + raise ValueError( + "Error, please set the POSTGRES_USER, POSTGRES_PASSWORD, POSTGRES_HOST, POSTGRES_DBNAME environment variables or provide them in the config." + ) + + # Check if it's a Unix socket connection + if host.startswith("/") and not port: + DB_CONNECTION = ( + f"postgresql://{user}:{password}@/{db_name}?host={host}" + ) + logger.info("Using Unix socket connection") + else: + DB_CONNECTION = ( + f"postgresql://{user}:{password}@{host}:{port}/{db_name}" + ) + logger.info("Using TCP connection") + + # The rest of the initialization remains the same + try: + self.vx: Client = r2r.vecs.create_client(DB_CONNECTION) + except Exception as e: + raise ValueError( + f"Error {e} occurred while attempting to connect to the pgvector provider with {DB_CONNECTION}." + ) + + self.collection_name = self.config.extra_fields.get( + "vecs_collection" + ) or os.getenv("POSTGRES_VECS_COLLECTION") + if not self.collection_name: + raise ValueError( + "Error, please set a valid POSTGRES_VECS_COLLECTION environment variable or set a 'vecs_collection' in the 'vector_database' settings of your `config.json`." + ) + + self.collection: Optional[Collection] = None + + logger.info( + f"Successfully initialized PGVectorDB with collection: {self.collection_name}" + ) + + def initialize_collection(self, dimension: int) -> None: + self.collection = self.vx.get_or_create_collection( + name=self.collection_name, dimension=dimension + ) + self._create_document_info_table() + self._create_hybrid_search_function() + + def _create_document_info_table(self): + with self.vx.Session() as sess: + with sess.begin(): + try: + # Enable uuid-ossp extension + sess.execute( + text('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') + ) + except exc.ProgrammingError as e: + logger.error(f"Error enabling uuid-ossp extension: {e}") + raise + + # Create the table if it doesn't exist + create_table_query = f""" + CREATE TABLE IF NOT EXISTS document_info_"{self.collection_name}" ( + document_id UUID PRIMARY KEY, + title TEXT, + user_id UUID NULL, + version TEXT, + size_in_bytes INT, + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW(), + metadata JSONB, + status TEXT + ); + """ + sess.execute(text(create_table_query)) + + # Add the new column if it doesn't exist + add_column_query = f""" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_name = 'document_info_"{self.collection_name}"' + AND column_name = 'status' + ) THEN + ALTER TABLE "document_info_{self.collection_name}" + ADD COLUMN status TEXT DEFAULT 'processing'; + END IF; + END $$; + """ + sess.execute(text(add_column_query)) + + sess.commit() + + def _create_hybrid_search_function(self): + hybrid_search_function = f""" + CREATE OR REPLACE FUNCTION hybrid_search_{self.collection_name}( + query_text TEXT, + query_embedding VECTOR(512), + match_limit INT, + full_text_weight FLOAT = 1, + semantic_weight FLOAT = 1, + rrf_k INT = 50, + filter_condition JSONB = NULL + ) + RETURNS SETOF vecs."{self.collection_name}" + LANGUAGE sql + AS $$ + WITH full_text AS ( + SELECT + id, + ROW_NUMBER() OVER (ORDER BY ts_rank(to_tsvector('english', metadata->>'text'), websearch_to_tsquery(query_text)) DESC) AS rank_ix + FROM vecs."{self.collection_name}" + WHERE to_tsvector('english', metadata->>'text') @@ websearch_to_tsquery(query_text) + AND (filter_condition IS NULL OR (metadata @> filter_condition)) + ORDER BY rank_ix + LIMIT LEAST(match_limit, 30) * 2 + ), + semantic AS ( + SELECT + id, + ROW_NUMBER() OVER (ORDER BY vec <#> query_embedding) AS rank_ix + FROM vecs."{self.collection_name}" + WHERE filter_condition IS NULL OR (metadata @> filter_condition) + ORDER BY rank_ix + LIMIT LEAST(match_limit, 30) * 2 + ) + SELECT + vecs."{self.collection_name}".* + FROM + full_text + FULL OUTER JOIN semantic + ON full_text.id = semantic.id + JOIN vecs."{self.collection_name}" + ON vecs."{self.collection_name}".id = COALESCE(full_text.id, semantic.id) + ORDER BY + COALESCE(1.0 / (rrf_k + full_text.rank_ix), 0.0) * full_text_weight + + COALESCE(1.0 / (rrf_k + semantic.rank_ix), 0.0) * semantic_weight + DESC + LIMIT + LEAST(match_limit, 30); + $$; + """ + retry_attempts = 5 + for attempt in range(retry_attempts): + try: + with self.vx.Session() as sess: + # Acquire an advisory lock + sess.execute(text("SELECT pg_advisory_lock(123456789)")) + try: + sess.execute(text(hybrid_search_function)) + sess.commit() + finally: + # Release the advisory lock + sess.execute( + text("SELECT pg_advisory_unlock(123456789)") + ) + break # Break the loop if successful + except exc.InternalError as e: + if "tuple concurrently updated" in str(e): + time.sleep(2**attempt) # Exponential backoff + else: + raise # Re-raise the exception if it's not a concurrency issue + else: + raise RuntimeError( + "Failed to create hybrid search function after multiple attempts" + ) + + def copy(self, entry: VectorEntry, commit=True) -> None: + if self.collection is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `copy`." + ) + + serializeable_entry = entry.to_serializable() + + self.collection.copy( + records=[ + ( + serializeable_entry["id"], + serializeable_entry["vector"], + serializeable_entry["metadata"], + ) + ] + ) + + def copy_entries( + self, entries: list[VectorEntry], commit: bool = True + ) -> None: + if self.collection is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `copy_entries`." + ) + + self.collection.copy( + records=[ + ( + str(entry.id), + entry.vector.data, + entry.to_serializable()["metadata"], + ) + for entry in entries + ] + ) + + def upsert(self, entry: VectorEntry, commit=True) -> None: + if self.collection is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `upsert`." + ) + + self.collection.upsert( + records=[ + ( + str(entry.id), + entry.vector.data, + entry.to_serializable()["metadata"], + ) + ] + ) + + def upsert_entries( + self, entries: list[VectorEntry], commit: bool = True + ) -> None: + if self.collection is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `upsert_entries`." + ) + + self.collection.upsert( + records=[ + ( + str(entry.id), + entry.vector.data, + entry.to_serializable()["metadata"], + ) + for entry in entries + ] + ) + + def search( + self, + query_vector: list[float], + filters: dict[str, Union[bool, int, str]] = {}, + limit: int = 10, + *args, + **kwargs, + ) -> list[VectorSearchResult]: + if self.collection is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `search`." + ) + measure = kwargs.get("measure", "cosine_distance") + mapped_filters = { + key: {"$eq": value} for key, value in filters.items() + } + + return [ + VectorSearchResult(id=ele[0], score=float(1 - ele[1]), metadata=ele[2]) # type: ignore + for ele in self.collection.query( + data=query_vector, + limit=limit, + filters=mapped_filters, + measure=measure, + include_value=True, + include_metadata=True, + ) + ] + + def hybrid_search( + self, + query_text: str, + query_vector: list[float], + limit: int = 10, + filters: Optional[dict[str, Union[bool, int, str]]] = None, + # Hybrid search parameters + full_text_weight: float = 1.0, + semantic_weight: float = 1.0, + rrf_k: int = 20, # typical value is ~2x the number of results you want + *args, + **kwargs, + ) -> list[VectorSearchResult]: + if self.collection is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `hybrid_search`." + ) + + # Convert filters to a JSON-compatible format + filter_condition = None + if filters: + filter_condition = json.dumps(filters) + + query = text( + f""" + SELECT * FROM hybrid_search_{self.collection_name}( + cast(:query_text as TEXT), cast(:query_embedding as VECTOR), cast(:match_limit as INT), + cast(:full_text_weight as FLOAT), cast(:semantic_weight as FLOAT), cast(:rrf_k as INT), + cast(:filter_condition as JSONB) + ) + """ + ) + + params = { + "query_text": str(query_text), + "query_embedding": list(query_vector), + "match_limit": limit, + "full_text_weight": full_text_weight, + "semantic_weight": semantic_weight, + "rrf_k": rrf_k, + "filter_condition": filter_condition, + } + + with self.vx.Session() as session: + result = session.execute(query, params).fetchall() + return [ + VectorSearchResult(id=row[0], score=1.0, metadata=row[-1]) + for row in result + ] + + def create_index(self, index_type, column_name, index_options): + pass + + def delete_by_metadata( + self, + metadata_fields: list[str], + metadata_values: list[Union[bool, int, str]], + logic: Literal["AND", "OR"] = "AND", + ) -> list[str]: + if logic == "OR": + raise ValueError( + "OR logic is still being tested before official support for `delete_by_metadata` in pgvector." + ) + if self.collection is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `delete_by_metadata`." + ) + + if len(metadata_fields) != len(metadata_values): + raise ValueError( + "The number of metadata fields must match the number of metadata values." + ) + + # Construct the filter + if logic == "AND": + filters = { + k: {"$eq": v} for k, v in zip(metadata_fields, metadata_values) + } + else: # OR logic + # TODO - Test 'or' logic and remove check above + filters = { + "$or": [ + {k: {"$eq": v}} + for k, v in zip(metadata_fields, metadata_values) + ] + } + return self.collection.delete(filters=filters) + + def get_metadatas( + self, + metadata_fields: list[str], + filter_field: Optional[str] = None, + filter_value: Optional[Union[bool, int, str]] = None, + ) -> list[dict]: + if self.collection is None: + raise ValueError( + "Please call `initialize_collection` before attempting to run `get_metadatas`." + ) + + results = {tuple(metadata_fields): {}} + for field in metadata_fields: + unique_values = self.collection.get_unique_metadata_values( + field=field, + filter_field=filter_field, + filter_value=filter_value, + ) + for value in unique_values: + if value not in results: + results[value] = {} + results[value][field] = value + + return [ + results[key] for key in results if key != tuple(metadata_fields) + ] + + def upsert_documents_overview( + self, documents_overview: list[DocumentInfo] + ) -> None: + for document_info in documents_overview: + db_entry = document_info.convert_to_db_entry() + + # Convert 'None' string to None type for user_id + if db_entry["user_id"] == "None": + db_entry["user_id"] = None + + query = text( + f""" + INSERT INTO "document_info_{self.collection_name}" (document_id, title, user_id, version, created_at, updated_at, size_in_bytes, metadata, status) + VALUES (:document_id, :title, :user_id, :version, :created_at, :updated_at, :size_in_bytes, :metadata, :status) + ON CONFLICT (document_id) DO UPDATE SET + title = EXCLUDED.title, + user_id = EXCLUDED.user_id, + version = EXCLUDED.version, + updated_at = EXCLUDED.updated_at, + size_in_bytes = EXCLUDED.size_in_bytes, + metadata = EXCLUDED.metadata, + status = EXCLUDED.status; + """ + ) + with self.vx.Session() as sess: + sess.execute(query, db_entry) + sess.commit() + + def delete_from_documents_overview( + self, document_id: str, version: Optional[str] = None + ) -> None: + query = f""" + DELETE FROM "document_info_{self.collection_name}" + WHERE document_id = :document_id + """ + params = {"document_id": document_id} + + if version is not None: + query += " AND version = :version" + params["version"] = version + + with self.vx.Session() as sess: + with sess.begin(): + sess.execute(text(query), params) + sess.commit() + + def get_documents_overview( + self, + filter_document_ids: Optional[list[str]] = None, + filter_user_ids: Optional[list[str]] = None, + ): + conditions = [] + params = {} + + if filter_document_ids: + placeholders = ", ".join( + f":doc_id_{i}" for i in range(len(filter_document_ids)) + ) + conditions.append(f"document_id IN ({placeholders})") + params.update( + { + f"doc_id_{i}": str(document_id) + for i, document_id in enumerate(filter_document_ids) + } + ) + if filter_user_ids: + placeholders = ", ".join( + f":user_id_{i}" for i in range(len(filter_user_ids)) + ) + conditions.append(f"user_id IN ({placeholders})") + params.update( + { + f"user_id_{i}": str(user_id) + for i, user_id in enumerate(filter_user_ids) + } + ) + + query = f""" + SELECT document_id, title, user_id, version, size_in_bytes, created_at, updated_at, metadata, status + FROM "document_info_{self.collection_name}" + """ + if conditions: + query += " WHERE " + " AND ".join(conditions) + + with self.vx.Session() as sess: + results = sess.execute(text(query), params).fetchall() + return [ + DocumentInfo( + document_id=row[0], + title=row[1], + user_id=row[2], + version=row[3], + size_in_bytes=row[4], + created_at=row[5], + updated_at=row[6], + metadata=row[7], + status=row[8], + ) + for row in results + ] + + def get_document_chunks(self, document_id: str) -> list[dict]: + if not self.collection: + raise ValueError("Collection is not initialized.") + + table_name = self.collection.table.name + query = text( + f""" + SELECT metadata + FROM vecs."{table_name}" + WHERE metadata->>'document_id' = :document_id + ORDER BY CAST(metadata->>'chunk_order' AS INTEGER) + """ + ) + + params = {"document_id": document_id} + + with self.vx.Session() as sess: + results = sess.execute(query, params).fetchall() + return [result[0] for result in results] + + def get_users_overview(self, user_ids: Optional[list[str]] = None): + user_ids_condition = "" + params = {} + if user_ids: + user_ids_condition = "WHERE user_id IN :user_ids" + params["user_ids"] = tuple( + map(str, user_ids) + ) # Convert UUIDs to strings + + query = f""" + SELECT user_id, COUNT(document_id) AS num_files, SUM(size_in_bytes) AS total_size_in_bytes, ARRAY_AGG(document_id) AS document_ids + FROM "document_info_{self.collection_name}" + {user_ids_condition} + GROUP BY user_id + """ + + with self.vx.Session() as sess: + results = sess.execute(text(query), params).fetchall() + return [ + UserStats( + user_id=row[0], + num_files=row[1], + total_size_in_bytes=row[2], + document_ids=row[3], + ) + for row in results + if row[0] is not None + ] |