aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/providers/database/documents.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/database/documents.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/database/documents.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/documents.py1172
1 files changed, 1172 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/documents.py b/.venv/lib/python3.12/site-packages/core/providers/database/documents.py
new file mode 100644
index 00000000..19781037
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/documents.py
@@ -0,0 +1,1172 @@
+import asyncio
+import copy
+import csv
+import json
+import logging
+import math
+import tempfile
+from typing import IO, Any, Optional
+from uuid import UUID
+
+import asyncpg
+from fastapi import HTTPException
+
+from core.base import (
+ DocumentResponse,
+ DocumentType,
+ GraphConstructionStatus,
+ GraphExtractionStatus,
+ Handler,
+ IngestionStatus,
+ R2RException,
+ SearchSettings,
+)
+
+from .base import PostgresConnectionManager
+from .filters import apply_filters
+
+logger = logging.getLogger()
+
+
+def transform_filter_fields(filters: dict[str, Any]) -> dict[str, Any]:
+ """Recursively transform filter field names by replacing 'document_id' with
+ 'id'. Handles nested logical operators like $and, $or, etc.
+
+ Args:
+ filters (dict[str, Any]): The original filters dictionary
+
+ Returns:
+ dict[str, Any]: A new dictionary with transformed field names
+ """
+ if not filters:
+ return {}
+
+ transformed = {}
+
+ for key, value in filters.items():
+ # Handle logical operators recursively
+ if key in ("$and", "$or", "$not"):
+ if isinstance(value, list):
+ transformed[key] = [
+ transform_filter_fields(item) for item in value
+ ]
+ else:
+ transformed[key] = transform_filter_fields(value) # type: ignore
+ continue
+
+ # Replace 'document_id' with 'id'
+ new_key = "id" if key == "document_id" else key
+
+ # Handle nested dictionary cases (e.g., for operators like $eq, $gt, etc.)
+ if isinstance(value, dict):
+ transformed[new_key] = transform_filter_fields(value) # type: ignore
+ else:
+ transformed[new_key] = value
+
+ logger.debug(f"Transformed filters from {filters} to {transformed}")
+ return transformed
+
+
+class PostgresDocumentsHandler(Handler):
+ TABLE_NAME = "documents"
+
+ def __init__(
+ self,
+ project_name: str,
+ connection_manager: PostgresConnectionManager,
+ dimension: int | float,
+ ):
+ self.dimension = dimension
+ super().__init__(project_name, connection_manager)
+
+ async def create_tables(self):
+ logger.info(
+ f"Creating table, if it does not exist: {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
+ )
+
+ vector_dim = (
+ "" if math.isnan(self.dimension) else f"({self.dimension})"
+ )
+ vector_type = f"vector{vector_dim}"
+
+ try:
+ query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} (
+ id UUID PRIMARY KEY,
+ collection_ids UUID[],
+ owner_id UUID,
+ type TEXT,
+ metadata JSONB,
+ title TEXT,
+ summary TEXT NULL,
+ summary_embedding {vector_type} NULL,
+ version TEXT,
+ size_in_bytes INT,
+ ingestion_status TEXT DEFAULT 'pending',
+ extraction_status TEXT DEFAULT 'pending',
+ created_at TIMESTAMPTZ DEFAULT NOW(),
+ updated_at TIMESTAMPTZ DEFAULT NOW(),
+ ingestion_attempt_number INT DEFAULT 0,
+ raw_tsvector tsvector GENERATED ALWAYS AS (
+ setweight(to_tsvector('english', COALESCE(title, '')), 'A') ||
+ setweight(to_tsvector('english', COALESCE(summary, '')), 'B') ||
+ setweight(to_tsvector('english', COALESCE((metadata->>'description')::text, '')), 'C')
+ ) STORED,
+ total_tokens INT DEFAULT 0
+ );
+ CREATE INDEX IF NOT EXISTS idx_collection_ids_{self.project_name}
+ ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} USING GIN (collection_ids);
+
+ -- Full text search index
+ CREATE INDEX IF NOT EXISTS idx_doc_search_{self.project_name}
+ ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ USING GIN (raw_tsvector);
+ """
+ await self.connection_manager.execute_query(query)
+
+ # ---------------------------------------------------------------
+ # Now check if total_tokens column exists in the 'documents' table
+ # ---------------------------------------------------------------
+ # 1) See what columns exist
+ # column_check_query = f"""
+ # SELECT column_name
+ # FROM information_schema.columns
+ # WHERE table_name = '{self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}'
+ # AND table_schema = CURRENT_SCHEMA()
+ # """
+ # existing_columns = await self.connection_manager.fetch_query(column_check_query)
+ # 2) Parse the table name for schema checks
+ table_full_name = self._get_table_name(
+ PostgresDocumentsHandler.TABLE_NAME
+ )
+ parsed_schema = "public"
+ parsed_table_name = table_full_name
+ if "." in table_full_name:
+ parts = table_full_name.split(".", maxsplit=1)
+ parsed_schema = parts[0].replace('"', "").strip()
+ parsed_table_name = parts[1].replace('"', "").strip()
+ else:
+ parsed_table_name = parsed_table_name.replace('"', "").strip()
+
+ # 3) Check columns
+ column_check_query = f"""
+ SELECT column_name
+ FROM information_schema.columns
+ WHERE table_name = '{parsed_table_name}'
+ AND table_schema = '{parsed_schema}'
+ """
+ existing_columns = await self.connection_manager.fetch_query(
+ column_check_query
+ )
+
+ existing_column_names = {
+ row["column_name"] for row in existing_columns
+ }
+
+ if "total_tokens" not in existing_column_names:
+ # 2) If missing, see if the table already has data
+ # doc_count_query = f"SELECT COUNT(*) FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
+ # doc_count = await self.connection_manager.fetchval(doc_count_query)
+ doc_count_query = f"SELECT COUNT(*) AS doc_count FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
+ row = await self.connection_manager.fetchrow_query(
+ doc_count_query
+ )
+ if row is None:
+ doc_count = 0
+ else:
+ doc_count = row[
+ "doc_count"
+ ] # or row[0] if you prefer positional indexing
+
+ if doc_count > 0:
+ # We already have documents, but no total_tokens column
+ # => ask user to run r2r db migrate
+ logger.warning(
+ "Adding the missing 'total_tokens' column to the 'documents' table, this will impact existing files."
+ )
+
+ create_tokens_col = f"""
+ ALTER TABLE {table_full_name}
+ ADD COLUMN total_tokens INT DEFAULT 0
+ """
+ await self.connection_manager.execute_query(create_tokens_col)
+
+ except Exception as e:
+ logger.warning(f"Error {e} when creating document table.")
+ raise e
+
+ async def upsert_documents_overview(
+ self, documents_overview: DocumentResponse | list[DocumentResponse]
+ ) -> None:
+ if isinstance(documents_overview, DocumentResponse):
+ documents_overview = [documents_overview]
+
+ # TODO: make this an arg
+ max_retries = 20
+ for document in documents_overview:
+ retries = 0
+ while retries < max_retries:
+ try:
+ async with (
+ self.connection_manager.pool.get_connection() as conn # type: ignore
+ ):
+ async with conn.transaction():
+ # Lock the row for update
+ check_query = f"""
+ SELECT ingestion_attempt_number, ingestion_status FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ WHERE id = $1 FOR UPDATE
+ """
+ existing_doc = await conn.fetchrow(
+ check_query, document.id
+ )
+
+ db_entry = document.convert_to_db_entry()
+
+ if existing_doc:
+ db_version = existing_doc[
+ "ingestion_attempt_number"
+ ]
+ db_status = existing_doc["ingestion_status"]
+ new_version = db_entry[
+ "ingestion_attempt_number"
+ ]
+
+ # Only increment version if status is changing to 'success' or if it's a new version
+ if (
+ db_status != "success"
+ and db_entry["ingestion_status"]
+ == "success"
+ ) or (new_version > db_version):
+ new_attempt_number = db_version + 1
+ else:
+ new_attempt_number = db_version
+
+ db_entry["ingestion_attempt_number"] = (
+ new_attempt_number
+ )
+
+ update_query = f"""
+ UPDATE {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ SET collection_ids = $1,
+ owner_id = $2,
+ type = $3,
+ metadata = $4,
+ title = $5,
+ version = $6,
+ size_in_bytes = $7,
+ ingestion_status = $8,
+ extraction_status = $9,
+ updated_at = $10,
+ ingestion_attempt_number = $11,
+ summary = $12,
+ summary_embedding = $13,
+ total_tokens = $14
+ WHERE id = $15
+ """
+
+ await conn.execute(
+ update_query,
+ db_entry["collection_ids"],
+ db_entry["owner_id"],
+ db_entry["document_type"],
+ db_entry["metadata"],
+ db_entry["title"],
+ db_entry["version"],
+ db_entry["size_in_bytes"],
+ db_entry["ingestion_status"],
+ db_entry["extraction_status"],
+ db_entry["updated_at"],
+ db_entry["ingestion_attempt_number"],
+ db_entry["summary"],
+ db_entry["summary_embedding"],
+ db_entry[
+ "total_tokens"
+ ], # pass the new field here
+ document.id,
+ )
+ else:
+ insert_query = f"""
+ INSERT INTO {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ (id, collection_ids, owner_id, type, metadata, title, version,
+ size_in_bytes, ingestion_status, extraction_status, created_at,
+ updated_at, ingestion_attempt_number, summary, summary_embedding, total_tokens)
+ VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
+ """
+ await conn.execute(
+ insert_query,
+ db_entry["id"],
+ db_entry["collection_ids"],
+ db_entry["owner_id"],
+ db_entry["document_type"],
+ db_entry["metadata"],
+ db_entry["title"],
+ db_entry["version"],
+ db_entry["size_in_bytes"],
+ db_entry["ingestion_status"],
+ db_entry["extraction_status"],
+ db_entry["created_at"],
+ db_entry["updated_at"],
+ db_entry["ingestion_attempt_number"],
+ db_entry["summary"],
+ db_entry["summary_embedding"],
+ db_entry["total_tokens"],
+ )
+
+ break # Success, exit the retry loop
+ except (
+ asyncpg.exceptions.UniqueViolationError,
+ asyncpg.exceptions.DeadlockDetectedError,
+ ) as e:
+ retries += 1
+ if retries == max_retries:
+ logger.error(
+ f"Failed to update document {document.id} after {max_retries} attempts. Error: {str(e)}"
+ )
+ raise
+ else:
+ wait_time = 0.1 * (2**retries) # Exponential backoff
+ await asyncio.sleep(wait_time)
+
+ async def delete(
+ self, document_id: UUID, version: Optional[str] = None
+ ) -> None:
+ query = f"""
+ DELETE FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ WHERE id = $1
+ """
+
+ params = [str(document_id)]
+
+ if version:
+ query += " AND version = $2"
+ params.append(version)
+
+ await self.connection_manager.execute_query(query=query, params=params)
+
+ async def _get_status_from_table(
+ self,
+ ids: list[UUID],
+ table_name: str,
+ status_type: str,
+ column_name: str,
+ ):
+ """Get the workflow status for a given document or list of documents.
+
+ Args:
+ ids (list[UUID]): The document IDs.
+ table_name (str): The table name.
+ status_type (str): The type of status to retrieve.
+
+ Returns:
+ The workflow status for the given document or list of documents.
+ """
+ query = f"""
+ SELECT {status_type} FROM {self._get_table_name(table_name)}
+ WHERE {column_name} = ANY($1)
+ """
+ return [
+ row[status_type]
+ for row in await self.connection_manager.fetch_query(query, [ids])
+ ]
+
+ async def _get_ids_from_table(
+ self,
+ status: list[str],
+ table_name: str,
+ status_type: str,
+ collection_id: Optional[UUID] = None,
+ ):
+ """Get the IDs from a given table.
+
+ Args:
+ status (str | list[str]): The status or list of statuses to retrieve.
+ table_name (str): The table name.
+ status_type (str): The type of status to retrieve.
+ """
+ query = f"""
+ SELECT id FROM {self._get_table_name(table_name)}
+ WHERE {status_type} = ANY($1) and $2 = ANY(collection_ids)
+ """
+ records = await self.connection_manager.fetch_query(
+ query, [status, collection_id]
+ )
+ return [record["id"] for record in records]
+
+ async def _set_status_in_table(
+ self,
+ ids: list[UUID],
+ status: str,
+ table_name: str,
+ status_type: str,
+ column_name: str,
+ ):
+ """Set the workflow status for a given document or list of documents.
+
+ Args:
+ ids (list[UUID]): The document IDs.
+ status (str): The status to set.
+ table_name (str): The table name.
+ status_type (str): The type of status to set.
+ column_name (str): The column name in the table to update.
+ """
+ query = f"""
+ UPDATE {self._get_table_name(table_name)}
+ SET {status_type} = $1
+ WHERE {column_name} = Any($2)
+ """
+ await self.connection_manager.execute_query(query, [status, ids])
+
+ def _get_status_model(self, status_type: str):
+ """Get the status model for a given status type.
+
+ Args:
+ status_type (str): The type of status to retrieve.
+
+ Returns:
+ The status model for the given status type.
+ """
+ if status_type == "ingestion":
+ return IngestionStatus
+ elif status_type == "extraction_status":
+ return GraphExtractionStatus
+ elif status_type in {"graph_cluster_status", "graph_sync_status"}:
+ return GraphConstructionStatus
+ else:
+ raise R2RException(
+ status_code=400, message=f"Invalid status type: {status_type}"
+ )
+
+ async def get_workflow_status(
+ self, id: UUID | list[UUID], status_type: str
+ ):
+ """Get the workflow status for a given document or list of documents.
+
+ Args:
+ id (UUID | list[UUID]): The document ID or list of document IDs.
+ status_type (str): The type of status to retrieve.
+
+ Returns:
+ The workflow status for the given document or list of documents.
+ """
+
+ ids = [id] if isinstance(id, UUID) else id
+ out_model = self._get_status_model(status_type)
+ result = await self._get_status_from_table(
+ ids,
+ out_model.table_name(),
+ status_type,
+ out_model.id_column(),
+ )
+
+ result = [out_model[status.upper()] for status in result]
+ return result[0] if isinstance(id, UUID) else result
+
+ async def set_workflow_status(
+ self, id: UUID | list[UUID], status_type: str, status: str
+ ):
+ """Set the workflow status for a given document or list of documents.
+
+ Args:
+ id (UUID | list[UUID]): The document ID or list of document IDs.
+ status_type (str): The type of status to set.
+ status (str): The status to set.
+ """
+ ids = [id] if isinstance(id, UUID) else id
+ out_model = self._get_status_model(status_type)
+
+ return await self._set_status_in_table(
+ ids,
+ status,
+ out_model.table_name(),
+ status_type,
+ out_model.id_column(),
+ )
+
+ async def get_document_ids_by_status(
+ self,
+ status_type: str,
+ status: str | list[str],
+ collection_id: Optional[UUID] = None,
+ ):
+ """Get the IDs for a given status.
+
+ Args:
+ ids_key (str): The key to retrieve the IDs.
+ status_type (str): The type of status to retrieve.
+ status (str | list[str]): The status or list of statuses to retrieve.
+ """
+
+ if isinstance(status, str):
+ status = [status]
+
+ out_model = self._get_status_model(status_type)
+ return await self._get_ids_from_table(
+ status, out_model.table_name(), status_type, collection_id
+ )
+
+ async def get_documents_overview(
+ self,
+ offset: int,
+ limit: int,
+ filter_user_ids: Optional[list[UUID]] = None,
+ filter_document_ids: Optional[list[UUID]] = None,
+ filter_collection_ids: Optional[list[UUID]] = None,
+ include_summary_embedding: Optional[bool] = True,
+ filters: Optional[dict[str, Any]] = None,
+ sort_order: str = "DESC", # Add this parameter with a default of DESC
+ ) -> dict[str, Any]:
+ """Fetch overviews of documents with optional offset/limit pagination.
+
+ You can use either:
+ - Traditional filters: `filter_user_ids`, `filter_document_ids`, `filter_collection_ids`
+ - A `filters` dict (e.g., like we do in semantic search), which will be passed to `apply_filters`.
+
+ If both the `filters` dict and any of the traditional filter arguments are provided,
+ this method will raise an error.
+ """
+
+ filters = copy.deepcopy(filters)
+ filters = transform_filter_fields(filters) # type: ignore
+
+ # Safety check: We do not allow mixing the old filter arguments with the new `filters` dict.
+ # This keeps the query logic unambiguous.
+ if filters and any(
+ [
+ filter_user_ids,
+ filter_document_ids,
+ filter_collection_ids,
+ ]
+ ):
+ raise HTTPException(
+ status_code=400,
+ detail=(
+ "Cannot use both the 'filters' dictionary "
+ "and the 'filter_*_ids' parameters simultaneously."
+ ),
+ )
+
+ conditions = []
+ params: list[Any] = []
+ param_index = 1
+
+ # -------------------------------------------
+ # 1) If using the new `filters` dict approach
+ # -------------------------------------------
+ if filters:
+ # Apply the filters to generate a WHERE clause
+ filter_condition, filter_params = apply_filters(
+ filters, params, mode="condition_only"
+ )
+ if filter_condition:
+ conditions.append(filter_condition)
+ # Make sure we keep adding to the same params list
+ # params.extend(filter_params)
+ param_index += len(filter_params)
+
+ # -------------------------------------------
+ # 2) If using the old filter_*_ids approach
+ # -------------------------------------------
+ else:
+ # Handle document IDs with AND
+ if filter_document_ids:
+ conditions.append(f"id = ANY(${param_index})")
+ params.append(filter_document_ids)
+ param_index += 1
+
+ # For owner/collection filters, we used OR logic previously
+ # so we combine them into a single sub-condition in parentheses
+ or_conditions = []
+ if filter_user_ids:
+ or_conditions.append(f"owner_id = ANY(${param_index})")
+ params.append(filter_user_ids)
+ param_index += 1
+
+ if filter_collection_ids:
+ or_conditions.append(f"collection_ids && ${param_index}")
+ params.append(filter_collection_ids)
+ param_index += 1
+
+ if or_conditions:
+ conditions.append(f"({' OR '.join(or_conditions)})")
+
+ # -------------------------
+ # Build the full query
+ # -------------------------
+ base_query = (
+ f"FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
+ )
+ if conditions:
+ # Combine everything with AND
+ base_query += " WHERE " + " AND ".join(conditions)
+
+ # Construct SELECT fields (including total_entries via window function)
+ select_fields = """
+ SELECT
+ id,
+ collection_ids,
+ owner_id,
+ type,
+ metadata,
+ title,
+ version,
+ size_in_bytes,
+ ingestion_status,
+ extraction_status,
+ created_at,
+ updated_at,
+ summary,
+ summary_embedding,
+ total_tokens,
+ COUNT(*) OVER() AS total_entries
+ """
+
+ query = f"""
+ {select_fields}
+ {base_query}
+ ORDER BY created_at {sort_order}
+ OFFSET ${param_index}
+ """
+ params.append(offset)
+ param_index += 1
+
+ if limit != -1:
+ query += f" LIMIT ${param_index}"
+ params.append(limit)
+ param_index += 1
+
+ try:
+ results = await self.connection_manager.fetch_query(query, params)
+ total_entries = results[0]["total_entries"] if results else 0
+
+ documents = []
+ for row in results:
+ # Safely handle the embedding
+ embedding = None
+ if (
+ "summary_embedding" in row
+ and row["summary_embedding"] is not None
+ ):
+ try:
+ # The embedding is stored as a string like "[0.1, 0.2, ...]"
+ embedding_str = row["summary_embedding"]
+ if embedding_str.startswith(
+ "["
+ ) and embedding_str.endswith("]"):
+ embedding = [
+ float(x)
+ for x in embedding_str[1:-1].split(",")
+ if x
+ ]
+ except Exception as e:
+ logger.warning(
+ f"Failed to parse embedding for document {row['id']}: {e}"
+ )
+
+ documents.append(
+ DocumentResponse(
+ id=row["id"],
+ collection_ids=row["collection_ids"],
+ owner_id=row["owner_id"],
+ document_type=DocumentType(row["type"]),
+ metadata=json.loads(row["metadata"]),
+ title=row["title"],
+ version=row["version"],
+ size_in_bytes=row["size_in_bytes"],
+ ingestion_status=IngestionStatus(
+ row["ingestion_status"]
+ ),
+ extraction_status=GraphExtractionStatus(
+ row["extraction_status"]
+ ),
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ summary=row["summary"] if "summary" in row else None,
+ summary_embedding=(
+ embedding if include_summary_embedding else None
+ ),
+ total_tokens=row["total_tokens"],
+ )
+ )
+ return {"results": documents, "total_entries": total_entries}
+ except Exception as e:
+ logger.error(f"Error in get_documents_overview: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail="Database query failed",
+ ) from e
+
+ async def update_document_metadata(
+ self,
+ document_id: UUID,
+ metadata: list[dict],
+ overwrite: bool = False,
+ ) -> DocumentResponse:
+ """
+ Update the metadata of a document, either by appending to existing metadata or overwriting it.
+ Accepts a list of metadata dictionaries.
+ """
+
+ doc_result = await self.get_documents_overview(
+ offset=0,
+ limit=1,
+ filter_document_ids=[document_id],
+ )
+
+ if not doc_result["results"]:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Document with ID {document_id} not found",
+ )
+
+ existing_doc = doc_result["results"][0]
+
+ if overwrite:
+ combined_metadata: dict[str, Any] = {}
+ for meta_item in metadata:
+ combined_metadata |= meta_item
+ existing_doc.metadata = combined_metadata
+ else:
+ for meta_item in metadata:
+ existing_doc.metadata.update(meta_item)
+
+ await self.upsert_documents_overview(existing_doc)
+
+ return existing_doc
+
+ async def semantic_document_search(
+ self, query_embedding: list[float], search_settings: SearchSettings
+ ) -> list[DocumentResponse]:
+ """Search documents using semantic similarity with their summary
+ embeddings."""
+
+ where_clauses = ["summary_embedding IS NOT NULL"]
+ params: list[str | int | bytes] = [str(query_embedding)]
+
+ vector_dim = (
+ "" if math.isnan(self.dimension) else f"({self.dimension})"
+ )
+ filters = copy.deepcopy(search_settings.filters)
+ if filters:
+ filter_condition, params = apply_filters(
+ transform_filter_fields(filters), params, mode="condition_only"
+ )
+ if filter_condition:
+ where_clauses.append(filter_condition)
+
+ where_clause = " AND ".join(where_clauses)
+
+ query = f"""
+ WITH document_scores AS (
+ SELECT
+ id,
+ collection_ids,
+ owner_id,
+ type,
+ metadata,
+ title,
+ version,
+ size_in_bytes,
+ ingestion_status,
+ extraction_status,
+ created_at,
+ updated_at,
+ summary,
+ summary_embedding,
+ total_tokens,
+ (summary_embedding <=> $1::vector({vector_dim})) as semantic_distance
+ FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ WHERE {where_clause}
+ ORDER BY semantic_distance ASC
+ LIMIT ${len(params) + 1}
+ OFFSET ${len(params) + 2}
+ )
+ SELECT *,
+ 1.0 - semantic_distance as semantic_score
+ FROM document_scores
+ """
+
+ params.extend([search_settings.limit, search_settings.offset])
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ return [
+ DocumentResponse(
+ id=row["id"],
+ collection_ids=row["collection_ids"],
+ owner_id=row["owner_id"],
+ document_type=DocumentType(row["type"]),
+ metadata={
+ **(
+ json.loads(row["metadata"])
+ if search_settings.include_metadatas
+ else {}
+ ),
+ "search_score": float(row["semantic_score"]),
+ "search_type": "semantic",
+ },
+ title=row["title"],
+ version=row["version"],
+ size_in_bytes=row["size_in_bytes"],
+ ingestion_status=IngestionStatus(row["ingestion_status"]),
+ extraction_status=GraphExtractionStatus(
+ row["extraction_status"]
+ ),
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ summary=row["summary"],
+ summary_embedding=[
+ float(x)
+ for x in row["summary_embedding"][1:-1].split(",")
+ if x
+ ],
+ total_tokens=row["total_tokens"],
+ )
+ for row in results
+ ]
+
+ async def full_text_document_search(
+ self, query_text: str, search_settings: SearchSettings
+ ) -> list[DocumentResponse]:
+ """Enhanced full-text search using generated tsvector."""
+
+ where_clauses = ["raw_tsvector @@ websearch_to_tsquery('english', $1)"]
+ params: list[str | int | bytes] = [query_text]
+
+ filters = copy.deepcopy(search_settings.filters)
+ if filters:
+ filter_condition, params = apply_filters(
+ transform_filter_fields(filters), params, mode="condition_only"
+ )
+ if filter_condition:
+ where_clauses.append(filter_condition)
+
+ where_clause = " AND ".join(where_clauses)
+
+ query = f"""
+ WITH document_scores AS (
+ SELECT
+ id,
+ collection_ids,
+ owner_id,
+ type,
+ metadata,
+ title,
+ version,
+ size_in_bytes,
+ ingestion_status,
+ extraction_status,
+ created_at,
+ updated_at,
+ summary,
+ summary_embedding,
+ total_tokens,
+ ts_rank_cd(raw_tsvector, websearch_to_tsquery('english', $1), 32) as text_score
+ FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
+ WHERE {where_clause}
+ ORDER BY text_score DESC
+ LIMIT ${len(params) + 1}
+ OFFSET ${len(params) + 2}
+ )
+ SELECT * FROM document_scores
+ """
+
+ params.extend([search_settings.limit, search_settings.offset])
+
+ results = await self.connection_manager.fetch_query(query, params)
+
+ return [
+ DocumentResponse(
+ id=row["id"],
+ collection_ids=row["collection_ids"],
+ owner_id=row["owner_id"],
+ document_type=DocumentType(row["type"]),
+ metadata={
+ **(
+ json.loads(row["metadata"])
+ if search_settings.include_metadatas
+ else {}
+ ),
+ "search_score": float(row["text_score"]),
+ "search_type": "full_text",
+ },
+ title=row["title"],
+ version=row["version"],
+ size_in_bytes=row["size_in_bytes"],
+ ingestion_status=IngestionStatus(row["ingestion_status"]),
+ extraction_status=GraphExtractionStatus(
+ row["extraction_status"]
+ ),
+ created_at=row["created_at"],
+ updated_at=row["updated_at"],
+ summary=row["summary"],
+ summary_embedding=(
+ [
+ float(x)
+ for x in row["summary_embedding"][1:-1].split(",")
+ if x
+ ]
+ if row["summary_embedding"]
+ else None
+ ),
+ total_tokens=row["total_tokens"],
+ )
+ for row in results
+ ]
+
+ async def hybrid_document_search(
+ self,
+ query_text: str,
+ query_embedding: list[float],
+ search_settings: SearchSettings,
+ ) -> list[DocumentResponse]:
+ """Search documents using both semantic and full-text search with RRF
+ fusion."""
+
+ # Get more results than needed for better fusion
+ extended_settings = copy.deepcopy(search_settings)
+ extended_settings.limit = search_settings.limit * 3
+
+ # Get results from both search methods
+ semantic_results = await self.semantic_document_search(
+ query_embedding, extended_settings
+ )
+ full_text_results = await self.full_text_document_search(
+ query_text, extended_settings
+ )
+
+ # Combine results using RRF
+ doc_scores: dict[str, dict] = {}
+
+ # Process semantic results
+ for rank, result in enumerate(semantic_results, 1):
+ doc_id = str(result.id)
+ doc_scores[doc_id] = {
+ "semantic_rank": rank,
+ "full_text_rank": len(full_text_results)
+ + 1, # Default rank if not found
+ "data": result,
+ }
+
+ # Process full-text results
+ for rank, result in enumerate(full_text_results, 1):
+ doc_id = str(result.id)
+ if doc_id in doc_scores:
+ doc_scores[doc_id]["full_text_rank"] = rank
+ else:
+ doc_scores[doc_id] = {
+ "semantic_rank": len(semantic_results)
+ + 1, # Default rank if not found
+ "full_text_rank": rank,
+ "data": result,
+ }
+
+ # Calculate RRF scores using hybrid search settings
+ rrf_k = search_settings.hybrid_settings.rrf_k
+ semantic_weight = search_settings.hybrid_settings.semantic_weight
+ full_text_weight = search_settings.hybrid_settings.full_text_weight
+
+ for scores in doc_scores.values():
+ semantic_score = 1 / (rrf_k + scores["semantic_rank"])
+ full_text_score = 1 / (rrf_k + scores["full_text_rank"])
+
+ # Weighted combination
+ combined_score = (
+ semantic_score * semantic_weight
+ + full_text_score * full_text_weight
+ ) / (semantic_weight + full_text_weight)
+
+ scores["final_score"] = combined_score
+
+ # Sort by final score and apply offset/limit
+ sorted_results = sorted(
+ doc_scores.values(), key=lambda x: x["final_score"], reverse=True
+ )[
+ search_settings.offset : search_settings.offset
+ + search_settings.limit
+ ]
+
+ return [
+ DocumentResponse(
+ **{
+ **result["data"].__dict__,
+ "metadata": {
+ **(
+ result["data"].metadata
+ if search_settings.include_metadatas
+ else {}
+ ),
+ "search_score": result["final_score"],
+ "semantic_rank": result["semantic_rank"],
+ "full_text_rank": result["full_text_rank"],
+ "search_type": "hybrid",
+ },
+ }
+ )
+ for result in sorted_results
+ ]
+
+ async def search_documents(
+ self,
+ query_text: str,
+ query_embedding: Optional[list[float]] = None,
+ settings: Optional[SearchSettings] = None,
+ ) -> list[DocumentResponse]:
+ """Main search method that delegates to the appropriate search method
+ based on settings."""
+ if settings is None:
+ settings = SearchSettings()
+
+ if (
+ settings.use_semantic_search and settings.use_fulltext_search
+ ) or settings.use_hybrid_search:
+ if query_embedding is None:
+ raise ValueError(
+ "query_embedding is required for hybrid search"
+ )
+ return await self.hybrid_document_search(
+ query_text, query_embedding, settings
+ )
+ elif settings.use_semantic_search:
+ if query_embedding is None:
+ raise ValueError(
+ "query_embedding is required for vector search"
+ )
+ return await self.semantic_document_search(
+ query_embedding, settings
+ )
+ else:
+ return await self.full_text_document_search(query_text, settings)
+
+ async def export_to_csv(
+ self,
+ columns: Optional[list[str]] = None,
+ filters: Optional[dict] = None,
+ include_header: bool = True,
+ ) -> tuple[str, IO]:
+ """Creates a CSV file from the PostgreSQL data and returns the path to
+ the temp file."""
+ valid_columns = {
+ "id",
+ "collection_ids",
+ "owner_id",
+ "type",
+ "metadata",
+ "title",
+ "summary",
+ "version",
+ "size_in_bytes",
+ "ingestion_status",
+ "extraction_status",
+ "created_at",
+ "updated_at",
+ "total_tokens",
+ }
+ filters = copy.deepcopy(filters)
+ filters = transform_filter_fields(filters) # type: ignore
+
+ if not columns:
+ columns = list(valid_columns)
+ elif invalid_cols := set(columns) - valid_columns:
+ raise ValueError(f"Invalid columns: {invalid_cols}")
+
+ select_stmt = f"""
+ SELECT
+ id::text,
+ collection_ids::text,
+ owner_id::text,
+ type::text,
+ metadata::text AS metadata,
+ title,
+ summary,
+ version,
+ size_in_bytes,
+ ingestion_status,
+ extraction_status,
+ to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
+ to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at,
+ total_tokens
+ FROM {self._get_table_name(self.TABLE_NAME)}
+ """
+
+ conditions = []
+ params: list[Any] = []
+ param_index = 1
+
+ if filters:
+ for field, value in filters.items():
+ if field not in valid_columns:
+ continue
+
+ if isinstance(value, dict):
+ for op, val in value.items():
+ if op == "$eq":
+ conditions.append(f"{field} = ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$gt":
+ conditions.append(f"{field} > ${param_index}")
+ params.append(val)
+ param_index += 1
+ elif op == "$lt":
+ conditions.append(f"{field} < ${param_index}")
+ params.append(val)
+ param_index += 1
+ else:
+ # Direct equality
+ conditions.append(f"{field} = ${param_index}")
+ params.append(value)
+ param_index += 1
+
+ if conditions:
+ select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
+
+ select_stmt = f"{select_stmt} ORDER BY created_at DESC"
+
+ temp_file = None
+ try:
+ temp_file = tempfile.NamedTemporaryFile(
+ mode="w", delete=True, suffix=".csv"
+ )
+ writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
+
+ async with self.connection_manager.pool.get_connection() as conn: # type: ignore
+ async with conn.transaction():
+ cursor = await conn.cursor(select_stmt, *params)
+
+ if include_header:
+ writer.writerow(columns)
+
+ chunk_size = 1000
+ while True:
+ rows = await cursor.fetch(chunk_size)
+ if not rows:
+ break
+ for row in rows:
+ row_dict = {
+ "id": row[0],
+ "collection_ids": row[1],
+ "owner_id": row[2],
+ "type": row[3],
+ "metadata": row[4],
+ "title": row[5],
+ "summary": row[6],
+ "version": row[7],
+ "size_in_bytes": row[8],
+ "ingestion_status": row[9],
+ "extraction_status": row[10],
+ "created_at": row[11],
+ "updated_at": row[12],
+ "total_tokens": row[13],
+ }
+ writer.writerow([row_dict[col] for col in columns])
+
+ temp_file.flush()
+ return temp_file.name, temp_file
+
+ except Exception as e:
+ if temp_file:
+ temp_file.close()
+ raise HTTPException(
+ status_code=500,
+ detail=f"Failed to export data: {str(e)}",
+ ) from e