aboutsummaryrefslogtreecommitdiff
import asyncio
import copy
import csv
import json
import logging
import math
import tempfile
from typing import IO, Any, Optional
from uuid import UUID

import asyncpg
from fastapi import HTTPException

from core.base import (
    DocumentResponse,
    DocumentType,
    GraphConstructionStatus,
    GraphExtractionStatus,
    Handler,
    IngestionStatus,
    R2RException,
    SearchSettings,
)

from .base import PostgresConnectionManager
from .filters import apply_filters

logger = logging.getLogger()


def transform_filter_fields(filters: dict[str, Any]) -> dict[str, Any]:
    """Recursively transform filter field names by replacing 'document_id' with
    'id'. Handles nested logical operators like $and, $or, etc.

    Args:
        filters (dict[str, Any]): The original filters dictionary

    Returns:
        dict[str, Any]: A new dictionary with transformed field names
    """
    if not filters:
        return {}

    transformed = {}

    for key, value in filters.items():
        # Handle logical operators recursively
        if key in ("$and", "$or", "$not"):
            if isinstance(value, list):
                transformed[key] = [
                    transform_filter_fields(item) for item in value
                ]
            else:
                transformed[key] = transform_filter_fields(value)  # type: ignore
            continue

        # Replace 'document_id' with 'id'
        new_key = "id" if key == "document_id" else key

        # Handle nested dictionary cases (e.g., for operators like $eq, $gt, etc.)
        if isinstance(value, dict):
            transformed[new_key] = transform_filter_fields(value)  # type: ignore
        else:
            transformed[new_key] = value

    logger.debug(f"Transformed filters from {filters} to {transformed}")
    return transformed


