import json import logging import os import time from typing import Literal, Optional, Union from sqlalchemy import exc, text from sqlalchemy.engine.url import make_url from r2r.base import ( DocumentInfo, UserStats, VectorDBConfig, VectorDBProvider, VectorEntry, VectorSearchResult, ) from r2r.vecs.client import Client from r2r.vecs.collection import Collection logger = logging.getLogger(__name__) class PGVectorDB(VectorDBProvider): def __init__(self, config: VectorDBConfig) -> None: super().__init__(config) try: import r2r.vecs except ImportError: raise ValueError( f"Error, PGVectorDB requires the vecs library. Please run `pip install vecs`." ) # Check if a complete Postgres URI is provided postgres_uri = self.config.extra_fields.get( "postgres_uri" ) or os.getenv("POSTGRES_URI") if postgres_uri: # Log loudly that Postgres URI is being used logger.warning("=" * 50) logger.warning( "ATTENTION: Using provided Postgres URI for connection" ) logger.warning("=" * 50) # Validate and use the provided URI try: parsed_uri = make_url(postgres_uri) if not all([parsed_uri.username, parsed_uri.database]): raise ValueError( "The provided Postgres URI is missing required components." ) DB_CONNECTION = postgres_uri # Log the sanitized URI (without password) sanitized_uri = parsed_uri.set(password="*****") logger.info(f"Connecting using URI: {sanitized_uri}") except Exception as e: raise ValueError(f"Invalid Postgres URI provided: {e}") else: # Fall back to existing logic for individual connection parameters user = self.config.extra_fields.get("user", None) or os.getenv( "POSTGRES_USER" ) password = self.config.extra_fields.get( "password", None ) or os.getenv("POSTGRES_PASSWORD") host = self.config.extra_fields.get("host", None) or os.getenv( "POSTGRES_HOST" ) port = self.config.extra_fields.get("port", None) or os.getenv( "POSTGRES_PORT" ) db_name = self.config.extra_fields.get( "db_name", None ) or os.getenv("POSTGRES_DBNAME") if not all([user, password, host, db_name]): raise ValueError( "Error, please set the POSTGRES_USER, POSTGRES_PASSWORD, POSTGRES_HOST, POSTGRES_DBNAME environment variables or provide them in the config." ) # Check if it's a Unix socket connection if host.startswith("/") and not port: DB_CONNECTION = ( f"postgresql://{user}:{password}@/{db_name}?host={host}" ) logger.info("Using Unix socket connection") else: DB_CONNECTION = ( f"postgresql://{user}:{password}@{host}:{port}/{db_name}" ) logger.info("Using TCP connection") # The rest of the initialization remains the same try: self.vx: Client = r2r.vecs.create_client(DB_CONNECTION) except Exception as e: raise ValueError( f"Error {e} occurred while attempting to connect to the pgvector provider with {DB_CONNECTION}." ) self.collection_name = self.config.extra_fields.get( "vecs_collection" ) or os.getenv("POSTGRES_VECS_COLLECTION") if not self.collection_name: raise ValueError( "Error, please set a valid POSTGRES_VECS_COLLECTION environment variable or set a 'vecs_collection' in the 'vector_database' settings of your `config.json`." ) self.collection: Optional[Collection] = None logger.info( f"Successfully initialized PGVectorDB with collection: {self.collection_name}" ) def initialize_collection(self, dimension: int) -> None: self.collection = self.vx.get_or_create_collection( name=self.collection_name, dimension=dimension ) self._create_document_info_table() self._create_hybrid_search_function() def _create_document_info_table(self): with self.vx.Session() as sess: with sess.begin(): try: # Enable uuid-ossp extension sess.execute( text('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') ) except exc.ProgrammingError as e: logger.error(f"Error enabling uuid-ossp extension: {e}") raise # Create the table if it doesn't exist create_table_query = f""" CREATE TABLE IF NOT EXISTS document_info_"{self.collection_name}" ( document_id UUID PRIMARY KEY, title TEXT, user_id UUID NULL, version TEXT, size_in_bytes INT, created_at TIMESTAMPTZ DEFAULT NOW(), updated_at TIMESTAMPTZ DEFAULT NOW(), metadata JSONB, status TEXT ); """ sess.execute(text(create_table_query)) # Add the new column if it doesn't exist add_column_query = f""" DO $$ BEGIN IF NOT EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_name = 'document_info_"{self.collection_name}"' AND column_name = 'status' ) THEN ALTER TABLE "document_info_{self.collection_name}" ADD COLUMN status TEXT DEFAULT 'processing'; END IF; END $$; """ sess.execute(text(add_column_query)) sess.commit() def _create_hybrid_search_function(self): hybrid_search_function = f""" CREATE OR REPLACE FUNCTION hybrid_search_{self.collection_name}( query_text TEXT, query_embedding VECTOR(512), match_limit INT, full_text_weight FLOAT = 1, semantic_weight FLOAT = 1, rrf_k INT = 50, filter_condition JSONB = NULL ) RETURNS SETOF vecs."{self.collection_name}" LANGUAGE sql AS $$ WITH full_text AS ( SELECT id, ROW_NUMBER() OVER (ORDER BY ts_rank(to_tsvector('english', metadata->>'text'), websearch_to_tsquery(query_text)) DESC) AS rank_ix FROM vecs."{self.collection_name}" WHERE to_tsvector('english', metadata->>'text') @@ websearch_to_tsquery(query_text) AND (filter_condition IS NULL OR (metadata @> filter_condition)) ORDER BY rank_ix LIMIT LEAST(match_limit, 30) * 2 ), semantic AS ( SELECT id, ROW_NUMBER() OVER (ORDER BY vec <#> query_embedding) AS rank_ix FROM vecs."{self.collection_name}" WHERE filter_condition IS NULL OR (metadata @> filter_condition) ORDER BY rank_ix LIMIT LEAST(match_limit, 30) * 2 ) SELECT vecs."{self.collection_name}".* FROM full_text FULL OUTER JOIN semantic ON full_text.id = semantic.id JOIN vecs."{self.collection_name}" ON vecs."{self.collection_name}".id = COALESCE(full_text.id, semantic.id) ORDER BY COALESCE(1.0 / (rrf_k + full_text.rank_ix), 0.0) * full_text_weight + COALESCE(1.0 / (rrf_k + semantic.rank_ix), 0.0) * semantic_weight DESC LIMIT LEAST(match_limit, 30); $$; """ retry_attempts = 5 for attempt in range(retry_attempts): try: with self.vx.Session() as sess: # Acquire an advisory lock sess.execute(text("SELECT pg_advisory_lock(123456789)")) try: sess.execute(text(hybrid_search_function)) sess.commit() finally: # Release the advisory lock sess.execute( text("SELECT pg_advisory_unlock(123456789)") ) break # Break the loop if successful except exc.InternalError as e: if "tuple concurrently updated" in str(e): time.sleep(2**attempt) # Exponential backoff else: raise # Re-raise the exception if it's not a concurrency issue else: raise RuntimeError( "Failed to create hybrid search function after multiple attempts" ) def copy(self, entry: VectorEntry, commit=True) -> None: if self.collection is None: raise ValueError( "Please call `initialize_collection` before attempting to run `copy`." ) serializeable_entry = entry.to_serializable() self.collection.copy( records=[ ( serializeable_entry["id"], serializeable_entry["vector"], serializeable_entry["metadata"], ) ] ) def copy_entries( self, entries: list[VectorEntry], commit: bool = True ) -> None: if self.collection is None: raise ValueError( "Please call `initialize_collection` before attempting to run `copy_entries`." ) self.collection.copy( records=[ ( str(entry.id), entry.vector.data, entry.to_serializable()["metadata"], ) for entry in entries ] ) def upsert(self, entry: VectorEntry, commit=True) -> None: if self.collection is None: raise ValueError( "Please call `initialize_collection` before attempting to run `upsert`." ) self.collection.upsert( records=[ ( str(entry.id), entry.vector.data, entry.to_serializable()["metadata"], ) ] ) def upsert_entries( self, entries: list[VectorEntry], commit: bool = True ) -> None: if self.collection is None: raise ValueError( "Please call `initialize_collection` before attempting to run `upsert_entries`." ) self.collection.upsert( records=[ ( str(entry.id), entry.vector.data, entry.to_serializable()["metadata"], ) for entry in entries ] ) def search( self, query_vector: list[float], filters: dict[str, Union[bool, int, str]] = {}, limit: int = 10, *args, **kwargs, ) -> list[VectorSearchResult]: if self.collection is None: raise ValueError( "Please call `initialize_collection` before attempting to run `search`." ) measure = kwargs.get("measure", "cosine_distance") mapped_filters = { key: {"$eq": value} for key, value in filters.items() } return [ VectorSearchResult(id=ele[0], score=float(1 - ele[1]), metadata=ele[2]) # type: ignore for ele in self.collection.query( data=query_vector, limit=limit, filters=mapped_filters, measure=measure, include_value=True, include_metadata=True, ) ] def hybrid_search( self, query_text: str, query_vector: list[float], limit: int = 10, filters: Optional[dict[str, Union[bool, int, str]]] = None, # Hybrid search parameters full_text_weight: float = 1.0, semantic_weight: float = 1.0, rrf_k: int = 20, # typical value is ~2x the number of results you want *args, **kwargs, ) -> list[VectorSearchResult]: if self.collection is None: raise ValueError( "Please call `initialize_collection` before attempting to run `hybrid_search`." ) # Convert filters to a JSON-compatible format filter_condition = None if filters: filter_condition = json.dumps(filters) query = text( f""" SELECT * FROM hybrid_search_{self.collection_name}( cast(:query_text as TEXT), cast(:query_embedding as VECTOR), cast(:match_limit as INT), cast(:full_text_weight as FLOAT), cast(:semantic_weight as FLOAT), cast(:rrf_k as INT), cast(:filter_condition as JSONB) ) """ ) params = { "query_text": str(query_text), "query_embedding": list(query_vector), "match_limit": limit, "full_text_weight": full_text_weight, "semantic_weight": semantic_weight, "rrf_k": rrf_k, "filter_condition": filter_condition, } with self.vx.Session() as session: result = session.execute(query, params).fetchall() return [ VectorSearchResult(id=row[0], score=1.0, metadata=row[-1]) for row in result ] def create_index(self, index_type, column_name, index_options): pass def delete_by_metadata( self, metadata_fields: list[str], metadata_values: list[Union[bool, int, str]], logic: Literal["AND", "OR"] = "AND", ) -> list[str]: if logic == "OR": raise ValueError( "OR logic is still being tested before official support for `delete_by_metadata` in pgvector." ) if self.collection is None: raise ValueError( "Please call `initialize_collection` before attempting to run `delete_by_metadata`." ) if len(metadata_fields) != len(metadata_values): raise ValueError( "The number of metadata fields must match the number of metadata values." ) # Construct the filter if logic == "AND": filters = { k: {"$eq": v} for k, v in zip(metadata_fields, metadata_values) } else: # OR logic # TODO - Test 'or' logic and remove check above filters = { "$or": [ {k: {"$eq": v}} for k, v in zip(metadata_fields, metadata_values) ] } return self.collection.delete(filters=filters) def get_metadatas( self, metadata_fields: list[str], filter_field: Optional[str] = None, filter_value: Optional[Union[bool, int, str]] = None, ) -> list[dict]: if self.collection is None: raise ValueError( "Please call `initialize_collection` before attempting to run `get_metadatas`." ) results = {tuple(metadata_fields): {}} for field in metadata_fields: unique_values = self.collection.get_unique_metadata_values( field=field, filter_field=filter_field, filter_value=filter_value, ) for value in unique_values: if value not in results: results[value] = {} results[value][field] = value return [ results[key] for key in results if key != tuple(metadata_fields) ] def upsert_documents_overview( self, documents_overview: list[DocumentInfo] ) -> None: for document_info in documents_overview: db_entry = document_info.convert_to_db_entry() # Convert 'None' string to None type for user_id if db_entry["user_id"] == "None": db_entry["user_id"] = None query = text( f""" INSERT INTO "document_info_{self.collection_name}" (document_id, title, user_id, version, created_at, updated_at, size_in_bytes, metadata, status) VALUES (:document_id, :title, :user_id, :version, :created_at, :updated_at, :size_in_bytes, :metadata, :status) ON CONFLICT (document_id) DO UPDATE SET title = EXCLUDED.title, user_id = EXCLUDED.user_id, version = EXCLUDED.version, updated_at = EXCLUDED.updated_at, size_in_bytes = EXCLUDED.size_in_bytes, metadata = EXCLUDED.metadata, status = EXCLUDED.status; """ ) with self.vx.Session() as sess: sess.execute(query, db_entry) sess.commit() def delete_from_documents_overview( self, document_id: str, version: Optional[str] = None ) -> None: query = f""" DELETE FROM "document_info_{self.collection_name}" WHERE document_id = :document_id """ params = {"document_id": document_id} if version is not None: query += " AND version = :version" params["version"] = version with self.vx.Session() as sess: with sess.begin(): sess.execute(text(query), params) sess.commit() def get_documents_overview( self, filter_document_ids: Optional[list[str]] = None, filter_user_ids: Optional[list[str]] = None, ): conditions = [] params = {} if filter_document_ids: placeholders = ", ".join( f":doc_id_{i}" for i in range(len(filter_document_ids)) ) conditions.append(f"document_id IN ({placeholders})") params.update( { f"doc_id_{i}": str(document_id) for i, document_id in enumerate(filter_document_ids) } ) if filter_user_ids: placeholders = ", ".join( f":user_id_{i}" for i in range(len(filter_user_ids)) ) conditions.append(f"user_id IN ({placeholders})") params.update( { f"user_id_{i}": str(user_id) for i, user_id in enumerate(filter_user_ids) } ) query = f""" SELECT document_id, title, user_id, version, size_in_bytes, created_at, updated_at, metadata, status FROM "document_info_{self.collection_name}" """ if conditions: query += " WHERE " + " AND ".join(conditions) with self.vx.Session() as sess: results = sess.execute(text(query), params).fetchall() return [ DocumentInfo( document_id=row[0], title=row[1], user_id=row[2], version=row[3], size_in_bytes=row[4], created_at=row[5], updated_at=row[6], metadata=row[7], status=row[8], ) for row in results ] def get_document_chunks(self, document_id: str) -> list[dict]: if not self.collection: raise ValueError("Collection is not initialized.") table_name = self.collection.table.name query = text( f""" SELECT metadata FROM vecs."{table_name}" WHERE metadata->>'document_id' = :document_id ORDER BY CAST(metadata->>'chunk_order' AS INTEGER) """ ) params = {"document_id": document_id} with self.vx.Session() as sess: results = sess.execute(query, params).fetchall() return [result[0] for result in results] def get_users_overview(self, user_ids: Optional[list[str]] = None): user_ids_condition = "" params = {} if user_ids: user_ids_condition = "WHERE user_id IN :user_ids" params["user_ids"] = tuple( map(str, user_ids) ) # Convert UUIDs to strings query = f""" SELECT user_id, COUNT(document_id) AS num_files, SUM(size_in_bytes) AS total_size_in_bytes, ARRAY_AGG(document_id) AS document_ids FROM "document_info_{self.collection_name}" {user_ids_condition} GROUP BY user_id """ with self.vx.Session() as sess: results = sess.execute(text(query), params).fetchall() return [ UserStats( user_id=row[0], num_files=row[1], total_size_in_bytes=row[2], document_ids=row[3], ) for row in results if row[0] is not None ]