aboutsummaryrefslogtreecommitdiff
import copy
import json
import logging
import math
import time
import uuid
from typing import Any, Optional, TypedDict
from uuid import UUID

import numpy as np

from core.base import (
    ChunkSearchResult,
    Handler,
    IndexArgsHNSW,
    IndexArgsIVFFlat,
    IndexMeasure,
    IndexMethod,
    R2RException,
    SearchSettings,
    VectorEntry,
    VectorQuantizationType,
    VectorTableName,
)
from core.base.utils import _decorate_vector_type

from .base import PostgresConnectionManager
from .filters import apply_filters

logger = logging.getLogger()


def psql_quote_literal(value: str) -> str:
    """Safely quote a string literal for PostgreSQL to prevent SQL injection.

    This is a simple implementation - in production, you should use proper parameterization
    or your database driver's quoting functions.
    """
    return "'" + value.replace("'", "''") + "'"


def index_measure_to_ops(
    measure: IndexMeasure,
    quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
):
    return _decorate_vector_type(measure.ops, quantization_type)


def quantize_vector_to_binary(
    vector: list[float] | np.ndarray,
    threshold: float = 0.0,
) -> bytes:
    """Quantizes a float vector to a binary vector string for PostgreSQL bit
    type. Used when quantization_type is INT1.

    Args:
        vector (List[float] | np.ndarray): Input vector of floats
        threshold (float, optional): Threshold for binarization. Defaults to 0.0.

    Returns:
        str: Binary string representation for PostgreSQL bit type
    """
    # Convert input to numpy array if it isn't already
    if not isinstance(vector, np.ndarray):
        vector = np.array(vector)

    # Convert to binary (1 where value > threshold, 0 otherwise)
    binary_vector = (vector > threshold).astype(int)

    # Convert to string of 1s and 0s
    # Convert to string of 1s and 0s, then to bytes
    binary_string = "".join(map(str, binary_vector))
    return binary_string.encode("ascii")


class HybridSearchIntermediateResult(TypedDict):
    semantic_rank: int
    full_text_rank: int
    data: ChunkSearchResult
    rrf_score: float


