diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/database/documents.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/database/documents.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/core/providers/database/documents.py | 1172 |
1 files changed, 1172 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/documents.py b/.venv/lib/python3.12/site-packages/core/providers/database/documents.py new file mode 100644 index 00000000..19781037 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/providers/database/documents.py @@ -0,0 +1,1172 @@ +import asyncio +import copy +import csv +import json +import logging +import math +import tempfile +from typing import IO, Any, Optional +from uuid import UUID + +import asyncpg +from fastapi import HTTPException + +from core.base import ( + DocumentResponse, + DocumentType, + GraphConstructionStatus, + GraphExtractionStatus, + Handler, + IngestionStatus, + R2RException, + SearchSettings, +) + +from .base import PostgresConnectionManager +from .filters import apply_filters + +logger = logging.getLogger() + + +def transform_filter_fields(filters: dict[str, Any]) -> dict[str, Any]: + """Recursively transform filter field names by replacing 'document_id' with + 'id'. Handles nested logical operators like $and, $or, etc. + + Args: + filters (dict[str, Any]): The original filters dictionary + + Returns: + dict[str, Any]: A new dictionary with transformed field names + """ + if not filters: + return {} + + transformed = {} + + for key, value in filters.items(): + # Handle logical operators recursively + if key in ("$and", "$or", "$not"): + if isinstance(value, list): + transformed[key] = [ + transform_filter_fields(item) for item in value + ] + else: + transformed[key] = transform_filter_fields(value) # type: ignore + continue + + # Replace 'document_id' with 'id' + new_key = "id" if key == "document_id" else key + + # Handle nested dictionary cases (e.g., for operators like $eq, $gt, etc.) + if isinstance(value, dict): + transformed[new_key] = transform_filter_fields(value) # type: ignore + else: + transformed[new_key] = value + + logger.debug(f"Transformed filters from {filters} to {transformed}") + return transformed + + +class PostgresDocumentsHandler(Handler): + TABLE_NAME = "documents" + + def __init__( + self, + project_name: str, + connection_manager: PostgresConnectionManager, + dimension: int | float, + ): + self.dimension = dimension + super().__init__(project_name, connection_manager) + + async def create_tables(self): + logger.info( + f"Creating table, if it does not exist: {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}" + ) + + vector_dim = ( + "" if math.isnan(self.dimension) else f"({self.dimension})" + ) + vector_type = f"vector{vector_dim}" + + try: + query = f""" + CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} ( + id UUID PRIMARY KEY, + collection_ids UUID[], + owner_id UUID, + type TEXT, + metadata JSONB, + title TEXT, + summary TEXT NULL, + summary_embedding {vector_type} NULL, + version TEXT, + size_in_bytes INT, + ingestion_status TEXT DEFAULT 'pending', + extraction_status TEXT DEFAULT 'pending', + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW(), + ingestion_attempt_number INT DEFAULT 0, + raw_tsvector tsvector GENERATED ALWAYS AS ( + setweight(to_tsvector('english', COALESCE(title, '')), 'A') || + setweight(to_tsvector('english', COALESCE(summary, '')), 'B') || + setweight(to_tsvector('english', COALESCE((metadata->>'description')::text, '')), 'C') + ) STORED, + total_tokens INT DEFAULT 0 + ); + CREATE INDEX IF NOT EXISTS idx_collection_ids_{self.project_name} + ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} USING GIN (collection_ids); + + -- Full text search index + CREATE INDEX IF NOT EXISTS idx_doc_search_{self.project_name} + ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} + USING GIN (raw_tsvector); + """ + await self.connection_manager.execute_query(query) + + # --------------------------------------------------------------- + # Now check if total_tokens column exists in the 'documents' table + # --------------------------------------------------------------- + # 1) See what columns exist + # column_check_query = f""" + # SELECT column_name + # FROM information_schema.columns + # WHERE table_name = '{self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}' + # AND table_schema = CURRENT_SCHEMA() + # """ + # existing_columns = await self.connection_manager.fetch_query(column_check_query) + # 2) Parse the table name for schema checks + table_full_name = self._get_table_name( + PostgresDocumentsHandler.TABLE_NAME + ) + parsed_schema = "public" + parsed_table_name = table_full_name + if "." in table_full_name: + parts = table_full_name.split(".", maxsplit=1) + parsed_schema = parts[0].replace('"', "").strip() + parsed_table_name = parts[1].replace('"', "").strip() + else: + parsed_table_name = parsed_table_name.replace('"', "").strip() + + # 3) Check columns + column_check_query = f""" + SELECT column_name + FROM information_schema.columns + WHERE table_name = '{parsed_table_name}' + AND table_schema = '{parsed_schema}' + """ + existing_columns = await self.connection_manager.fetch_query( + column_check_query + ) + + existing_column_names = { + row["column_name"] for row in existing_columns + } + + if "total_tokens" not in existing_column_names: + # 2) If missing, see if the table already has data + # doc_count_query = f"SELECT COUNT(*) FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}" + # doc_count = await self.connection_manager.fetchval(doc_count_query) + doc_count_query = f"SELECT COUNT(*) AS doc_count FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}" + row = await self.connection_manager.fetchrow_query( + doc_count_query + ) + if row is None: + doc_count = 0 + else: + doc_count = row[ + "doc_count" + ] # or row[0] if you prefer positional indexing + + if doc_count > 0: + # We already have documents, but no total_tokens column + # => ask user to run r2r db migrate + logger.warning( + "Adding the missing 'total_tokens' column to the 'documents' table, this will impact existing files." + ) + + create_tokens_col = f""" + ALTER TABLE {table_full_name} + ADD COLUMN total_tokens INT DEFAULT 0 + """ + await self.connection_manager.execute_query(create_tokens_col) + + except Exception as e: + logger.warning(f"Error {e} when creating document table.") + raise e + + async def upsert_documents_overview( + self, documents_overview: DocumentResponse | list[DocumentResponse] + ) -> None: + if isinstance(documents_overview, DocumentResponse): + documents_overview = [documents_overview] + + # TODO: make this an arg + max_retries = 20 + for document in documents_overview: + retries = 0 + while retries < max_retries: + try: + async with ( + self.connection_manager.pool.get_connection() as conn # type: ignore + ): + async with conn.transaction(): + # Lock the row for update + check_query = f""" + SELECT ingestion_attempt_number, ingestion_status FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} + WHERE id = $1 FOR UPDATE + """ + existing_doc = await conn.fetchrow( + check_query, document.id + ) + + db_entry = document.convert_to_db_entry() + + if existing_doc: + db_version = existing_doc[ + "ingestion_attempt_number" + ] + db_status = existing_doc["ingestion_status"] + new_version = db_entry[ + "ingestion_attempt_number" + ] + + # Only increment version if status is changing to 'success' or if it's a new version + if ( + db_status != "success" + and db_entry["ingestion_status"] + == "success" + ) or (new_version > db_version): + new_attempt_number = db_version + 1 + else: + new_attempt_number = db_version + + db_entry["ingestion_attempt_number"] = ( + new_attempt_number + ) + + update_query = f""" + UPDATE {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} + SET collection_ids = $1, + owner_id = $2, + type = $3, + metadata = $4, + title = $5, + version = $6, + size_in_bytes = $7, + ingestion_status = $8, + extraction_status = $9, + updated_at = $10, + ingestion_attempt_number = $11, + summary = $12, + summary_embedding = $13, + total_tokens = $14 + WHERE id = $15 + """ + + await conn.execute( + update_query, + db_entry["collection_ids"], + db_entry["owner_id"], + db_entry["document_type"], + db_entry["metadata"], + db_entry["title"], + db_entry["version"], + db_entry["size_in_bytes"], + db_entry["ingestion_status"], + db_entry["extraction_status"], + db_entry["updated_at"], + db_entry["ingestion_attempt_number"], + db_entry["summary"], + db_entry["summary_embedding"], + db_entry[ + "total_tokens" + ], # pass the new field here + document.id, + ) + else: + insert_query = f""" + INSERT INTO {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} + (id, collection_ids, owner_id, type, metadata, title, version, + size_in_bytes, ingestion_status, extraction_status, created_at, + updated_at, ingestion_attempt_number, summary, summary_embedding, total_tokens) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) + """ + await conn.execute( + insert_query, + db_entry["id"], + db_entry["collection_ids"], + db_entry["owner_id"], + db_entry["document_type"], + db_entry["metadata"], + db_entry["title"], + db_entry["version"], + db_entry["size_in_bytes"], + db_entry["ingestion_status"], + db_entry["extraction_status"], + db_entry["created_at"], + db_entry["updated_at"], + db_entry["ingestion_attempt_number"], + db_entry["summary"], + db_entry["summary_embedding"], + db_entry["total_tokens"], + ) + + break # Success, exit the retry loop + except ( + asyncpg.exceptions.UniqueViolationError, + asyncpg.exceptions.DeadlockDetectedError, + ) as e: + retries += 1 + if retries == max_retries: + logger.error( + f"Failed to update document {document.id} after {max_retries} attempts. Error: {str(e)}" + ) + raise + else: + wait_time = 0.1 * (2**retries) # Exponential backoff + await asyncio.sleep(wait_time) + + async def delete( + self, document_id: UUID, version: Optional[str] = None + ) -> None: + query = f""" + DELETE FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} + WHERE id = $1 + """ + + params = [str(document_id)] + + if version: + query += " AND version = $2" + params.append(version) + + await self.connection_manager.execute_query(query=query, params=params) + + async def _get_status_from_table( + self, + ids: list[UUID], + table_name: str, + status_type: str, + column_name: str, + ): + """Get the workflow status for a given document or list of documents. + + Args: + ids (list[UUID]): The document IDs. + table_name (str): The table name. + status_type (str): The type of status to retrieve. + + Returns: + The workflow status for the given document or list of documents. + """ + query = f""" + SELECT {status_type} FROM {self._get_table_name(table_name)} + WHERE {column_name} = ANY($1) + """ + return [ + row[status_type] + for row in await self.connection_manager.fetch_query(query, [ids]) + ] + + async def _get_ids_from_table( + self, + status: list[str], + table_name: str, + status_type: str, + collection_id: Optional[UUID] = None, + ): + """Get the IDs from a given table. + + Args: + status (str | list[str]): The status or list of statuses to retrieve. + table_name (str): The table name. + status_type (str): The type of status to retrieve. + """ + query = f""" + SELECT id FROM {self._get_table_name(table_name)} + WHERE {status_type} = ANY($1) and $2 = ANY(collection_ids) + """ + records = await self.connection_manager.fetch_query( + query, [status, collection_id] + ) + return [record["id"] for record in records] + + async def _set_status_in_table( + self, + ids: list[UUID], + status: str, + table_name: str, + status_type: str, + column_name: str, + ): + """Set the workflow status for a given document or list of documents. + + Args: + ids (list[UUID]): The document IDs. + status (str): The status to set. + table_name (str): The table name. + status_type (str): The type of status to set. + column_name (str): The column name in the table to update. + """ + query = f""" + UPDATE {self._get_table_name(table_name)} + SET {status_type} = $1 + WHERE {column_name} = Any($2) + """ + await self.connection_manager.execute_query(query, [status, ids]) + + def _get_status_model(self, status_type: str): + """Get the status model for a given status type. + + Args: + status_type (str): The type of status to retrieve. + + Returns: + The status model for the given status type. + """ + if status_type == "ingestion": + return IngestionStatus + elif status_type == "extraction_status": + return GraphExtractionStatus + elif status_type in {"graph_cluster_status", "graph_sync_status"}: + return GraphConstructionStatus + else: + raise R2RException( + status_code=400, message=f"Invalid status type: {status_type}" + ) + + async def get_workflow_status( + self, id: UUID | list[UUID], status_type: str + ): + """Get the workflow status for a given document or list of documents. + + Args: + id (UUID | list[UUID]): The document ID or list of document IDs. + status_type (str): The type of status to retrieve. + + Returns: + The workflow status for the given document or list of documents. + """ + + ids = [id] if isinstance(id, UUID) else id + out_model = self._get_status_model(status_type) + result = await self._get_status_from_table( + ids, + out_model.table_name(), + status_type, + out_model.id_column(), + ) + + result = [out_model[status.upper()] for status in result] + return result[0] if isinstance(id, UUID) else result + + async def set_workflow_status( + self, id: UUID | list[UUID], status_type: str, status: str + ): + """Set the workflow status for a given document or list of documents. + + Args: + id (UUID | list[UUID]): The document ID or list of document IDs. + status_type (str): The type of status to set. + status (str): The status to set. + """ + ids = [id] if isinstance(id, UUID) else id + out_model = self._get_status_model(status_type) + + return await self._set_status_in_table( + ids, + status, + out_model.table_name(), + status_type, + out_model.id_column(), + ) + + async def get_document_ids_by_status( + self, + status_type: str, + status: str | list[str], + collection_id: Optional[UUID] = None, + ): + """Get the IDs for a given status. + + Args: + ids_key (str): The key to retrieve the IDs. + status_type (str): The type of status to retrieve. + status (str | list[str]): The status or list of statuses to retrieve. + """ + + if isinstance(status, str): + status = [status] + + out_model = self._get_status_model(status_type) + return await self._get_ids_from_table( + status, out_model.table_name(), status_type, collection_id + ) + + async def get_documents_overview( + self, + offset: int, + limit: int, + filter_user_ids: Optional[list[UUID]] = None, + filter_document_ids: Optional[list[UUID]] = None, + filter_collection_ids: Optional[list[UUID]] = None, + include_summary_embedding: Optional[bool] = True, + filters: Optional[dict[str, Any]] = None, + sort_order: str = "DESC", # Add this parameter with a default of DESC + ) -> dict[str, Any]: + """Fetch overviews of documents with optional offset/limit pagination. + + You can use either: + - Traditional filters: `filter_user_ids`, `filter_document_ids`, `filter_collection_ids` + - A `filters` dict (e.g., like we do in semantic search), which will be passed to `apply_filters`. + + If both the `filters` dict and any of the traditional filter arguments are provided, + this method will raise an error. + """ + + filters = copy.deepcopy(filters) + filters = transform_filter_fields(filters) # type: ignore + + # Safety check: We do not allow mixing the old filter arguments with the new `filters` dict. + # This keeps the query logic unambiguous. + if filters and any( + [ + filter_user_ids, + filter_document_ids, + filter_collection_ids, + ] + ): + raise HTTPException( + status_code=400, + detail=( + "Cannot use both the 'filters' dictionary " + "and the 'filter_*_ids' parameters simultaneously." + ), + ) + + conditions = [] + params: list[Any] = [] + param_index = 1 + + # ------------------------------------------- + # 1) If using the new `filters` dict approach + # ------------------------------------------- + if filters: + # Apply the filters to generate a WHERE clause + filter_condition, filter_params = apply_filters( + filters, params, mode="condition_only" + ) + if filter_condition: + conditions.append(filter_condition) + # Make sure we keep adding to the same params list + # params.extend(filter_params) + param_index += len(filter_params) + + # ------------------------------------------- + # 2) If using the old filter_*_ids approach + # ------------------------------------------- + else: + # Handle document IDs with AND + if filter_document_ids: + conditions.append(f"id = ANY(${param_index})") + params.append(filter_document_ids) + param_index += 1 + + # For owner/collection filters, we used OR logic previously + # so we combine them into a single sub-condition in parentheses + or_conditions = [] + if filter_user_ids: + or_conditions.append(f"owner_id = ANY(${param_index})") + params.append(filter_user_ids) + param_index += 1 + + if filter_collection_ids: + or_conditions.append(f"collection_ids && ${param_index}") + params.append(filter_collection_ids) + param_index += 1 + + if or_conditions: + conditions.append(f"({' OR '.join(or_conditions)})") + + # ------------------------- + # Build the full query + # ------------------------- + base_query = ( + f"FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}" + ) + if conditions: + # Combine everything with AND + base_query += " WHERE " + " AND ".join(conditions) + + # Construct SELECT fields (including total_entries via window function) + select_fields = """ + SELECT + id, + collection_ids, + owner_id, + type, + metadata, + title, + version, + size_in_bytes, + ingestion_status, + extraction_status, + created_at, + updated_at, + summary, + summary_embedding, + total_tokens, + COUNT(*) OVER() AS total_entries + """ + + query = f""" + {select_fields} + {base_query} + ORDER BY created_at {sort_order} + OFFSET ${param_index} + """ + params.append(offset) + param_index += 1 + + if limit != -1: + query += f" LIMIT ${param_index}" + params.append(limit) + param_index += 1 + + try: + results = await self.connection_manager.fetch_query(query, params) + total_entries = results[0]["total_entries"] if results else 0 + + documents = [] + for row in results: + # Safely handle the embedding + embedding = None + if ( + "summary_embedding" in row + and row["summary_embedding"] is not None + ): + try: + # The embedding is stored as a string like "[0.1, 0.2, ...]" + embedding_str = row["summary_embedding"] + if embedding_str.startswith( + "[" + ) and embedding_str.endswith("]"): + embedding = [ + float(x) + for x in embedding_str[1:-1].split(",") + if x + ] + except Exception as e: + logger.warning( + f"Failed to parse embedding for document {row['id']}: {e}" + ) + + documents.append( + DocumentResponse( + id=row["id"], + collection_ids=row["collection_ids"], + owner_id=row["owner_id"], + document_type=DocumentType(row["type"]), + metadata=json.loads(row["metadata"]), + title=row["title"], + version=row["version"], + size_in_bytes=row["size_in_bytes"], + ingestion_status=IngestionStatus( + row["ingestion_status"] + ), + extraction_status=GraphExtractionStatus( + row["extraction_status"] + ), + created_at=row["created_at"], + updated_at=row["updated_at"], + summary=row["summary"] if "summary" in row else None, + summary_embedding=( + embedding if include_summary_embedding else None + ), + total_tokens=row["total_tokens"], + ) + ) + return {"results": documents, "total_entries": total_entries} + except Exception as e: + logger.error(f"Error in get_documents_overview: {str(e)}") + raise HTTPException( + status_code=500, + detail="Database query failed", + ) from e + + async def update_document_metadata( + self, + document_id: UUID, + metadata: list[dict], + overwrite: bool = False, + ) -> DocumentResponse: + """ + Update the metadata of a document, either by appending to existing metadata or overwriting it. + Accepts a list of metadata dictionaries. + """ + + doc_result = await self.get_documents_overview( + offset=0, + limit=1, + filter_document_ids=[document_id], + ) + + if not doc_result["results"]: + raise HTTPException( + status_code=404, + detail=f"Document with ID {document_id} not found", + ) + + existing_doc = doc_result["results"][0] + + if overwrite: + combined_metadata: dict[str, Any] = {} + for meta_item in metadata: + combined_metadata |= meta_item + existing_doc.metadata = combined_metadata + else: + for meta_item in metadata: + existing_doc.metadata.update(meta_item) + + await self.upsert_documents_overview(existing_doc) + + return existing_doc + + async def semantic_document_search( + self, query_embedding: list[float], search_settings: SearchSettings + ) -> list[DocumentResponse]: + """Search documents using semantic similarity with their summary + embeddings.""" + + where_clauses = ["summary_embedding IS NOT NULL"] + params: list[str | int | bytes] = [str(query_embedding)] + + vector_dim = ( + "" if math.isnan(self.dimension) else f"({self.dimension})" + ) + filters = copy.deepcopy(search_settings.filters) + if filters: + filter_condition, params = apply_filters( + transform_filter_fields(filters), params, mode="condition_only" + ) + if filter_condition: + where_clauses.append(filter_condition) + + where_clause = " AND ".join(where_clauses) + + query = f""" + WITH document_scores AS ( + SELECT + id, + collection_ids, + owner_id, + type, + metadata, + title, + version, + size_in_bytes, + ingestion_status, + extraction_status, + created_at, + updated_at, + summary, + summary_embedding, + total_tokens, + (summary_embedding <=> $1::vector({vector_dim})) as semantic_distance + FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} + WHERE {where_clause} + ORDER BY semantic_distance ASC + LIMIT ${len(params) + 1} + OFFSET ${len(params) + 2} + ) + SELECT *, + 1.0 - semantic_distance as semantic_score + FROM document_scores + """ + + params.extend([search_settings.limit, search_settings.offset]) + + results = await self.connection_manager.fetch_query(query, params) + + return [ + DocumentResponse( + id=row["id"], + collection_ids=row["collection_ids"], + owner_id=row["owner_id"], + document_type=DocumentType(row["type"]), + metadata={ + **( + json.loads(row["metadata"]) + if search_settings.include_metadatas + else {} + ), + "search_score": float(row["semantic_score"]), + "search_type": "semantic", + }, + title=row["title"], + version=row["version"], + size_in_bytes=row["size_in_bytes"], + ingestion_status=IngestionStatus(row["ingestion_status"]), + extraction_status=GraphExtractionStatus( + row["extraction_status"] + ), + created_at=row["created_at"], + updated_at=row["updated_at"], + summary=row["summary"], + summary_embedding=[ + float(x) + for x in row["summary_embedding"][1:-1].split(",") + if x + ], + total_tokens=row["total_tokens"], + ) + for row in results + ] + + async def full_text_document_search( + self, query_text: str, search_settings: SearchSettings + ) -> list[DocumentResponse]: + """Enhanced full-text search using generated tsvector.""" + + where_clauses = ["raw_tsvector @@ websearch_to_tsquery('english', $1)"] + params: list[str | int | bytes] = [query_text] + + filters = copy.deepcopy(search_settings.filters) + if filters: + filter_condition, params = apply_filters( + transform_filter_fields(filters), params, mode="condition_only" + ) + if filter_condition: + where_clauses.append(filter_condition) + + where_clause = " AND ".join(where_clauses) + + query = f""" + WITH document_scores AS ( + SELECT + id, + collection_ids, + owner_id, + type, + metadata, + title, + version, + size_in_bytes, + ingestion_status, + extraction_status, + created_at, + updated_at, + summary, + summary_embedding, + total_tokens, + ts_rank_cd(raw_tsvector, websearch_to_tsquery('english', $1), 32) as text_score + FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} + WHERE {where_clause} + ORDER BY text_score DESC + LIMIT ${len(params) + 1} + OFFSET ${len(params) + 2} + ) + SELECT * FROM document_scores + """ + + params.extend([search_settings.limit, search_settings.offset]) + + results = await self.connection_manager.fetch_query(query, params) + + return [ + DocumentResponse( + id=row["id"], + collection_ids=row["collection_ids"], + owner_id=row["owner_id"], + document_type=DocumentType(row["type"]), + metadata={ + **( + json.loads(row["metadata"]) + if search_settings.include_metadatas + else {} + ), + "search_score": float(row["text_score"]), + "search_type": "full_text", + }, + title=row["title"], + version=row["version"], + size_in_bytes=row["size_in_bytes"], + ingestion_status=IngestionStatus(row["ingestion_status"]), + extraction_status=GraphExtractionStatus( + row["extraction_status"] + ), + created_at=row["created_at"], + updated_at=row["updated_at"], + summary=row["summary"], + summary_embedding=( + [ + float(x) + for x in row["summary_embedding"][1:-1].split(",") + if x + ] + if row["summary_embedding"] + else None + ), + total_tokens=row["total_tokens"], + ) + for row in results + ] + + async def hybrid_document_search( + self, + query_text: str, + query_embedding: list[float], + search_settings: SearchSettings, + ) -> list[DocumentResponse]: + """Search documents using both semantic and full-text search with RRF + fusion.""" + + # Get more results than needed for better fusion + extended_settings = copy.deepcopy(search_settings) + extended_settings.limit = search_settings.limit * 3 + + # Get results from both search methods + semantic_results = await self.semantic_document_search( + query_embedding, extended_settings + ) + full_text_results = await self.full_text_document_search( + query_text, extended_settings + ) + + # Combine results using RRF + doc_scores: dict[str, dict] = {} + + # Process semantic results + for rank, result in enumerate(semantic_results, 1): + doc_id = str(result.id) + doc_scores[doc_id] = { + "semantic_rank": rank, + "full_text_rank": len(full_text_results) + + 1, # Default rank if not found + "data": result, + } + + # Process full-text results + for rank, result in enumerate(full_text_results, 1): + doc_id = str(result.id) + if doc_id in doc_scores: + doc_scores[doc_id]["full_text_rank"] = rank + else: + doc_scores[doc_id] = { + "semantic_rank": len(semantic_results) + + 1, # Default rank if not found + "full_text_rank": rank, + "data": result, + } + + # Calculate RRF scores using hybrid search settings + rrf_k = search_settings.hybrid_settings.rrf_k + semantic_weight = search_settings.hybrid_settings.semantic_weight + full_text_weight = search_settings.hybrid_settings.full_text_weight + + for scores in doc_scores.values(): + semantic_score = 1 / (rrf_k + scores["semantic_rank"]) + full_text_score = 1 / (rrf_k + scores["full_text_rank"]) + + # Weighted combination + combined_score = ( + semantic_score * semantic_weight + + full_text_score * full_text_weight + ) / (semantic_weight + full_text_weight) + + scores["final_score"] = combined_score + + # Sort by final score and apply offset/limit + sorted_results = sorted( + doc_scores.values(), key=lambda x: x["final_score"], reverse=True + )[ + search_settings.offset : search_settings.offset + + search_settings.limit + ] + + return [ + DocumentResponse( + **{ + **result["data"].__dict__, + "metadata": { + **( + result["data"].metadata + if search_settings.include_metadatas + else {} + ), + "search_score": result["final_score"], + "semantic_rank": result["semantic_rank"], + "full_text_rank": result["full_text_rank"], + "search_type": "hybrid", + }, + } + ) + for result in sorted_results + ] + + async def search_documents( + self, + query_text: str, + query_embedding: Optional[list[float]] = None, + settings: Optional[SearchSettings] = None, + ) -> list[DocumentResponse]: + """Main search method that delegates to the appropriate search method + based on settings.""" + if settings is None: + settings = SearchSettings() + + if ( + settings.use_semantic_search and settings.use_fulltext_search + ) or settings.use_hybrid_search: + if query_embedding is None: + raise ValueError( + "query_embedding is required for hybrid search" + ) + return await self.hybrid_document_search( + query_text, query_embedding, settings + ) + elif settings.use_semantic_search: + if query_embedding is None: + raise ValueError( + "query_embedding is required for vector search" + ) + return await self.semantic_document_search( + query_embedding, settings + ) + else: + return await self.full_text_document_search(query_text, settings) + + async def export_to_csv( + self, + columns: Optional[list[str]] = None, + filters: Optional[dict] = None, + include_header: bool = True, + ) -> tuple[str, IO]: + """Creates a CSV file from the PostgreSQL data and returns the path to + the temp file.""" + valid_columns = { + "id", + "collection_ids", + "owner_id", + "type", + "metadata", + "title", + "summary", + "version", + "size_in_bytes", + "ingestion_status", + "extraction_status", + "created_at", + "updated_at", + "total_tokens", + } + filters = copy.deepcopy(filters) + filters = transform_filter_fields(filters) # type: ignore + + if not columns: + columns = list(valid_columns) + elif invalid_cols := set(columns) - valid_columns: + raise ValueError(f"Invalid columns: {invalid_cols}") + + select_stmt = f""" + SELECT + id::text, + collection_ids::text, + owner_id::text, + type::text, + metadata::text AS metadata, + title, + summary, + version, + size_in_bytes, + ingestion_status, + extraction_status, + to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at, + to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at, + total_tokens + FROM {self._get_table_name(self.TABLE_NAME)} + """ + + conditions = [] + params: list[Any] = [] + param_index = 1 + + if filters: + for field, value in filters.items(): + if field not in valid_columns: + continue + + if isinstance(value, dict): + for op, val in value.items(): + if op == "$eq": + conditions.append(f"{field} = ${param_index}") + params.append(val) + param_index += 1 + elif op == "$gt": + conditions.append(f"{field} > ${param_index}") + params.append(val) + param_index += 1 + elif op == "$lt": + conditions.append(f"{field} < ${param_index}") + params.append(val) + param_index += 1 + else: + # Direct equality + conditions.append(f"{field} = ${param_index}") + params.append(value) + param_index += 1 + + if conditions: + select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}" + + select_stmt = f"{select_stmt} ORDER BY created_at DESC" + + temp_file = None + try: + temp_file = tempfile.NamedTemporaryFile( + mode="w", delete=True, suffix=".csv" + ) + writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL) + + async with self.connection_manager.pool.get_connection() as conn: # type: ignore + async with conn.transaction(): + cursor = await conn.cursor(select_stmt, *params) + + if include_header: + writer.writerow(columns) + + chunk_size = 1000 + while True: + rows = await cursor.fetch(chunk_size) + if not rows: + break + for row in rows: + row_dict = { + "id": row[0], + "collection_ids": row[1], + "owner_id": row[2], + "type": row[3], + "metadata": row[4], + "title": row[5], + "summary": row[6], + "version": row[7], + "size_in_bytes": row[8], + "ingestion_status": row[9], + "extraction_status": row[10], + "created_at": row[11], + "updated_at": row[12], + "total_tokens": row[13], + } + writer.writerow([row_dict[col] for col in columns]) + + temp_file.flush() + return temp_file.name, temp_file + + except Exception as e: + if temp_file: + temp_file.close() + raise HTTPException( + status_code=500, + detail=f"Failed to export data: {str(e)}", + ) from e |