class PostgresDocumentsHandler(Handler):
    TABLE_NAME = "documents"

    def __init__(
        self,
        project_name: str,
        connection_manager: PostgresConnectionManager,
        dimension: int | float,
    ):
        self.dimension = dimension
        super().__init__(project_name, connection_manager)

    async def create_tables(self):
        logger.info(
            f"Creating table, if it does not exist: {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
        )

        vector_dim = (
            "" if math.isnan(self.dimension) else f"({self.dimension})"
        )
        vector_type = f"vector{vector_dim}"

        try:
            query = f"""
            CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} (
                id UUID PRIMARY KEY,
                collection_ids UUID[],
                owner_id UUID,
                type TEXT,
                metadata JSONB,
                title TEXT,
                summary TEXT NULL,
                summary_embedding {vector_type} NULL,
                version TEXT,
                size_in_bytes INT,
                ingestion_status TEXT DEFAULT 'pending',
                extraction_status TEXT DEFAULT 'pending',
                created_at TIMESTAMPTZ DEFAULT NOW(),
                updated_at TIMESTAMPTZ DEFAULT NOW(),
                ingestion_attempt_number INT DEFAULT 0,
                raw_tsvector tsvector GENERATED ALWAYS AS (
                    setweight(to_tsvector('english', COALESCE(title, '')), 'A') ||
                    setweight(to_tsvector('english', COALESCE(summary, '')), 'B') ||
                    setweight(to_tsvector('english', COALESCE((metadata->>'description')::text, '')), 'C')
                ) STORED,
                total_tokens INT DEFAULT 0
            );
            CREATE INDEX IF NOT EXISTS idx_collection_ids_{self.project_name}
            ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} USING GIN (collection_ids);

            -- Full text search index
            CREATE INDEX IF NOT EXISTS idx_doc_search_{self.project_name}
            ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
            USING GIN (raw_tsvector);
            """
            await self.connection_manager.execute_query(query)

            # ---------------------------------------------------------------
            # Now check if total_tokens column exists in the 'documents' table
            # ---------------------------------------------------------------
            # 1) See what columns exist
            # column_check_query = f"""
            # SELECT column_name
            # FROM information_schema.columns
            # WHERE table_name = '{self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}'
            # AND table_schema = CURRENT_SCHEMA()
            # """
            # existing_columns = await self.connection_manager.fetch_query(column_check_query)
            # 2) Parse the table name for schema checks
            table_full_name = self._get_table_name(
                PostgresDocumentsHandler.TABLE_NAME
            )
            parsed_schema = "public"
            parsed_table_name = table_full_name
            if "." in table_full_name:
                parts = table_full_name.split(".", maxsplit=1)
                parsed_schema = parts[0].replace('"', "").strip()
                parsed_table_name = parts[1].replace('"', "").strip()
            else:
                parsed_table_name = parsed_table_name.replace('"', "").strip()

            # 3) Check columns
            column_check_query = f"""
            SELECT column_name
            FROM information_schema.columns
            WHERE table_name = '{parsed_table_name}'
            AND table_schema = '{parsed_schema}'
            """
            existing_columns = await self.connection_manager.fetch_query(
                column_check_query
            )

            existing_column_names = {
                row["column_name"] for row in existing_columns
            }

            if "total_tokens" not in existing_column_names:
                # 2) If missing, see if the table already has data
                # doc_count_query = f"SELECT COUNT(*) FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
                # doc_count = await self.connection_manager.fetchval(doc_count_query)
                doc_count_query = f"SELECT COUNT(*) AS doc_count FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
                row = await self.connection_manager.fetchrow_query(
                    doc_count_query
                )
                if row is None:
                    doc_count = 0
                else:
                    doc_count = row[
                        "doc_count"
                    ]  # or row[0] if you prefer positional indexing

                if doc_count > 0:
                    # We already have documents, but no total_tokens column
                    # => ask user to run r2r db migrate
                    logger.warning(
                        "Adding the missing 'total_tokens' column to the 'documents' table, this will impact existing files."
                    )

                create_tokens_col = f"""
                ALTER TABLE {table_full_name}
                ADD COLUMN total_tokens INT DEFAULT 0
                """
                await self.connection_manager.execute_query(create_tokens_col)

        except Exception as e:
            logger.warning(f"Error {e} when creating document table.")
            raise e

    async def upsert_documents_overview(
        self, documents_overview: DocumentResponse | list[DocumentResponse]
    ) -> None:
        if isinstance(documents_overview, DocumentResponse):
            documents_overview = [documents_overview]

        # TODO: make this an arg
        max_retries = 20
        for document in documents_overview:
            retries = 0
            while retries < max_retries:
                try:
                    async with (
                        self.connection_manager.pool.get_connection() as conn  # type: ignore
                    ):
                        async with conn.transaction():
                            # Lock the row for update
                            check_query = f"""
                            SELECT ingestion_attempt_number, ingestion_status FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
                            WHERE id = $1 FOR UPDATE
                            """
                            existing_doc = await conn.fetchrow(
                                check_query, document.id
                            )

                            db_entry = document.convert_to_db_entry()

                            if existing_doc:
                                db_version = existing_doc[
                                    "ingestion_attempt_number"
                                ]
                                db_status = existing_doc["ingestion_status"]
                                new_version = db_entry[
                                    "ingestion_attempt_number"
                                ]

                                # Only increment version if status is changing to 'success' or if it's a new version
                                if (
                                    db_status != "success"
                                    and db_entry["ingestion_status"]
                                    == "success"
                                ) or (new_version > db_version):
                                    new_attempt_number = db_version + 1
                                else:
                                    new_attempt_number = db_version

                                db_entry["ingestion_attempt_number"] = (
                                    new_attempt_number
                                )

                                update_query = f"""
                                UPDATE {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
                                SET collection_ids = $1,
                                    owner_id = $2,
                                    type = $3,
                                    metadata = $4,
                                    title = $5,
                                    version = $6,
                                    size_in_bytes = $7,
                                    ingestion_status = $8,
                                    extraction_status = $9,
                                    updated_at = $10,
                                    ingestion_attempt_number = $11,
                                    summary = $12,
                                    summary_embedding = $13,
                                    total_tokens = $14
                                WHERE id = $15
                                """

                                await conn.execute(
                                    update_query,
                                    db_entry["collection_ids"],
                                    db_entry["owner_id"],
                                    db_entry["document_type"],
                                    db_entry["metadata"],
                                    db_entry["title"],
                                    db_entry["version"],
                                    db_entry["size_in_bytes"],
                                    db_entry["ingestion_status"],
                                    db_entry["extraction_status"],
                                    db_entry["updated_at"],
                                    db_entry["ingestion_attempt_number"],
                                    db_entry["summary"],
                                    db_entry["summary_embedding"],
                                    db_entry[
                                        "total_tokens"
                                    ],  # pass the new field here
                                    document.id,
                                )
                            else:
                                insert_query = f"""
                                INSERT INTO {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
                                (id, collection_ids, owner_id, type, metadata, title, version,
                                size_in_bytes, ingestion_status, extraction_status, created_at,
                                updated_at, ingestion_attempt_number, summary, summary_embedding, total_tokens)
                                VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
                                """
                                await conn.execute(
                                    insert_query,
                                    db_entry["id"],
                                    db_entry["collection_ids"],
                                    db_entry["owner_id"],
                                    db_entry["document_type"],
                                    db_entry["metadata"],
                                    db_entry["title"],
                                    db_entry["version"],
                                    db_entry["size_in_bytes"],
                                    db_entry["ingestion_status"],
                                    db_entry["extraction_status"],
                                    db_entry["created_at"],
                                    db_entry["updated_at"],
                                    db_entry["ingestion_attempt_number"],
                                    db_entry["summary"],
                                    db_entry["summary_embedding"],
                                    db_entry["total_tokens"],
                                )

                    break  # Success, exit the retry loop
                except (
                    asyncpg.exceptions.UniqueViolationError,
                    asyncpg.exceptions.DeadlockDetectedError,
                ) as e:
                    retries += 1
                    if retries == max_retries:
                        logger.error(
                            f"Failed to update document {document.id} after {max_retries} attempts. Error: {str(e)}"
                        )
                        raise
                    else:
                        wait_time = 0.1 * (2**retries)  # Exponential backoff
                        await asyncio.sleep(wait_time)

    async def delete(
        self, document_id: UUID, version: Optional[str] = None
    ) -> None:
        query = f"""
        DELETE FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
        WHERE id = $1
        """

        params = [str(document_id)]

        if version:
            query += " AND version = $2"
            params.append(version)

        await self.connection_manager.execute_query(query=query, params=params)

    async def _get_status_from_table(
        self,
        ids: list[UUID],
        table_name: str,
        status_type: str,
        column_name: str,
    ):
        """Get the workflow status for a given document or list of documents.

        Args:
            ids (list[UUID]): The document IDs.
            table_name (str): The table name.
            status_type (str): The type of status to retrieve.

        Returns:
            The workflow status for the given document or list of documents.
        """
        query = f"""
            SELECT {status_type} FROM {self._get_table_name(table_name)}
            WHERE {column_name} = ANY($1)
        """
        return [
            row[status_type]
            for row in await self.connection_manager.fetch_query(query, [ids])
        ]

    async def _get_ids_from_table(
        self,
        status: list[str],
        table_name: str,
        status_type: str,
        collection_id: Optional[UUID] = None,
    ):
        """Get the IDs from a given table.

        Args:
            status (str | list[str]): The status or list of statuses to retrieve.
            table_name (str): The table name.
            status_type (str): The type of status to retrieve.
        """
        query = f"""
            SELECT id FROM {self._get_table_name(table_name)}
            WHERE {status_type} = ANY($1) and $2 = ANY(collection_ids)
        """
        records = await self.connection_manager.fetch_query(
            query, [status, collection_id]
        )
        return [record["id"] for record in records]

    async def _set_status_in_table(
        self,
        ids: list[UUID],
        status: str,
        table_name: str,
        status_type: str,
        column_name: str,
    ):
        """Set the workflow status for a given document or list of documents.

        Args:
            ids (list[UUID]): The document IDs.
            status (str): The status to set.
            table_name (str): The table name.
            status_type (str): The type of status to set.
            column_name (str): The column name in the table to update.
        """
        query = f"""
            UPDATE {self._get_table_name(table_name)}
            SET {status_type} = $1
            WHERE {column_name} = Any($2)
        """
        await self.connection_manager.execute_query(query, [status, ids])

    def _get_status_model(self, status_type: str):
        """Get the status model for a given status type.

        Args:
            status_type (str): The type of status to retrieve.

        Returns:
            The status model for the given status type.
        """
        if status_type == "ingestion":
            return IngestionStatus
        elif status_type == "extraction_status":
            return GraphExtractionStatus
        elif status_type in {"graph_cluster_status", "graph_sync_status"}:
            return GraphConstructionStatus
        else:
            raise R2RException(
                status_code=400, message=f"Invalid status type: {status_type}"
            )

    async def get_workflow_status(
        self, id: UUID | list[UUID], status_type: str
    ):
        """Get the workflow status for a given document or list of documents.

        Args:
            id (UUID | list[UUID]): The document ID or list of document IDs.
            status_type (str): The type of status to retrieve.

        Returns:
            The workflow status for the given document or list of documents.
        """

        ids = [id] if isinstance(id, UUID) else id
        out_model = self._get_status_model(status_type)
        result = await self._get_status_from_table(
            ids,
            out_model.table_name(),
            status_type,
            out_model.id_column(),
        )

        result = [out_model[status.upper()] for status in result]
        return result[0] if isinstance(id, UUID) else result

    async def set_workflow_status(
        self, id: UUID | list[UUID], status_type: str, status: str
    ):
        """Set the workflow status for a given document or list of documents.

        Args:
            id (UUID | list[UUID]): The document ID or list of document IDs.
            status_type (str): The type of status to set.
            status (str): The status to set.
        """
        ids = [id] if isinstance(id, UUID) else id
        out_model = self._get_status_model(status_type)

        return await self._set_status_in_table(
            ids,
            status,
            out_model.table_name(),
            status_type,
            out_model.id_column(),
        )

    async def get_document_ids_by_status(
        self,
        status_type: str,
        status: str | list[str],
        collection_id: Optional[UUID] = None,
    ):
        """Get the IDs for a given status.

        Args:
            ids_key (str): The key to retrieve the IDs.
            status_type (str): The type of status to retrieve.
            status (str | list[str]): The status or list of statuses to retrieve.
        """

        if isinstance(status, str):
            status = [status]

        out_model = self._get_status_model(status_type)
        return await self._get_ids_from_table(
            status, out_model.table_name(), status_type, collection_id
        )

    async def get_documents_overview(
        self,
        offset: int,
        limit: int,
        filter_user_ids: Optional[list[UUID]] = None,
        filter_document_ids: Optional[list[UUID]] = None,
        filter_collection_ids: Optional[list[UUID]] = None,
        include_summary_embedding: Optional[bool] = True,
        filters: Optional[dict[str, Any]] = None,
        sort_order: str = "DESC",  # Add this parameter with a default of DESC
    ) -> dict[str, Any]:
        """Fetch overviews of documents with optional offset/limit pagination.

        You can use either:
          - Traditional filters: `filter_user_ids`, `filter_document_ids`, `filter_collection_ids`
          - A `filters` dict (e.g., like we do in semantic search), which will be passed to `apply_filters`.

        If both the `filters` dict and any of the traditional filter arguments are provided,
        this method will raise an error.
        """

        filters = copy.deepcopy(filters)
        filters = transform_filter_fields(filters)  # type: ignore

        # Safety check: We do not allow mixing the old filter arguments with the new `filters` dict.
        # This keeps the query logic unambiguous.
        if filters and any(
            [
                filter_user_ids,
                filter_document_ids,
                filter_collection_ids,
            ]
        ):
            raise HTTPException(
                status_code=400,
                detail=(
                    "Cannot use both the 'filters' dictionary "
                    "and the 'filter_*_ids' parameters simultaneously."
                ),
            )

        conditions = []
        params: list[Any] = []
        param_index = 1

        # -------------------------------------------
        # 1) If using the new `filters` dict approach
        # -------------------------------------------
        if filters:
            # Apply the filters to generate a WHERE clause
            filter_condition, filter_params = apply_filters(
                filters, params, mode="condition_only"
            )
            if filter_condition:
                conditions.append(filter_condition)
            # Make sure we keep adding to the same params list
            # params.extend(filter_params)
            param_index += len(filter_params)

        # -------------------------------------------
        # 2) If using the old filter_*_ids approach
        # -------------------------------------------
        else:
            # Handle document IDs with AND
            if filter_document_ids:
                conditions.append(f"id = ANY(${param_index})")
                params.append(filter_document_ids)
                param_index += 1

            # For owner/collection filters, we used OR logic previously
            # so we combine them into a single sub-condition in parentheses
            or_conditions = []
            if filter_user_ids:
                or_conditions.append(f"owner_id = ANY(${param_index})")
                params.append(filter_user_ids)
                param_index += 1

            if filter_collection_ids:
                or_conditions.append(f"collection_ids && ${param_index}")
                params.append(filter_collection_ids)
                param_index += 1

            if or_conditions:
                conditions.append(f"({' OR '.join(or_conditions)})")

        # -------------------------
        # Build the full query
        # -------------------------
        base_query = (
            f"FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
        )
        if conditions:
            # Combine everything with AND
            base_query += " WHERE " + " AND ".join(conditions)

        # Construct SELECT fields (including total_entries via window function)
        select_fields = """
            SELECT
                id,
                collection_ids,
                owner_id,
                type,
                metadata,
                title,
                version,
                size_in_bytes,
                ingestion_status,
                extraction_status,
                created_at,
                updated_at,
                summary,
                summary_embedding,
                total_tokens,
                COUNT(*) OVER() AS total_entries
        """

        query = f"""
            {select_fields}
            {base_query}
            ORDER BY created_at {sort_order}
            OFFSET ${param_index}
        """
        params.append(offset)
        param_index += 1

        if limit != -1:
            query += f" LIMIT ${param_index}"
            params.append(limit)
            param_index += 1

        try:
            results = await self.connection_manager.fetch_query(query, params)
            total_entries = results[0]["total_entries"] if results else 0

            documents = []
            for row in results:
                # Safely handle the embedding
                embedding = None
                if (
                    "summary_embedding" in row
                    and row["summary_embedding"] is not None
                ):
                    try:
                        # The embedding is stored as a string like "[0.1, 0.2, ...]"
                        embedding_str = row["summary_embedding"]
                        if embedding_str.startswith(
                            "["
                        ) and embedding_str.endswith("]"):
                            embedding = [
                                float(x)
                                for x in embedding_str[1:-1].split(",")
                                if x
                            ]
                    except Exception as e:
                        logger.warning(
                            f"Failed to parse embedding for document {row['id']}: {e}"
                        )

                documents.append(
                    DocumentResponse(
                        id=row["id"],
                        collection_ids=row["collection_ids"],
                        owner_id=row["owner_id"],
                        document_type=DocumentType(row["type"]),
                        metadata=json.loads(row["metadata"]),
                        title=row["title"],
                        version=row["version"],
                        size_in_bytes=row["size_in_bytes"],
                        ingestion_status=IngestionStatus(
                            row["ingestion_status"]
                        ),
                        extraction_status=GraphExtractionStatus(
                            row["extraction_status"]
                        ),
                        created_at=row["created_at"],
                        updated_at=row["updated_at"],
                        summary=row["summary"] if "summary" in row else None,
                        summary_embedding=(
                            embedding if include_summary_embedding else None
                        ),
                        total_tokens=row["total_tokens"],
                    )
                )
            return {"results": documents, "total_entries": total_entries}
        except Exception as e:
            logger.error(f"Error in get_documents_overview: {str(e)}")
            raise HTTPException(
                status_code=500,
                detail="Database query failed",
            ) from e

    async def update_document_metadata(
        self,
        document_id: UUID,
        metadata: list[dict],
        overwrite: bool = False,
    ) -> DocumentResponse:
        """
        Update the metadata of a document, either by appending to existing metadata or overwriting it.
        Accepts a list of metadata dictionaries.
        """

        doc_result = await self.get_documents_overview(
            offset=0,
            limit=1,
            filter_document_ids=[document_id],
        )

        if not doc_result["results"]:
            raise HTTPException(
                status_code=404,
                detail=f"Document with ID {document_id} not found",
            )

        existing_doc = doc_result["results"][0]

        if overwrite:
            combined_metadata: dict[str, Any] = {}
            for meta_item in metadata:
                combined_metadata |= meta_item
            existing_doc.metadata = combined_metadata
        else:
            for meta_item in metadata:
                existing_doc.metadata.update(meta_item)

        await self.upsert_documents_overview(existing_doc)

        return existing_doc

    async def semantic_document_search(
        self, query_embedding: list[float], search_settings: SearchSettings
    ) -> list[DocumentResponse]:
        """Search documents using semantic similarity with their summary
        embeddings."""

        where_clauses = ["summary_embedding IS NOT NULL"]
        params: list[str | int | bytes] = [str(query_embedding)]

        vector_dim = (
            "" if math.isnan(self.dimension) else f"({self.dimension})"
        )
        filters = copy.deepcopy(search_settings.filters)
        if filters:
            filter_condition, params = apply_filters(
                transform_filter_fields(filters), params, mode="condition_only"
            )
            if filter_condition:
                where_clauses.append(filter_condition)

        where_clause = " AND ".join(where_clauses)

        query = f"""
        WITH document_scores AS (
            SELECT
                id,
                collection_ids,
                owner_id,
                type,
                metadata,
                title,
                version,
                size_in_bytes,
                ingestion_status,
                extraction_status,
                created_at,
                updated_at,
                summary,
                summary_embedding,
                total_tokens,
                (summary_embedding <=> $1::vector({vector_dim})) as semantic_distance
            FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
            WHERE {where_clause}
            ORDER BY semantic_distance ASC
            LIMIT ${len(params) + 1}
            OFFSET ${len(params) + 2}
        )
        SELECT *,
            1.0 - semantic_distance as semantic_score
        FROM document_scores
        """

        params.extend([search_settings.limit, search_settings.offset])

        results = await self.connection_manager.fetch_query(query, params)

        return [
            DocumentResponse(
                id=row["id"],
                collection_ids=row["collection_ids"],
                owner_id=row["owner_id"],
                document_type=DocumentType(row["type"]),
                metadata={
                    **(
                        json.loads(row["metadata"])
                        if search_settings.include_metadatas
                        else {}
                    ),
                    "search_score": float(row["semantic_score"]),
                    "search_type": "semantic",
                },
                title=row["title"],
                version=row["version"],
                size_in_bytes=row["size_in_bytes"],
                ingestion_status=IngestionStatus(row["ingestion_status"]),
                extraction_status=GraphExtractionStatus(
                    row["extraction_status"]
                ),
                created_at=row["created_at"],
                updated_at=row["updated_at"],
                summary=row["summary"],
                summary_embedding=[
                    float(x)
                    for x in row["summary_embedding"][1:-1].split(",")
                    if x
                ],
                total_tokens=row["total_tokens"],
            )
            for row in results
        ]

    async def full_text_document_search(
        self, query_text: str, search_settings: SearchSettings
    ) -> list[DocumentResponse]:
        """Enhanced full-text search using generated tsvector."""

        where_clauses = ["raw_tsvector @@ websearch_to_tsquery('english', $1)"]
        params: list[str | int | bytes] = [query_text]

        filters = copy.deepcopy(search_settings.filters)
        if filters:
            filter_condition, params = apply_filters(
                transform_filter_fields(filters), params, mode="condition_only"
            )
            if filter_condition:
                where_clauses.append(filter_condition)

        where_clause = " AND ".join(where_clauses)

        query = f"""
        WITH document_scores AS (
            SELECT
                id,
                collection_ids,
                owner_id,
                type,
                metadata,
                title,
                version,
                size_in_bytes,
                ingestion_status,
                extraction_status,
                created_at,
                updated_at,
                summary,
                summary_embedding,
                total_tokens,
                ts_rank_cd(raw_tsvector, websearch_to_tsquery('english', $1), 32) as text_score
            FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
            WHERE {where_clause}
            ORDER BY text_score DESC
            LIMIT ${len(params) + 1}
            OFFSET ${len(params) + 2}
        )
        SELECT * FROM document_scores
        """

        params.extend([search_settings.limit, search_settings.offset])

        results = await self.connection_manager.fetch_query(query, params)

        return [
            DocumentResponse(
                id=row["id"],
                collection_ids=row["collection_ids"],
                owner_id=row["owner_id"],
                document_type=DocumentType(row["type"]),
                metadata={
                    **(
                        json.loads(row["metadata"])
                        if search_settings.include_metadatas
                        else {}
                    ),
                    "search_score": float(row["text_score"]),
                    "search_type": "full_text",
                },
                title=row["title"],
                version=row["version"],
                size_in_bytes=row["size_in_bytes"],
                ingestion_status=IngestionStatus(row["ingestion_status"]),
                extraction_status=GraphExtractionStatus(
                    row["extraction_status"]
                ),
                created_at=row["created_at"],
                updated_at=row["updated_at"],
                summary=row["summary"],
                summary_embedding=(
                    [
                        float(x)
                        for x in row["summary_embedding"][1:-1].split(",")
                        if x
                    ]
                    if row["summary_embedding"]
                    else None
                ),
                total_tokens=row["total_tokens"],
            )
            for row in results
        ]

    async def hybrid_document_search(
        self,
        query_text: str,
        query_embedding: list[float],
        search_settings: SearchSettings,
    ) -> list[DocumentResponse]:
        """Search documents using both semantic and full-text search with RRF
        fusion."""

        # Get more results than needed for better fusion
        extended_settings = copy.deepcopy(search_settings)
        extended_settings.limit = search_settings.limit * 3

        # Get results from both search methods
        semantic_results = await self.semantic_document_search(
            query_embedding, extended_settings
        )
        full_text_results = await self.full_text_document_search(
            query_text, extended_settings
        )

        # Combine results using RRF
        doc_scores: dict[str, dict] = {}

        # Process semantic results
        for rank, result in enumerate(semantic_results, 1):
            doc_id = str(result.id)
            doc_scores[doc_id] = {
                "semantic_rank": rank,
                "full_text_rank": len(full_text_results)
                + 1,  # Default rank if not found
                "data": result,
            }

        # Process full-text results
        for rank, result in enumerate(full_text_results, 1):
            doc_id = str(result.id)
            if doc_id in doc_scores:
                doc_scores[doc_id]["full_text_rank"] = rank
            else:
                doc_scores[doc_id] = {
                    "semantic_rank": len(semantic_results)
                    + 1,  # Default rank if not found
                    "full_text_rank": rank,
                    "data": result,
                }

        # Calculate RRF scores using hybrid search settings
        rrf_k = search_settings.hybrid_settings.rrf_k
        semantic_weight = search_settings.hybrid_settings.semantic_weight
        full_text_weight = search_settings.hybrid_settings.full_text_weight

        for scores in doc_scores.values():
            semantic_score = 1 / (rrf_k + scores["semantic_rank"])
            full_text_score = 1 / (rrf_k + scores["full_text_rank"])

            # Weighted combination
            combined_score = (
                semantic_score * semantic_weight
                + full_text_score * full_text_weight
            ) / (semantic_weight + full_text_weight)

            scores["final_score"] = combined_score

        # Sort by final score and apply offset/limit
        sorted_results = sorted(
            doc_scores.values(), key=lambda x: x["final_score"], reverse=True
        )[
            search_settings.offset : search_settings.offset
            + search_settings.limit
        ]

        return [
            DocumentResponse(
                **{
                    **result["data"].__dict__,
                    "metadata": {
                        **(
                            result["data"].metadata
                            if search_settings.include_metadatas
                            else {}
                        ),
                        "search_score": result["final_score"],
                        "semantic_rank": result["semantic_rank"],
                        "full_text_rank": result["full_text_rank"],
                        "search_type": "hybrid",
                    },
                }
            )
            for result in sorted_results
        ]

    async def search_documents(
        self,
        query_text: str,
        query_embedding: Optional[list[float]] = None,
        settings: Optional[SearchSettings] = None,
    ) -> list[DocumentResponse]:
        """Main search method that delegates to the appropriate search method
        based on settings."""
        if settings is None:
            settings = SearchSettings()

        if (
            settings.use_semantic_search and settings.use_fulltext_search
        ) or settings.use_hybrid_search:
            if query_embedding is None:
                raise ValueError(
                    "query_embedding is required for hybrid search"
                )
            return await self.hybrid_document_search(
                query_text, query_embedding, settings
            )
        elif settings.use_semantic_search:
            if query_embedding is None:
                raise ValueError(
                    "query_embedding is required for vector search"
                )
            return await self.semantic_document_search(
                query_embedding, settings
            )
        else:
            return await self.full_text_document_search(query_text, settings)

    async def export_to_csv(
        self,
        columns: Optional[list[str]] = None,
        filters: Optional[dict] = None,
        include_header: bool = True,
    ) -> tuple[str, IO]:
        """Creates a CSV file from the PostgreSQL data and returns the path to
        the temp file."""
        valid_columns = {
            "id",
            "collection_ids",
            "owner_id",
            "type",
            "metadata",
            "title",
            "summary",
            "version",
            "size_in_bytes",
            "ingestion_status",
            "extraction_status",
            "created_at",
            "updated_at",
            "total_tokens",
        }
        filters = copy.deepcopy(filters)
        filters = transform_filter_fields(filters)  # type: ignore

        if not columns:
            columns = list(valid_columns)
        elif invalid_cols := set(columns) - valid_columns:
            raise ValueError(f"Invalid columns: {invalid_cols}")

        select_stmt = f"""
            SELECT
                id::text,
                collection_ids::text,
                owner_id::text,
                type::text,
                metadata::text AS metadata,
                title,
                summary,
                version,
                size_in_bytes,
                ingestion_status,
                extraction_status,
                to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
                to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at,
                total_tokens
            FROM {self._get_table_name(self.TABLE_NAME)}
        """

        conditions = []
        params: list[Any] = []
        param_index = 1

        if filters:
            for field, value in filters.items():
                if field not in valid_columns:
                    continue

                if isinstance(value, dict):
                    for op, val in value.items():
                        if op == "$eq":
                            conditions.append(f"{field} = ${param_index}")
                            params.append(val)
                            param_index += 1
                        elif op == "$gt":
                            conditions.append(f"{field} > ${param_index}")
                            params.append(val)
                            param_index += 1
                        elif op == "$lt":
                            conditions.append(f"{field} < ${param_index}")
                            params.append(val)
                            param_index += 1
                else:
                    # Direct equality
                    conditions.append(f"{field} = ${param_index}")
                    params.append(value)
                    param_index += 1

        if conditions:
            select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"

        select_stmt = f"{select_stmt} ORDER BY created_at DESC"

        temp_file = None
        try:
            temp_file = tempfile.NamedTemporaryFile(
                mode="w", delete=True, suffix=".csv"
            )
            writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)

            async with self.connection_manager.pool.get_connection() as conn:  # type: ignore
                async with conn.transaction():
                    cursor = await conn.cursor(select_stmt, *params)

                    if include_header:
                        writer.writerow(columns)

                    chunk_size = 1000
                    while True:
                        rows = await cursor.fetch(chunk_size)
                        if not rows:
                            break
                        for row in rows:
                            row_dict = {
                                "id": row[0],
                                "collection_ids": row[1],
                                "owner_id": row[2],
                                "type": row[3],
                                "metadata": row[4],
                                "title": row[5],
                                "summary": row[6],
                                "version": row[7],
                                "size_in_bytes": row[8],
                                "ingestion_status": row[9],
                                "extraction_status": row[10],
                                "created_at": row[11],
                                "updated_at": row[12],
                                "total_tokens": row[13],
                            }
                            writer.writerow([row_dict[col] for col in columns])

            temp_file.flush()
            return temp_file.name, temp_file

        except Exception as e:
            if temp_file:
                temp_file.close()
            raise HTTPException(
                status_code=500,
                detail=f"Failed to export data: {str(e)}",
            ) from e