class PostgresChunksHandler(Handler):
    TABLE_NAME = VectorTableName.CHUNKS

    def __init__(
        self,
        project_name: str,
        connection_manager: PostgresConnectionManager,
        dimension: int | float,
        quantization_type: VectorQuantizationType,
    ):
        super().__init__(project_name, connection_manager)
        self.dimension = dimension
        self.quantization_type = quantization_type

    async def create_tables(self):
        # First check if table already exists and validate dimensions
        table_exists_query = """
        SELECT EXISTS (
            SELECT FROM pg_tables
            WHERE schemaname = $1
            AND tablename = $2
        );
        """
        table_name = VectorTableName.CHUNKS
        table_exists = await self.connection_manager.fetch_query(
            table_exists_query, (self.project_name, table_name)
        )

        if len(table_exists) > 0 and table_exists[0]["exists"]:
            # Table exists, check vector dimension
            vector_dim_query = """
            SELECT a.atttypmod as dimension
            FROM pg_attribute a
            JOIN pg_class c ON a.attrelid = c.oid
            JOIN pg_namespace n ON c.relnamespace = n.oid
            WHERE n.nspname = $1
            AND c.relname = $2
            AND a.attname = 'vec';
            """

            vector_dim_result = await self.connection_manager.fetch_query(
                vector_dim_query, (self.project_name, table_name)
            )

            if vector_dim_result and len(vector_dim_result) > 0:
                existing_dimension = vector_dim_result[0]["dimension"]
                # In pgvector, dimension is stored as atttypmod - 4
                if existing_dimension > 0:  # If it has a specific dimension
                    # Compare with provided dimension
                    if (
                        self.dimension > 0
                        and existing_dimension != self.dimension
                    ):
                        raise ValueError(
                            f"Dimension mismatch: Table '{self.project_name}.{table_name}' was created with "
                            f"dimension {existing_dimension}, but {self.dimension} was provided. "
                            f"You must use the same dimension for existing tables."
                        )

        # Check for old table name
        check_query = """
        SELECT EXISTS (
            SELECT FROM pg_tables
            WHERE schemaname = $1
            AND tablename = $2
        );
        """
        old_table_exists = await self.connection_manager.fetch_query(
            check_query, (self.project_name, self.project_name)
        )

        if len(old_table_exists) > 0 and old_table_exists[0]["exists"]:
            raise ValueError(
                f"Found old vector table '{self.project_name}.{self.project_name}'. "
                "Please run `r2r db upgrade` with the CLI, or to run manually, "
                "run in R2R/py/migrations with 'alembic upgrade head' to update "
                "your database schema to the new version."
            )

        binary_col = (
            ""
            if self.quantization_type != VectorQuantizationType.INT1
            else f"vec_binary bit({self.dimension}),"
        )

        if self.dimension > 0:
            vector_col = f"vec vector({self.dimension})"
        else:
            vector_col = "vec vector"

        query = f"""
        CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (
            id UUID PRIMARY KEY,
            document_id UUID,
            owner_id UUID,
            collection_ids UUID[],
            {vector_col},
            {binary_col}
            text TEXT,
            metadata JSONB,
            fts tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED
        );
        CREATE INDEX IF NOT EXISTS idx_vectors_document_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (document_id);
        CREATE INDEX IF NOT EXISTS idx_vectors_owner_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (owner_id);
        CREATE INDEX IF NOT EXISTS idx_vectors_collection_ids ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (collection_ids);
        CREATE INDEX IF NOT EXISTS idx_vectors_text ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (to_tsvector('english', text));
        """

        await self.connection_manager.execute_query(query)

    async def upsert(self, entry: VectorEntry) -> None:
        """Upsert function that handles vector quantization only when
        quantization_type is INT1.

        Matches the table schema where vec_binary column only exists for INT1
        quantization.
        """
        # Check the quantization type to determine which columns to use
        if self.quantization_type == VectorQuantizationType.INT1:
            bit_dim = (
                "" if math.isnan(self.dimension) else f"({self.dimension})"
            )

            # For quantized vectors, use vec_binary column
            query = f"""
            INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
            (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
            VALUES ($1, $2, $3, $4, $5, $6::bit({bit_dim}), $7, $8)
            ON CONFLICT (id) DO UPDATE SET
            document_id = EXCLUDED.document_id,
            owner_id = EXCLUDED.owner_id,
            collection_ids = EXCLUDED.collection_ids,
            vec = EXCLUDED.vec,
            vec_binary = EXCLUDED.vec_binary,
            text = EXCLUDED.text,
            metadata = EXCLUDED.metadata;
            """
            await self.connection_manager.execute_query(
                query,
                (
                    entry.id,
                    entry.document_id,
                    entry.owner_id,
                    entry.collection_ids,
                    str(entry.vector.data),
                    quantize_vector_to_binary(
                        entry.vector.data
                    ),  # Convert to binary
                    entry.text,
                    json.dumps(entry.metadata),
                ),
            )
        else:
            # For regular vectors, use vec column only
            query = f"""
            INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
            (id, document_id, owner_id, collection_ids, vec, text, metadata)
            VALUES ($1, $2, $3, $4, $5, $6, $7)
            ON CONFLICT (id) DO UPDATE SET
            document_id = EXCLUDED.document_id,
            owner_id = EXCLUDED.owner_id,
            collection_ids = EXCLUDED.collection_ids,
            vec = EXCLUDED.vec,
            text = EXCLUDED.text,
            metadata = EXCLUDED.metadata;
            """

            await self.connection_manager.execute_query(
                query,
                (
                    entry.id,
                    entry.document_id,
                    entry.owner_id,
                    entry.collection_ids,
                    str(entry.vector.data),
                    entry.text,
                    json.dumps(entry.metadata),
                ),
            )

    async def upsert_entries(self, entries: list[VectorEntry]) -> None:
        """Batch upsert function that handles vector quantization only when
        quantization_type is INT1.

        Matches the table schema where vec_binary column only exists for INT1
        quantization.
        """
        if self.quantization_type == VectorQuantizationType.INT1:
            bit_dim = (
                "" if math.isnan(self.dimension) else f"({self.dimension})"
            )

            # For quantized vectors, use vec_binary column
            query = f"""
            INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
            (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
            VALUES ($1, $2, $3, $4, $5, $6::bit({bit_dim}), $7, $8)
            ON CONFLICT (id) DO UPDATE SET
            document_id = EXCLUDED.document_id,
            owner_id = EXCLUDED.owner_id,
            collection_ids = EXCLUDED.collection_ids,
            vec = EXCLUDED.vec,
            vec_binary = EXCLUDED.vec_binary,
            text = EXCLUDED.text,
            metadata = EXCLUDED.metadata;
            """
            bin_params = [
                (
                    entry.id,
                    entry.document_id,
                    entry.owner_id,
                    entry.collection_ids,
                    str(entry.vector.data),
                    quantize_vector_to_binary(
                        entry.vector.data
                    ),  # Convert to binary
                    entry.text,
                    json.dumps(entry.metadata),
                )
                for entry in entries
            ]
            await self.connection_manager.execute_many(query, bin_params)

        else:
            # For regular vectors, use vec column only
            query = f"""
            INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
            (id, document_id, owner_id, collection_ids, vec, text, metadata)
            VALUES ($1, $2, $3, $4, $5, $6, $7)
            ON CONFLICT (id) DO UPDATE SET
            document_id = EXCLUDED.document_id,
            owner_id = EXCLUDED.owner_id,
            collection_ids = EXCLUDED.collection_ids,
            vec = EXCLUDED.vec,
            text = EXCLUDED.text,
            metadata = EXCLUDED.metadata;
            """
            params = [
                (
                    entry.id,
                    entry.document_id,
                    entry.owner_id,
                    entry.collection_ids,
                    str(entry.vector.data),
                    entry.text,
                    json.dumps(entry.metadata),
                )
                for entry in entries
            ]

            await self.connection_manager.execute_many(query, params)

    async def semantic_search(
        self, query_vector: list[float], search_settings: SearchSettings
    ) -> list[ChunkSearchResult]:
        try:
            imeasure_obj = IndexMeasure(
                search_settings.chunk_settings.index_measure
            )
        except ValueError:
            raise ValueError("Invalid index measure") from None

        table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
        cols = [
            f"{table_name}.id",
            f"{table_name}.document_id",
            f"{table_name}.owner_id",
            f"{table_name}.collection_ids",
            f"{table_name}.text",
        ]

        params: list[str | int | bytes] = []

        # For binary vectors (INT1), implement two-stage search
        if self.quantization_type == VectorQuantizationType.INT1:
            # Convert query vector to binary format
            binary_query = quantize_vector_to_binary(query_vector)
            # TODO - Put depth multiplier in config / settings
            extended_limit = (
                search_settings.limit * 20
            )  # Get 20x candidates for re-ranking

            if (
                imeasure_obj == IndexMeasure.hamming_distance
                or imeasure_obj == IndexMeasure.jaccard_distance
            ):
                binary_search_measure_repr = imeasure_obj.pgvector_repr
            else:
                binary_search_measure_repr = (
                    IndexMeasure.hamming_distance.pgvector_repr
                )

            # Use binary column and binary-specific distance measures for first stage
            bit_dim = (
                "" if math.isnan(self.dimension) else f"({self.dimension})"
            )
            stage1_distance = f"{table_name}.vec_binary {binary_search_measure_repr} $1::bit{bit_dim}"
            stage1_param = binary_query

            cols.append(
                f"{table_name}.vec"
            )  # Need original vector for re-ranking
            if search_settings.include_metadatas:
                cols.append(f"{table_name}.metadata")

            select_clause = ", ".join(cols)
            where_clause = ""
            params.append(stage1_param)

            if search_settings.filters:
                where_clause, params = apply_filters(
                    search_settings.filters, params, mode="where_clause"
                )

            vector_dim = (
                "" if math.isnan(self.dimension) else f"({self.dimension})"
            )

            # First stage: Get candidates using binary search
            query = f"""
            WITH candidates AS (
                SELECT {select_clause},
                    ({stage1_distance}) as binary_distance
                FROM {table_name}
                {where_clause}
                ORDER BY {stage1_distance}
                LIMIT ${len(params) + 1}
                OFFSET ${len(params) + 2}
            )
            -- Second stage: Re-rank using original vectors
            SELECT
                id,
                document_id,
                owner_id,
                collection_ids,
                text,
                {"metadata," if search_settings.include_metadatas else ""}
                (vec <=> ${len(params) + 4}::vector{vector_dim}) as distance
            FROM candidates
            ORDER BY distance
            LIMIT ${len(params) + 3}
            """

            params.extend(
                [
                    extended_limit,  # First stage limit
                    search_settings.offset,
                    search_settings.limit,  # Final limit
                    str(query_vector),  # For re-ranking
                ]
            )

        else:
            # Standard float vector handling
            vector_dim = (
                "" if math.isnan(self.dimension) else f"({self.dimension})"
            )
            distance_calc = f"{table_name}.vec {search_settings.chunk_settings.index_measure.pgvector_repr} $1::vector{vector_dim}"
            query_param = str(query_vector)

            if search_settings.include_scores:
                cols.append(f"({distance_calc}) AS distance")
            if search_settings.include_metadatas:
                cols.append(f"{table_name}.metadata")

            select_clause = ", ".join(cols)
            where_clause = ""
            params.append(query_param)

            if search_settings.filters:
                where_clause, new_params = apply_filters(
                    search_settings.filters,
                    params,
                    mode="where_clause",  # Get just conditions without WHERE
                )
                params = new_params

            query = f"""
            SELECT {select_clause}
            FROM {table_name}
            {where_clause}
            ORDER BY {distance_calc}
            LIMIT ${len(params) + 1}
            OFFSET ${len(params) + 2}
            """
            params.extend([search_settings.limit, search_settings.offset])
        results = await self.connection_manager.fetch_query(query, params)

        return [
            ChunkSearchResult(
                id=UUID(str(result["id"])),
                document_id=UUID(str(result["document_id"])),
                owner_id=UUID(str(result["owner_id"])),
                collection_ids=result["collection_ids"],
                text=result["text"],
                score=(
                    (1 - float(result["distance"]))
                    if "distance" in result
                    else -1
                ),
                metadata=(
                    json.loads(result["metadata"])
                    if search_settings.include_metadatas
                    else {}
                ),
            )
            for result in results
        ]

    async def full_text_search(
        self, query_text: str, search_settings: SearchSettings
    ) -> list[ChunkSearchResult]:
        conditions = []
        params: list[str | int | bytes] = [query_text]

        conditions.append("fts @@ websearch_to_tsquery('english', $1)")

        if search_settings.filters:
            filter_condition, params = apply_filters(
                search_settings.filters, params, mode="condition_only"
            )
            if filter_condition:
                conditions.append(filter_condition)

        where_clause = "WHERE " + " AND ".join(conditions)

        query = f"""
            SELECT
                id,
                document_id,
                owner_id,
                collection_ids,
                text,
                metadata,
                ts_rank(fts, websearch_to_tsquery('english', $1), 32) as rank
            FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
            {where_clause}
            ORDER BY rank DESC
            OFFSET ${len(params) + 1}
            LIMIT ${len(params) + 2}
        """

        params.extend(
            [
                search_settings.offset,
                search_settings.hybrid_settings.full_text_limit,
            ]
        )

        results = await self.connection_manager.fetch_query(query, params)
        return [
            ChunkSearchResult(
                id=UUID(str(r["id"])),
                document_id=UUID(str(r["document_id"])),
                owner_id=UUID(str(r["owner_id"])),
                collection_ids=r["collection_ids"],
                text=r["text"],
                score=float(r["rank"]),
                metadata=json.loads(r["metadata"]),
            )
            for r in results
        ]

    async def hybrid_search(
        self,
        query_text: str,
        query_vector: list[float],
        search_settings: SearchSettings,
        *args,
        **kwargs,
    ) -> list[ChunkSearchResult]:
        if search_settings.hybrid_settings is None:
            raise ValueError(
                "Please provide a valid `hybrid_settings` in the `search_settings`."
            )
        if (
            search_settings.hybrid_settings.full_text_limit
            < search_settings.limit
        ):
            raise ValueError(
                "The `full_text_limit` must be greater than or equal to the `limit`."
            )

        semantic_settings = copy.deepcopy(search_settings)
        semantic_settings.limit += search_settings.offset

        full_text_settings = copy.deepcopy(search_settings)
        full_text_settings.hybrid_settings.full_text_limit += (
            search_settings.offset
        )

        semantic_results: list[ChunkSearchResult] = await self.semantic_search(
            query_vector, semantic_settings
        )
        full_text_results: list[
            ChunkSearchResult
        ] = await self.full_text_search(query_text, full_text_settings)

        semantic_limit = search_settings.limit
        full_text_limit = search_settings.hybrid_settings.full_text_limit
        semantic_weight = search_settings.hybrid_settings.semantic_weight
        full_text_weight = search_settings.hybrid_settings.full_text_weight
        rrf_k = search_settings.hybrid_settings.rrf_k

        combined_results: dict[uuid.UUID, HybridSearchIntermediateResult] = {}

        for rank, result in enumerate(semantic_results, 1):
            combined_results[result.id] = {
                "semantic_rank": rank,
                "full_text_rank": full_text_limit,
                "data": result,
                "rrf_score": 0.0,  # Initialize with 0, will be calculated later
            }

        for rank, result in enumerate(full_text_results, 1):
            if result.id in combined_results:
                combined_results[result.id]["full_text_rank"] = rank
            else:
                combined_results[result.id] = {
                    "semantic_rank": semantic_limit,
                    "full_text_rank": rank,
                    "data": result,
                    "rrf_score": 0.0,  # Initialize with 0, will be calculated later
                }

        combined_results = {
            k: v
            for k, v in combined_results.items()
            if v["semantic_rank"] <= semantic_limit * 2
            and v["full_text_rank"] <= full_text_limit * 2
        }

        for hyb_result in combined_results.values():
            semantic_score = 1 / (rrf_k + hyb_result["semantic_rank"])
            full_text_score = 1 / (rrf_k + hyb_result["full_text_rank"])
            hyb_result["rrf_score"] = (
                semantic_score * semantic_weight
                + full_text_score * full_text_weight
            ) / (semantic_weight + full_text_weight)

        sorted_results = sorted(
            combined_results.values(),
            key=lambda x: x["rrf_score"],
            reverse=True,
        )
        offset_results = sorted_results[
            search_settings.offset : search_settings.offset
            + search_settings.limit
        ]

        return [
            ChunkSearchResult(
                id=result["data"].id,
                document_id=result["data"].document_id,
                owner_id=result["data"].owner_id,
                collection_ids=result["data"].collection_ids,
                text=result["data"].text,
                score=result["rrf_score"],
                metadata={
                    **result["data"].metadata,
                    "semantic_rank": result["semantic_rank"],
                    "full_text_rank": result["full_text_rank"],
                },
            )
            for result in offset_results
        ]

    async def delete(
        self, filters: dict[str, Any]
    ) -> dict[str, dict[str, str]]:
        params: list[str | int | bytes] = []
        where_clause, params = apply_filters(
            filters, params, mode="condition_only"
        )

        query = f"""
        DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
        WHERE {where_clause}
        RETURNING id, document_id, text;
        """

        results = await self.connection_manager.fetch_query(query, params)

        return {
            str(result["id"]): {
                "status": "deleted",
                "id": str(result["id"]),
                "document_id": str(result["document_id"]),
                "text": result["text"],
            }
            for result in results
        }

    async def assign_document_chunks_to_collection(
        self, document_id: UUID, collection_id: UUID
    ) -> None:
        query = f"""
        UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
        SET collection_ids = array_append(collection_ids, $1)
        WHERE document_id = $2 AND NOT ($1 = ANY(collection_ids));
        """
        return await self.connection_manager.execute_query(
            query, (str(collection_id), str(document_id))
        )

    async def remove_document_from_collection_vector(
        self, document_id: UUID, collection_id: UUID
    ) -> None:
        query = f"""
        UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
        SET collection_ids = array_remove(collection_ids, $1)
        WHERE document_id = $2;
        """
        await self.connection_manager.execute_query(
            query, (collection_id, document_id)
        )

    async def delete_user_vector(self, owner_id: UUID) -> None:
        query = f"""
        DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
        WHERE owner_id = $1;
        """
        await self.connection_manager.execute_query(query, (owner_id,))

    async def delete_collection_vector(self, collection_id: UUID) -> None:
        query = f"""
         DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
         WHERE $1 = ANY(collection_ids)
         RETURNING collection_ids
         """
        await self.connection_manager.fetchrow_query(query, (collection_id,))
        return None

    async def list_document_chunks(
        self,
        document_id: UUID,
        offset: int,
        limit: int,
        include_vectors: bool = False,
    ) -> dict[str, Any]:
        vector_select = ", vec" if include_vectors else ""
        limit_clause = f"LIMIT {limit}" if limit > -1 else ""

        query = f"""
        SELECT id, document_id, owner_id, collection_ids, text, metadata{vector_select}, COUNT(*) OVER() AS total
        FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
        WHERE document_id = $1
        ORDER BY (metadata->>'chunk_order')::integer
        OFFSET $2
        {limit_clause};
        """

        params = [document_id, offset]

        results = await self.connection_manager.fetch_query(query, params)

        chunks = []
        total = 0
        if results:
            total = results[0].get("total", 0)
            chunks = [
                {
                    "id": result["id"],
                    "document_id": result["document_id"],
                    "owner_id": result["owner_id"],
                    "collection_ids": result["collection_ids"],
                    "text": result["text"],
                    "metadata": json.loads(result["metadata"]),
                    "vector": (
                        json.loads(result["vec"]) if include_vectors else None
                    ),
                }
                for result in results
            ]

        return {"results": chunks, "total_entries": total}

    async def get_chunk(self, id: UUID) -> dict:
        query = f"""
        SELECT id, document_id, owner_id, collection_ids, text, metadata
        FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
        WHERE id = $1;
        """

        result = await self.connection_manager.fetchrow_query(query, (id,))

        if result:
            return {
                "id": result["id"],
                "document_id": result["document_id"],
                "owner_id": result["owner_id"],
                "collection_ids": result["collection_ids"],
                "text": result["text"],
                "metadata": json.loads(result["metadata"]),
            }
        raise R2RException(
            message=f"Chunk with ID {id} not found", status_code=404
        )

    async def create_index(
        self,
        table_name: Optional[VectorTableName] = None,
        index_measure: IndexMeasure = IndexMeasure.cosine_distance,
        index_method: IndexMethod = IndexMethod.auto,
        index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW] = None,
        index_name: Optional[str] = None,
        index_column: Optional[str] = None,
        concurrently: bool = True,
    ) -> None:
        """Creates an index for the collection.

        Note:
            When `vecs` creates an index on a pgvector column in PostgreSQL, it uses a multi-step
            process that enables performant indexes to be built for large collections with low end
            database hardware.

            Those steps are:

            - Creates a new table with a different name
            - Randomly selects records from the existing table
            - Inserts the random records from the existing table into the new table
            - Creates the requested vector index on the new table
            - Upserts all data from the existing table into the new table
            - Drops the existing table
            - Renames the new table to the existing tables name

            If you create dependencies (like views) on the table that underpins
            a `vecs.Collection` the `create_index` step may require you to drop those dependencies before
            it will succeed.

        Args:
            index_measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'.
            index_method (IndexMethod, optional): The indexing method to use. Defaults to 'auto'.
            index_arguments: (IndexArgsIVFFlat | IndexArgsHNSW, optional): Index type specific arguments
            index_name (str, optional): The name of the index to create. Defaults to None.
            concurrently (bool, optional): Whether to create the index concurrently. Defaults to True.
        Raises:
            ValueError: If an invalid index method is used, or if *replace* is False and an index already exists.
        """

        if table_name == VectorTableName.CHUNKS:
            table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}"  # TODO - Fix bug in vector table naming convention
            if index_column:
                col_name = index_column
            else:
                col_name = (
                    "vec"
                    if (
                        index_measure != IndexMeasure.hamming_distance
                        and index_measure != IndexMeasure.jaccard_distance
                    )
                    else "vec_binary"
                )
        elif table_name == VectorTableName.ENTITIES_DOCUMENT:
            table_name_str = (
                f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
            )
            col_name = "description_embedding"
        elif table_name == VectorTableName.GRAPHS_ENTITIES:
            table_name_str = (
                f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
            )
            col_name = "description_embedding"
        elif table_name == VectorTableName.COMMUNITIES:
            table_name_str = (
                f"{self.project_name}.{VectorTableName.COMMUNITIES}"
            )
            col_name = "embedding"
        else:
            raise ValueError("invalid table name")

        if index_method not in (
            IndexMethod.ivfflat,
            IndexMethod.hnsw,
            IndexMethod.auto,
        ):
            raise ValueError("invalid index method")

        if index_arguments:
            # Disallow case where user submits index arguments but uses the
            # IndexMethod.auto index (index build arguments should only be
            # used with a specific index)
            if index_method == IndexMethod.auto:
                raise ValueError(
                    "Index build parameters are not allowed when using the IndexMethod.auto index."
                )
            # Disallow case where user specifies one index type but submits
            # index build arguments for the other index type
            if (
                isinstance(index_arguments, IndexArgsHNSW)
                and index_method != IndexMethod.hnsw
            ) or (
                isinstance(index_arguments, IndexArgsIVFFlat)
                and index_method != IndexMethod.ivfflat
            ):
                raise ValueError(
                    f"{index_arguments.__class__.__name__} build parameters were supplied but {index_method} index was specified."
                )

        if index_method == IndexMethod.auto:
            index_method = IndexMethod.hnsw

        ops = index_measure_to_ops(
            index_measure  # , quantization_type=self.quantization_type
        )

        if ops is None:
            raise ValueError("Unknown index measure")

        concurrently_sql = "CONCURRENTLY" if concurrently else ""

        index_name = (
            index_name
            or f"ix_{ops}_{index_method}__{col_name}_{time.strftime('%Y%m%d%H%M%S')}"
        )

        create_index_sql = f"""
        CREATE INDEX {concurrently_sql} {index_name}
        ON {table_name_str}
        USING {index_method} ({col_name} {ops}) {self._get_index_options(index_method, index_arguments)};
        """

        try:
            if concurrently:
                async with (
                    self.connection_manager.pool.get_connection() as conn  # type: ignore
                ):
                    # Disable automatic transaction management
                    await conn.execute(
                        "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
                    )
                    await conn.execute(create_index_sql)
            else:
                # Non-concurrent index creation can use normal query execution
                await self.connection_manager.execute_query(create_index_sql)
        except Exception as e:
            raise Exception(f"Failed to create index: {e}") from e
        return None

    async def list_indices(
        self,
        offset: int,
        limit: int,
        filters: Optional[dict[str, Any]] = None,
    ) -> dict:
        where_clauses = []
        params: list[Any] = [self.project_name]  # Start with schema name
        param_count = 1

        # Handle filtering
        if filters:
            if "table_name" in filters:
                where_clauses.append(f"i.tablename = ${param_count + 1}")
                params.append(filters["table_name"])
                param_count += 1
            if "index_method" in filters:
                where_clauses.append(f"am.amname = ${param_count + 1}")
                params.append(filters["index_method"])
                param_count += 1
            if "index_name" in filters:
                where_clauses.append(
                    f"LOWER(i.indexname) LIKE LOWER(${param_count + 1})"
                )
                params.append(f"%{filters['index_name']}%")
                param_count += 1

        where_clause = " AND ".join(where_clauses) if where_clauses else ""
        if where_clause:
            where_clause = f"AND {where_clause}"

        query = f"""
        WITH index_info AS (
            SELECT
                i.indexname as name,
                i.tablename as table_name,
                i.indexdef as definition,
                am.amname as method,
                pg_relation_size(c.oid) as size_in_bytes,
                c.reltuples::bigint as row_estimate,
                COALESCE(psat.idx_scan, 0) as number_of_scans,
                COALESCE(psat.idx_tup_read, 0) as tuples_read,
                COALESCE(psat.idx_tup_fetch, 0) as tuples_fetched,
                COUNT(*) OVER() as total_count
            FROM pg_indexes i
            JOIN pg_class c ON c.relname = i.indexname
            JOIN pg_am am ON c.relam = am.oid
            LEFT JOIN pg_stat_user_indexes psat ON psat.indexrelname = i.indexname
                AND psat.schemaname = i.schemaname
            WHERE i.schemaname = $1
            AND i.indexdef LIKE '%vector%'
            {where_clause}
        )
        SELECT *
        FROM index_info
        ORDER BY name
        LIMIT ${param_count + 1}
        OFFSET ${param_count + 2}
        """

        # Add limit and offset to params
        params.extend([limit, offset])

        results = await self.connection_manager.fetch_query(query, params)

        indices = []
        total_entries = 0

        if results:
            total_entries = results[0]["total_count"]
            for result in results:
                index_info = {
                    "name": result["name"],
                    "table_name": result["table_name"],
                    "definition": result["definition"],
                    "size_in_bytes": result["size_in_bytes"],
                    "row_estimate": result["row_estimate"],
                    "number_of_scans": result["number_of_scans"],
                    "tuples_read": result["tuples_read"],
                    "tuples_fetched": result["tuples_fetched"],
                }
                indices.append(index_info)

        return {"indices": indices, "total_entries": total_entries}

    async def delete_index(
        self,
        index_name: str,
        table_name: Optional[VectorTableName] = None,
        concurrently: bool = True,
    ) -> None:
        """Deletes a vector index.

        Args:
            index_name (str): Name of the index to delete
            table_name (VectorTableName, optional): Table the index belongs to
            concurrently (bool): Whether to drop the index concurrently

        Raises:
            ValueError: If table name is invalid or index doesn't exist
            Exception: If index deletion fails
        """
        # Validate table name and get column name
        if table_name == VectorTableName.CHUNKS:
            table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}"
            col_name = "vec"
        elif table_name == VectorTableName.ENTITIES_DOCUMENT:
            table_name_str = (
                f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
            )
            col_name = "description_embedding"
        elif table_name == VectorTableName.GRAPHS_ENTITIES:
            table_name_str = (
                f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
            )
            col_name = "description_embedding"
        elif table_name == VectorTableName.COMMUNITIES:
            table_name_str = (
                f"{self.project_name}.{VectorTableName.COMMUNITIES}"
            )
            col_name = "description_embedding"
        else:
            raise ValueError("invalid table name")

        # Extract schema and base table name
        schema_name, base_table_name = table_name_str.split(".")

        # Verify index exists and is a vector index
        query = """
        SELECT indexdef
        FROM pg_indexes
        WHERE indexname = $1
        AND schemaname = $2
        AND tablename = $3
        AND indexdef LIKE $4
        """

        result = await self.connection_manager.fetchrow_query(
            query, (index_name, schema_name, base_table_name, f"%({col_name}%")
        )

        if not result:
            raise ValueError(
                f"Vector index '{index_name}' does not exist on table {table_name_str}"
            )

        # Drop the index
        concurrently_sql = "CONCURRENTLY" if concurrently else ""
        drop_query = (
            f"DROP INDEX {concurrently_sql} {schema_name}.{index_name}"
        )

        try:
            if concurrently:
                async with (
                    self.connection_manager.pool.get_connection() as conn  # type: ignore
                ):
                    # Disable automatic transaction management
                    await conn.execute(
                        "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
                    )
                    await conn.execute(drop_query)
            else:
                await self.connection_manager.execute_query(drop_query)
        except Exception as e:
            raise Exception(f"Failed to delete index: {e}") from e

    async def list_chunks(
        self,
        offset: int,
        limit: int,
        filters: Optional[dict[str, Any]] = None,
        include_vectors: bool = False,
    ) -> dict[str, Any]:
        """List chunks with pagination support.

        Args:
            offset (int, optional): Number of records to skip. Defaults to 0.
            limit (int, optional): Maximum number of records to return. Defaults to 10.
            filters (dict, optional): Dictionary of filters to apply. Defaults to None.
            include_vectors (bool, optional): Whether to include vector data. Defaults to False.

        Returns:
            dict: Dictionary containing:
                - results: List of chunk records
                - total_entries: Total number of chunks matching the filters
        """
        vector_select = ", vec" if include_vectors else ""
        select_clause = f"""
            id, document_id, owner_id, collection_ids,
            text, metadata{vector_select}, COUNT(*) OVER() AS total_entries
        """

        params: list[str | int | bytes] = []
        where_clause = ""
        if filters:
            where_clause, params = apply_filters(
                filters, params, mode="where_clause"
            )

        query = f"""
        SELECT {select_clause}
        FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
        {where_clause}
        LIMIT ${len(params) + 1}
        OFFSET ${len(params) + 2}
        """

        params.extend([limit, offset])

        # Execute the query
        results = await self.connection_manager.fetch_query(query, params)

        # Process results
        chunks = []
        total_entries = 0
        if results:
            total_entries = results[0].get("total_entries", 0)
            chunks = [
                {
                    "id": str(result["id"]),
                    "document_id": str(result["document_id"]),
                    "owner_id": str(result["owner_id"]),
                    "collection_ids": result["collection_ids"],
                    "text": result["text"],
                    "metadata": json.loads(result["metadata"]),
                    "vector": (
                        json.loads(result["vec"]) if include_vectors else None
                    ),
                }
                for result in results
            ]

        return {"results": chunks, "total_entries": total_entries}

    async def search_documents(
        self,
        query_text: str,
        settings: SearchSettings,
    ) -> list[dict[str, Any]]:
        """Search for documents based on their metadata fields and/or body
        text. Joins with documents table to get complete document metadata.

        Args:
            query_text (str): The search query text
            settings (SearchSettings): Search settings including search preferences and filters

        Returns:
            list[dict[str, Any]]: List of documents with their search scores and complete metadata
        """
        where_clauses = []
        params: list[str | int | bytes] = [query_text]

        search_over_body = getattr(settings, "search_over_body", True)
        search_over_metadata = getattr(settings, "search_over_metadata", True)
        metadata_weight = getattr(settings, "metadata_weight", 3.0)
        title_weight = getattr(settings, "title_weight", 1.0)
        metadata_keys = getattr(
            settings, "metadata_keys", ["title", "description"]
        )

        # Build the dynamic metadata field search expression
        metadata_fields_expr = " || ' ' || ".join(
            [
                f"COALESCE(v.metadata->>{psql_quote_literal(key)}, '')"
                for key in metadata_keys  # type: ignore
            ]
        )

        query = f"""
            WITH
            -- Metadata search scores
            metadata_scores AS (
                SELECT DISTINCT ON (v.document_id)
                    v.document_id,
                    d.metadata as doc_metadata,
                    CASE WHEN $1 = '' THEN 0.0
                    ELSE
                        ts_rank_cd(
                            setweight(to_tsvector('english', {metadata_fields_expr}), 'A'),
                            websearch_to_tsquery('english', $1),
                            32
                        )
                    END as metadata_rank
                FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} v
                LEFT JOIN {self._get_table_name("documents")} d ON v.document_id = d.id
                WHERE v.metadata IS NOT NULL
            ),
            -- Body search scores
            body_scores AS (
                SELECT
                    document_id,
                    AVG(
                        ts_rank_cd(
                            setweight(to_tsvector('english', COALESCE(text, '')), 'B'),
                            websearch_to_tsquery('english', $1),
                            32
                        )
                    ) as body_rank
                FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
                WHERE $1 != ''
                {"AND to_tsvector('english', text) @@ websearch_to_tsquery('english', $1)" if search_over_body else ""}
                GROUP BY document_id
            ),
            -- Combined scores with document metadata
            combined_scores AS (
                SELECT
                    COALESCE(m.document_id, b.document_id) as document_id,
                    m.doc_metadata as metadata,
                    COALESCE(m.metadata_rank, 0) as debug_metadata_rank,
                    COALESCE(b.body_rank, 0) as debug_body_rank,
                    CASE
                        WHEN {str(search_over_metadata).lower()} AND {str(search_over_body).lower()} THEN
                            COALESCE(m.metadata_rank, 0) * {metadata_weight} + COALESCE(b.body_rank, 0) * {title_weight}
                        WHEN {str(search_over_metadata).lower()} THEN
                            COALESCE(m.metadata_rank, 0)
                        WHEN {str(search_over_body).lower()} THEN
                            COALESCE(b.body_rank, 0)
                        ELSE 0
                    END as rank
                FROM metadata_scores m
                FULL OUTER JOIN body_scores b ON m.document_id = b.document_id
                WHERE (
                    ($1 = '') OR
                    ({str(search_over_metadata).lower()} AND m.metadata_rank > 0) OR
                    ({str(search_over_body).lower()} AND b.body_rank > 0)
                )
        """

        # Add any additional filters
        if settings.filters:
            filter_clause, params = apply_filters(settings.filters, params)
            where_clauses.append(filter_clause)

        if where_clauses:
            query += f" AND {' AND '.join(where_clauses)}"

        query += """
            )
            SELECT
                document_id,
                metadata,
                rank as score,
                debug_metadata_rank,
                debug_body_rank
            FROM combined_scores
            WHERE rank > 0
            ORDER BY rank DESC
            OFFSET ${offset_param} LIMIT ${limit_param}
        """.format(
            offset_param=len(params) + 1,
            limit_param=len(params) + 2,
        )

        # Add offset and limit to params
        params.extend([settings.offset, settings.limit])

        # Execute query
        results = await self.connection_manager.fetch_query(query, params)

        # Format results with complete document metadata
        return [
            {
                "document_id": str(r["document_id"]),
                "metadata": (
                    json.loads(r["metadata"])
                    if isinstance(r["metadata"], str)
                    else r["metadata"]
                ),
                "score": float(r["score"]),
                "debug_metadata_rank": float(r["debug_metadata_rank"]),
                "debug_body_rank": float(r["debug_body_rank"]),
            }
            for r in results
        ]

    def _get_index_options(
        self,
        method: IndexMethod,
        index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW],
    ) -> str:
        if method == IndexMethod.ivfflat:
            if isinstance(index_arguments, IndexArgsIVFFlat):
                return f"WITH (lists={index_arguments.n_lists})"
            else:
                # Default value if no arguments provided
                return "WITH (lists=100)"
        elif method == IndexMethod.hnsw:
            if isinstance(index_arguments, IndexArgsHNSW):
                return f"WITH (m={index_arguments.m}, ef_construction={index_arguments.ef_construction})"
            else:
                # Default values if no arguments provided
                return "WITH (m=16, ef_construction=64)"
        else:
            return ""  # No options for other methods