about summary refs log tree commit diff
path: root/R2R/r2r/providers/vector_dbs
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/providers/vector_dbs')
-rwxr-xr-xR2R/r2r/providers/vector_dbs/__init__.py5
-rwxr-xr-xR2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py610
2 files changed, 615 insertions, 0 deletions
diff --git a/R2R/r2r/providers/vector_dbs/__init__.py b/R2R/r2r/providers/vector_dbs/__init__.py
new file mode 100755
index 00000000..38ea0890
--- /dev/null
+++ b/R2R/r2r/providers/vector_dbs/__init__.py
@@ -0,0 +1,5 @@
+from .pgvector.pgvector_db import PGVectorDB
+
+__all__ = [
+    "PGVectorDB",
+]
diff --git a/R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py b/R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py
new file mode 100755
index 00000000..8cf728d1
--- /dev/null
+++ b/R2R/r2r/providers/vector_dbs/pgvector/pgvector_db.py
@@ -0,0 +1,610 @@
+import json
+import logging
+import os
+import time
+from typing import Literal, Optional, Union
+
+from sqlalchemy import exc, text
+from sqlalchemy.engine.url import make_url
+
+from r2r.base import (
+    DocumentInfo,
+    UserStats,
+    VectorDBConfig,
+    VectorDBProvider,
+    VectorEntry,
+    VectorSearchResult,
+)
+from r2r.vecs.client import Client
+from r2r.vecs.collection import Collection
+
+logger = logging.getLogger(__name__)
+
+
+class PGVectorDB(VectorDBProvider):
+    def __init__(self, config: VectorDBConfig) -> None:
+        super().__init__(config)
+        try:
+            import r2r.vecs
+        except ImportError:
+            raise ValueError(
+                f"Error, PGVectorDB requires the vecs library. Please run `pip install vecs`."
+            )
+
+        # Check if a complete Postgres URI is provided
+        postgres_uri = self.config.extra_fields.get(
+            "postgres_uri"
+        ) or os.getenv("POSTGRES_URI")
+
+        if postgres_uri:
+            # Log loudly that Postgres URI is being used
+            logger.warning("=" * 50)
+            logger.warning(
+                "ATTENTION: Using provided Postgres URI for connection"
+            )
+            logger.warning("=" * 50)
+
+            # Validate and use the provided URI
+            try:
+                parsed_uri = make_url(postgres_uri)
+                if not all([parsed_uri.username, parsed_uri.database]):
+                    raise ValueError(
+                        "The provided Postgres URI is missing required components."
+                    )
+                DB_CONNECTION = postgres_uri
+
+                # Log the sanitized URI (without password)
+                sanitized_uri = parsed_uri.set(password="*****")
+                logger.info(f"Connecting using URI: {sanitized_uri}")
+            except Exception as e:
+                raise ValueError(f"Invalid Postgres URI provided: {e}")
+        else:
+            # Fall back to existing logic for individual connection parameters
+            user = self.config.extra_fields.get("user", None) or os.getenv(
+                "POSTGRES_USER"
+            )
+            password = self.config.extra_fields.get(
+                "password", None
+            ) or os.getenv("POSTGRES_PASSWORD")
+            host = self.config.extra_fields.get("host", None) or os.getenv(
+                "POSTGRES_HOST"
+            )
+            port = self.config.extra_fields.get("port", None) or os.getenv(
+                "POSTGRES_PORT"
+            )
+            db_name = self.config.extra_fields.get(
+                "db_name", None
+            ) or os.getenv("POSTGRES_DBNAME")
+
+            if not all([user, password, host, db_name]):
+                raise ValueError(
+                    "Error, please set the POSTGRES_USER, POSTGRES_PASSWORD, POSTGRES_HOST, POSTGRES_DBNAME environment variables or provide them in the config."
+                )
+
+            # Check if it's a Unix socket connection
+            if host.startswith("/") and not port:
+                DB_CONNECTION = (
+                    f"postgresql://{user}:{password}@/{db_name}?host={host}"
+                )
+                logger.info("Using Unix socket connection")
+            else:
+                DB_CONNECTION = (
+                    f"postgresql://{user}:{password}@{host}:{port}/{db_name}"
+                )
+                logger.info("Using TCP connection")
+
+        # The rest of the initialization remains the same
+        try:
+            self.vx: Client = r2r.vecs.create_client(DB_CONNECTION)
+        except Exception as e:
+            raise ValueError(
+                f"Error {e} occurred while attempting to connect to the pgvector provider with {DB_CONNECTION}."
+            )
+
+        self.collection_name = self.config.extra_fields.get(
+            "vecs_collection"
+        ) or os.getenv("POSTGRES_VECS_COLLECTION")
+        if not self.collection_name:
+            raise ValueError(
+                "Error, please set a valid POSTGRES_VECS_COLLECTION environment variable or set a 'vecs_collection' in the 'vector_database' settings of your `config.json`."
+            )
+
+        self.collection: Optional[Collection] = None
+
+        logger.info(
+            f"Successfully initialized PGVectorDB with collection: {self.collection_name}"
+        )
+
+    def initialize_collection(self, dimension: int) -> None:
+        self.collection = self.vx.get_or_create_collection(
+            name=self.collection_name, dimension=dimension
+        )
+        self._create_document_info_table()
+        self._create_hybrid_search_function()
+
+    def _create_document_info_table(self):
+        with self.vx.Session() as sess:
+            with sess.begin():
+                try:
+                    # Enable uuid-ossp extension
+                    sess.execute(
+                        text('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')
+                    )
+                except exc.ProgrammingError as e:
+                    logger.error(f"Error enabling uuid-ossp extension: {e}")
+                    raise
+
+                # Create the table if it doesn't exist
+                create_table_query = f"""
+                CREATE TABLE IF NOT EXISTS document_info_"{self.collection_name}" (
+                    document_id UUID PRIMARY KEY,
+                    title TEXT,
+                    user_id UUID NULL,
+                    version TEXT,
+                    size_in_bytes INT,
+                    created_at TIMESTAMPTZ DEFAULT NOW(),
+                    updated_at TIMESTAMPTZ DEFAULT NOW(),
+                    metadata JSONB,
+                    status TEXT
+                );
+                """
+                sess.execute(text(create_table_query))
+
+                # Add the new column if it doesn't exist
+                add_column_query = f"""
+                DO $$
+                BEGIN
+                    IF NOT EXISTS (
+                        SELECT 1
+                        FROM information_schema.columns
+                        WHERE table_name = 'document_info_"{self.collection_name}"'
+                        AND column_name = 'status'
+                    ) THEN
+                        ALTER TABLE "document_info_{self.collection_name}"
+                        ADD COLUMN status TEXT DEFAULT 'processing';
+                    END IF;
+                END $$;
+                """
+                sess.execute(text(add_column_query))
+
+                sess.commit()
+
+    def _create_hybrid_search_function(self):
+        hybrid_search_function = f"""
+        CREATE OR REPLACE FUNCTION hybrid_search_{self.collection_name}(
+            query_text TEXT,
+            query_embedding VECTOR(512),
+            match_limit INT,
+            full_text_weight FLOAT = 1,
+            semantic_weight FLOAT = 1,
+            rrf_k INT = 50,
+            filter_condition JSONB = NULL
+        )
+        RETURNS SETOF vecs."{self.collection_name}"
+        LANGUAGE sql
+        AS $$
+        WITH full_text AS (
+            SELECT
+                id,
+                ROW_NUMBER() OVER (ORDER BY ts_rank(to_tsvector('english', metadata->>'text'), websearch_to_tsquery(query_text)) DESC) AS rank_ix
+            FROM vecs."{self.collection_name}"
+            WHERE to_tsvector('english', metadata->>'text') @@ websearch_to_tsquery(query_text)
+            AND (filter_condition IS NULL OR (metadata @> filter_condition))
+            ORDER BY rank_ix
+            LIMIT LEAST(match_limit, 30) * 2
+        ),
+        semantic AS (
+            SELECT
+                id,
+                ROW_NUMBER() OVER (ORDER BY vec <#> query_embedding) AS rank_ix
+            FROM vecs."{self.collection_name}"
+            WHERE filter_condition IS NULL OR (metadata @> filter_condition)
+            ORDER BY rank_ix
+            LIMIT LEAST(match_limit, 30) * 2
+        )
+        SELECT
+            vecs."{self.collection_name}".*
+        FROM
+            full_text
+            FULL OUTER JOIN semantic
+                ON full_text.id = semantic.id
+            JOIN vecs."{self.collection_name}"
+                ON vecs."{self.collection_name}".id = COALESCE(full_text.id, semantic.id)
+        ORDER BY
+            COALESCE(1.0 / (rrf_k + full_text.rank_ix), 0.0) * full_text_weight +
+            COALESCE(1.0 / (rrf_k + semantic.rank_ix), 0.0) * semantic_weight
+            DESC
+        LIMIT
+            LEAST(match_limit, 30);
+        $$;
+        """
+        retry_attempts = 5
+        for attempt in range(retry_attempts):
+            try:
+                with self.vx.Session() as sess:
+                    # Acquire an advisory lock
+                    sess.execute(text("SELECT pg_advisory_lock(123456789)"))
+                    try:
+                        sess.execute(text(hybrid_search_function))
+                        sess.commit()
+                    finally:
+                        # Release the advisory lock
+                        sess.execute(
+                            text("SELECT pg_advisory_unlock(123456789)")
+                        )
+                break  # Break the loop if successful
+            except exc.InternalError as e:
+                if "tuple concurrently updated" in str(e):
+                    time.sleep(2**attempt)  # Exponential backoff
+                else:
+                    raise  # Re-raise the exception if it's not a concurrency issue
+        else:
+            raise RuntimeError(
+                "Failed to create hybrid search function after multiple attempts"
+            )
+
+    def copy(self, entry: VectorEntry, commit=True) -> None:
+        if self.collection is None:
+            raise ValueError(
+                "Please call `initialize_collection` before attempting to run `copy`."
+            )
+
+        serializeable_entry = entry.to_serializable()
+
+        self.collection.copy(
+            records=[
+                (
+                    serializeable_entry["id"],
+                    serializeable_entry["vector"],
+                    serializeable_entry["metadata"],
+                )
+            ]
+        )
+
+    def copy_entries(
+        self, entries: list[VectorEntry], commit: bool = True
+    ) -> None:
+        if self.collection is None:
+            raise ValueError(
+                "Please call `initialize_collection` before attempting to run `copy_entries`."
+            )
+
+        self.collection.copy(
+            records=[
+                (
+                    str(entry.id),
+                    entry.vector.data,
+                    entry.to_serializable()["metadata"],
+                )
+                for entry in entries
+            ]
+        )
+
+    def upsert(self, entry: VectorEntry, commit=True) -> None:
+        if self.collection is None:
+            raise ValueError(
+                "Please call `initialize_collection` before attempting to run `upsert`."
+            )
+
+        self.collection.upsert(
+            records=[
+                (
+                    str(entry.id),
+                    entry.vector.data,
+                    entry.to_serializable()["metadata"],
+                )
+            ]
+        )
+
+    def upsert_entries(
+        self, entries: list[VectorEntry], commit: bool = True
+    ) -> None:
+        if self.collection is None:
+            raise ValueError(
+                "Please call `initialize_collection` before attempting to run `upsert_entries`."
+            )
+
+        self.collection.upsert(
+            records=[
+                (
+                    str(entry.id),
+                    entry.vector.data,
+                    entry.to_serializable()["metadata"],
+                )
+                for entry in entries
+            ]
+        )
+
+    def search(
+        self,
+        query_vector: list[float],
+        filters: dict[str, Union[bool, int, str]] = {},
+        limit: int = 10,
+        *args,
+        **kwargs,
+    ) -> list[VectorSearchResult]:
+        if self.collection is None:
+            raise ValueError(
+                "Please call `initialize_collection` before attempting to run `search`."
+            )
+        measure = kwargs.get("measure", "cosine_distance")
+        mapped_filters = {
+            key: {"$eq": value} for key, value in filters.items()
+        }
+
+        return [
+            VectorSearchResult(id=ele[0], score=float(1 - ele[1]), metadata=ele[2])  # type: ignore
+            for ele in self.collection.query(
+                data=query_vector,
+                limit=limit,
+                filters=mapped_filters,
+                measure=measure,
+                include_value=True,
+                include_metadata=True,
+            )
+        ]
+
+    def hybrid_search(
+        self,
+        query_text: str,
+        query_vector: list[float],
+        limit: int = 10,
+        filters: Optional[dict[str, Union[bool, int, str]]] = None,
+        # Hybrid search parameters
+        full_text_weight: float = 1.0,
+        semantic_weight: float = 1.0,
+        rrf_k: int = 20,  # typical value is ~2x the number of results you want
+        *args,
+        **kwargs,
+    ) -> list[VectorSearchResult]:
+        if self.collection is None:
+            raise ValueError(
+                "Please call `initialize_collection` before attempting to run `hybrid_search`."
+            )
+
+        # Convert filters to a JSON-compatible format
+        filter_condition = None
+        if filters:
+            filter_condition = json.dumps(filters)
+
+        query = text(
+            f"""
+            SELECT * FROM hybrid_search_{self.collection_name}(
+                cast(:query_text as TEXT), cast(:query_embedding as VECTOR), cast(:match_limit as INT),
+                cast(:full_text_weight as FLOAT), cast(:semantic_weight as FLOAT), cast(:rrf_k as INT),
+                cast(:filter_condition as JSONB)
+            )
+        """
+        )
+
+        params = {
+            "query_text": str(query_text),
+            "query_embedding": list(query_vector),
+            "match_limit": limit,
+            "full_text_weight": full_text_weight,
+            "semantic_weight": semantic_weight,
+            "rrf_k": rrf_k,
+            "filter_condition": filter_condition,
+        }
+
+        with self.vx.Session() as session:
+            result = session.execute(query, params).fetchall()
+        return [
+            VectorSearchResult(id=row[0], score=1.0, metadata=row[-1])
+            for row in result
+        ]
+
+    def create_index(self, index_type, column_name, index_options):
+        pass
+
+    def delete_by_metadata(
+        self,
+        metadata_fields: list[str],
+        metadata_values: list[Union[bool, int, str]],
+        logic: Literal["AND", "OR"] = "AND",
+    ) -> list[str]:
+        if logic == "OR":
+            raise ValueError(
+                "OR logic is still being tested before official support for `delete_by_metadata` in pgvector."
+            )
+        if self.collection is None:
+            raise ValueError(
+                "Please call `initialize_collection` before attempting to run `delete_by_metadata`."
+            )
+
+        if len(metadata_fields) != len(metadata_values):
+            raise ValueError(
+                "The number of metadata fields must match the number of metadata values."
+            )
+
+        # Construct the filter
+        if logic == "AND":
+            filters = {
+                k: {"$eq": v} for k, v in zip(metadata_fields, metadata_values)
+            }
+        else:  # OR logic
+            # TODO - Test 'or' logic and remove check above
+            filters = {
+                "$or": [
+                    {k: {"$eq": v}}
+                    for k, v in zip(metadata_fields, metadata_values)
+                ]
+            }
+        return self.collection.delete(filters=filters)
+
+    def get_metadatas(
+        self,
+        metadata_fields: list[str],
+        filter_field: Optional[str] = None,
+        filter_value: Optional[Union[bool, int, str]] = None,
+    ) -> list[dict]:
+        if self.collection is None:
+            raise ValueError(
+                "Please call `initialize_collection` before attempting to run `get_metadatas`."
+            )
+
+        results = {tuple(metadata_fields): {}}
+        for field in metadata_fields:
+            unique_values = self.collection.get_unique_metadata_values(
+                field=field,
+                filter_field=filter_field,
+                filter_value=filter_value,
+            )
+            for value in unique_values:
+                if value not in results:
+                    results[value] = {}
+                results[value][field] = value
+
+        return [
+            results[key] for key in results if key != tuple(metadata_fields)
+        ]
+
+    def upsert_documents_overview(
+        self, documents_overview: list[DocumentInfo]
+    ) -> None:
+        for document_info in documents_overview:
+            db_entry = document_info.convert_to_db_entry()
+
+            # Convert 'None' string to None type for user_id
+            if db_entry["user_id"] == "None":
+                db_entry["user_id"] = None
+
+            query = text(
+                f"""
+                INSERT INTO "document_info_{self.collection_name}" (document_id, title, user_id, version, created_at, updated_at, size_in_bytes, metadata, status)
+                VALUES (:document_id, :title, :user_id, :version, :created_at, :updated_at, :size_in_bytes, :metadata, :status)
+                ON CONFLICT (document_id) DO UPDATE SET
+                    title = EXCLUDED.title,
+                    user_id = EXCLUDED.user_id,
+                    version = EXCLUDED.version,
+                    updated_at = EXCLUDED.updated_at,
+                    size_in_bytes = EXCLUDED.size_in_bytes,
+                    metadata = EXCLUDED.metadata,
+                    status = EXCLUDED.status;
+            """
+            )
+            with self.vx.Session() as sess:
+                sess.execute(query, db_entry)
+                sess.commit()
+
+    def delete_from_documents_overview(
+        self, document_id: str, version: Optional[str] = None
+    ) -> None:
+        query = f"""
+            DELETE FROM "document_info_{self.collection_name}"
+            WHERE document_id = :document_id
+        """
+        params = {"document_id": document_id}
+
+        if version is not None:
+            query += " AND version = :version"
+            params["version"] = version
+
+        with self.vx.Session() as sess:
+            with sess.begin():
+                sess.execute(text(query), params)
+            sess.commit()
+
+    def get_documents_overview(
+        self,
+        filter_document_ids: Optional[list[str]] = None,
+        filter_user_ids: Optional[list[str]] = None,
+    ):
+        conditions = []
+        params = {}
+
+        if filter_document_ids:
+            placeholders = ", ".join(
+                f":doc_id_{i}" for i in range(len(filter_document_ids))
+            )
+            conditions.append(f"document_id IN ({placeholders})")
+            params.update(
+                {
+                    f"doc_id_{i}": str(document_id)
+                    for i, document_id in enumerate(filter_document_ids)
+                }
+            )
+        if filter_user_ids:
+            placeholders = ", ".join(
+                f":user_id_{i}" for i in range(len(filter_user_ids))
+            )
+            conditions.append(f"user_id IN ({placeholders})")
+            params.update(
+                {
+                    f"user_id_{i}": str(user_id)
+                    for i, user_id in enumerate(filter_user_ids)
+                }
+            )
+
+        query = f"""
+            SELECT document_id, title, user_id, version, size_in_bytes, created_at, updated_at, metadata, status
+            FROM "document_info_{self.collection_name}"
+        """
+        if conditions:
+            query += " WHERE " + " AND ".join(conditions)
+
+        with self.vx.Session() as sess:
+            results = sess.execute(text(query), params).fetchall()
+            return [
+                DocumentInfo(
+                    document_id=row[0],
+                    title=row[1],
+                    user_id=row[2],
+                    version=row[3],
+                    size_in_bytes=row[4],
+                    created_at=row[5],
+                    updated_at=row[6],
+                    metadata=row[7],
+                    status=row[8],
+                )
+                for row in results
+            ]
+
+    def get_document_chunks(self, document_id: str) -> list[dict]:
+        if not self.collection:
+            raise ValueError("Collection is not initialized.")
+
+        table_name = self.collection.table.name
+        query = text(
+            f"""
+            SELECT metadata
+            FROM vecs."{table_name}"
+            WHERE metadata->>'document_id' = :document_id
+            ORDER BY CAST(metadata->>'chunk_order' AS INTEGER)
+        """
+        )
+
+        params = {"document_id": document_id}
+
+        with self.vx.Session() as sess:
+            results = sess.execute(query, params).fetchall()
+            return [result[0] for result in results]
+
+    def get_users_overview(self, user_ids: Optional[list[str]] = None):
+        user_ids_condition = ""
+        params = {}
+        if user_ids:
+            user_ids_condition = "WHERE user_id IN :user_ids"
+            params["user_ids"] = tuple(
+                map(str, user_ids)
+            )  # Convert UUIDs to strings
+
+        query = f"""
+            SELECT user_id, COUNT(document_id) AS num_files, SUM(size_in_bytes) AS total_size_in_bytes, ARRAY_AGG(document_id) AS document_ids
+            FROM "document_info_{self.collection_name}"
+            {user_ids_condition}
+            GROUP BY user_id
+        """
+
+        with self.vx.Session() as sess:
+            results = sess.execute(text(query), params).fetchall()
+        return [
+            UserStats(
+                user_id=row[0],
+                num_files=row[1],
+                total_size_in_bytes=row[2],
+                document_ids=row[3],
+            )
+            for row in results
+            if row[0] is not None
+        ]