import asyncio
import copy
import csv
import json
import logging
import math
import tempfile
from typing import IO, Any, Optional
from uuid import UUID
import asyncpg
from fastapi import HTTPException
from core.base import (
DocumentResponse,
DocumentType,
GraphConstructionStatus,
GraphExtractionStatus,
Handler,
IngestionStatus,
R2RException,
SearchSettings,
)
from .base import PostgresConnectionManager
from .filters import apply_filters
logger = logging.getLogger()
def transform_filter_fields(filters: dict[str, Any]) -> dict[str, Any]:
"""Recursively transform filter field names by replacing 'document_id' with
'id'. Handles nested logical operators like $and, $or, etc.
Args:
filters (dict[str, Any]): The original filters dictionary
Returns:
dict[str, Any]: A new dictionary with transformed field names
"""
if not filters:
return {}
transformed = {}
for key, value in filters.items():
# Handle logical operators recursively
if key in ("$and", "$or", "$not"):
if isinstance(value, list):
transformed[key] = [
transform_filter_fields(item) for item in value
]
else:
transformed[key] = transform_filter_fields(value) # type: ignore
continue
# Replace 'document_id' with 'id'
new_key = "id" if key == "document_id" else key
# Handle nested dictionary cases (e.g., for operators like $eq, $gt, etc.)
if isinstance(value, dict):
transformed[new_key] = transform_filter_fields(value) # type: ignore
else:
transformed[new_key] = value
logger.debug(f"Transformed filters from {filters} to {transformed}")
return transformed
class PostgresDocumentsHandler(Handler):
TABLE_NAME = "documents"
def __init__(
self,
project_name: str,
connection_manager: PostgresConnectionManager,
dimension: int | float,
):
self.dimension = dimension
super().__init__(project_name, connection_manager)
async def create_tables(self):
logger.info(
f"Creating table, if it does not exist: {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
)
vector_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)
vector_type = f"vector{vector_dim}"
try:
query = f"""
CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} (
id UUID PRIMARY KEY,
collection_ids UUID[],
owner_id UUID,
type TEXT,
metadata JSONB,
title TEXT,
summary TEXT NULL,
summary_embedding {vector_type} NULL,
version TEXT,
size_in_bytes INT,
ingestion_status TEXT DEFAULT 'pending',
extraction_status TEXT DEFAULT 'pending',
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
ingestion_attempt_number INT DEFAULT 0,
raw_tsvector tsvector GENERATED ALWAYS AS (
setweight(to_tsvector('english', COALESCE(title, '')), 'A') ||
setweight(to_tsvector('english', COALESCE(summary, '')), 'B') ||
setweight(to_tsvector('english', COALESCE((metadata->>'description')::text, '')), 'C')
) STORED,
total_tokens INT DEFAULT 0
);
CREATE INDEX IF NOT EXISTS idx_collection_ids_{self.project_name}
ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)} USING GIN (collection_ids);
-- Full text search index
CREATE INDEX IF NOT EXISTS idx_doc_search_{self.project_name}
ON {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
USING GIN (raw_tsvector);
"""
await self.connection_manager.execute_query(query)
# ---------------------------------------------------------------
# Now check if total_tokens column exists in the 'documents' table
# ---------------------------------------------------------------
# 1) See what columns exist
# column_check_query = f"""
# SELECT column_name
# FROM information_schema.columns
# WHERE table_name = '{self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}'
# AND table_schema = CURRENT_SCHEMA()
# """
# existing_columns = await self.connection_manager.fetch_query(column_check_query)
# 2) Parse the table name for schema checks
table_full_name = self._get_table_name(
PostgresDocumentsHandler.TABLE_NAME
)
parsed_schema = "public"
parsed_table_name = table_full_name
if "." in table_full_name:
parts = table_full_name.split(".", maxsplit=1)
parsed_schema = parts[0].replace('"', "").strip()
parsed_table_name = parts[1].replace('"', "").strip()
else:
parsed_table_name = parsed_table_name.replace('"', "").strip()
# 3) Check columns
column_check_query = f"""
SELECT column_name
FROM information_schema.columns
WHERE table_name = '{parsed_table_name}'
AND table_schema = '{parsed_schema}'
"""
existing_columns = await self.connection_manager.fetch_query(
column_check_query
)
existing_column_names = {
row["column_name"] for row in existing_columns
}
if "total_tokens" not in existing_column_names:
# 2) If missing, see if the table already has data
# doc_count_query = f"SELECT COUNT(*) FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
# doc_count = await self.connection_manager.fetchval(doc_count_query)
doc_count_query = f"SELECT COUNT(*) AS doc_count FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
row = await self.connection_manager.fetchrow_query(
doc_count_query
)
if row is None:
doc_count = 0
else:
doc_count = row[
"doc_count"
] # or row[0] if you prefer positional indexing
if doc_count > 0:
# We already have documents, but no total_tokens column
# => ask user to run r2r db migrate
logger.warning(
"Adding the missing 'total_tokens' column to the 'documents' table, this will impact existing files."
)
create_tokens_col = f"""
ALTER TABLE {table_full_name}
ADD COLUMN total_tokens INT DEFAULT 0
"""
await self.connection_manager.execute_query(create_tokens_col)
except Exception as e:
logger.warning(f"Error {e} when creating document table.")
raise e
async def upsert_documents_overview(
self, documents_overview: DocumentResponse | list[DocumentResponse]
) -> None:
if isinstance(documents_overview, DocumentResponse):
documents_overview = [documents_overview]
# TODO: make this an arg
max_retries = 20
for document in documents_overview:
retries = 0
while retries < max_retries:
try:
async with (
self.connection_manager.pool.get_connection() as conn # type: ignore
):
async with conn.transaction():
# Lock the row for update
check_query = f"""
SELECT ingestion_attempt_number, ingestion_status FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
WHERE id = $1 FOR UPDATE
"""
existing_doc = await conn.fetchrow(
check_query, document.id
)
db_entry = document.convert_to_db_entry()
if existing_doc:
db_version = existing_doc[
"ingestion_attempt_number"
]
db_status = existing_doc["ingestion_status"]
new_version = db_entry[
"ingestion_attempt_number"
]
# Only increment version if status is changing to 'success' or if it's a new version
if (
db_status != "success"
and db_entry["ingestion_status"]
== "success"
) or (new_version > db_version):
new_attempt_number = db_version + 1
else:
new_attempt_number = db_version
db_entry["ingestion_attempt_number"] = (
new_attempt_number
)
update_query = f"""
UPDATE {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
SET collection_ids = $1,
owner_id = $2,
type = $3,
metadata = $4,
title = $5,
version = $6,
size_in_bytes = $7,
ingestion_status = $8,
extraction_status = $9,
updated_at = $10,
ingestion_attempt_number = $11,
summary = $12,
summary_embedding = $13,
total_tokens = $14
WHERE id = $15
"""
await conn.execute(
update_query,
db_entry["collection_ids"],
db_entry["owner_id"],
db_entry["document_type"],
db_entry["metadata"],
db_entry["title"],
db_entry["version"],
db_entry["size_in_bytes"],
db_entry["ingestion_status"],
db_entry["extraction_status"],
db_entry["updated_at"],
db_entry["ingestion_attempt_number"],
db_entry["summary"],
db_entry["summary_embedding"],
db_entry[
"total_tokens"
], # pass the new field here
document.id,
)
else:
insert_query = f"""
INSERT INTO {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
(id, collection_ids, owner_id, type, metadata, title, version,
size_in_bytes, ingestion_status, extraction_status, created_at,
updated_at, ingestion_attempt_number, summary, summary_embedding, total_tokens)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
"""
await conn.execute(
insert_query,
db_entry["id"],
db_entry["collection_ids"],
db_entry["owner_id"],
db_entry["document_type"],
db_entry["metadata"],
db_entry["title"],
db_entry["version"],
db_entry["size_in_bytes"],
db_entry["ingestion_status"],
db_entry["extraction_status"],
db_entry["created_at"],
db_entry["updated_at"],
db_entry["ingestion_attempt_number"],
db_entry["summary"],
db_entry["summary_embedding"],
db_entry["total_tokens"],
)
break # Success, exit the retry loop
except (
asyncpg.exceptions.UniqueViolationError,
asyncpg.exceptions.DeadlockDetectedError,
) as e:
retries += 1
if retries == max_retries:
logger.error(
f"Failed to update document {document.id} after {max_retries} attempts. Error: {str(e)}"
)
raise
else:
wait_time = 0.1 * (2**retries) # Exponential backoff
await asyncio.sleep(wait_time)
async def delete(
self, document_id: UUID, version: Optional[str] = None
) -> None:
query = f"""
DELETE FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
WHERE id = $1
"""
params = [str(document_id)]
if version:
query += " AND version = $2"
params.append(version)
await self.connection_manager.execute_query(query=query, params=params)
async def _get_status_from_table(
self,
ids: list[UUID],
table_name: str,
status_type: str,
column_name: str,
):
"""Get the workflow status for a given document or list of documents.
Args:
ids (list[UUID]): The document IDs.
table_name (str): The table name.
status_type (str): The type of status to retrieve.
Returns:
The workflow status for the given document or list of documents.
"""
query = f"""
SELECT {status_type} FROM {self._get_table_name(table_name)}
WHERE {column_name} = ANY($1)
"""
return [
row[status_type]
for row in await self.connection_manager.fetch_query(query, [ids])
]
async def _get_ids_from_table(
self,
status: list[str],
table_name: str,
status_type: str,
collection_id: Optional[UUID] = None,
):
"""Get the IDs from a given table.
Args:
status (str | list[str]): The status or list of statuses to retrieve.
table_name (str): The table name.
status_type (str): The type of status to retrieve.
"""
query = f"""
SELECT id FROM {self._get_table_name(table_name)}
WHERE {status_type} = ANY($1) and $2 = ANY(collection_ids)
"""
records = await self.connection_manager.fetch_query(
query, [status, collection_id]
)
return [record["id"] for record in records]
async def _set_status_in_table(
self,
ids: list[UUID],
status: str,
table_name: str,
status_type: str,
column_name: str,
):
"""Set the workflow status for a given document or list of documents.
Args:
ids (list[UUID]): The document IDs.
status (str): The status to set.
table_name (str): The table name.
status_type (str): The type of status to set.
column_name (str): The column name in the table to update.
"""
query = f"""
UPDATE {self._get_table_name(table_name)}
SET {status_type} = $1
WHERE {column_name} = Any($2)
"""
await self.connection_manager.execute_query(query, [status, ids])
def _get_status_model(self, status_type: str):
"""Get the status model for a given status type.
Args:
status_type (str): The type of status to retrieve.
Returns:
The status model for the given status type.
"""
if status_type == "ingestion":
return IngestionStatus
elif status_type == "extraction_status":
return GraphExtractionStatus
elif status_type in {"graph_cluster_status", "graph_sync_status"}:
return GraphConstructionStatus
else:
raise R2RException(
status_code=400, message=f"Invalid status type: {status_type}"
)
async def get_workflow_status(
self, id: UUID | list[UUID], status_type: str
):
"""Get the workflow status for a given document or list of documents.
Args:
id (UUID | list[UUID]): The document ID or list of document IDs.
status_type (str): The type of status to retrieve.
Returns:
The workflow status for the given document or list of documents.
"""
ids = [id] if isinstance(id, UUID) else id
out_model = self._get_status_model(status_type)
result = await self._get_status_from_table(
ids,
out_model.table_name(),
status_type,
out_model.id_column(),
)
result = [out_model[status.upper()] for status in result]
return result[0] if isinstance(id, UUID) else result
async def set_workflow_status(
self, id: UUID | list[UUID], status_type: str, status: str
):
"""Set the workflow status for a given document or list of documents.
Args:
id (UUID | list[UUID]): The document ID or list of document IDs.
status_type (str): The type of status to set.
status (str): The status to set.
"""
ids = [id] if isinstance(id, UUID) else id
out_model = self._get_status_model(status_type)
return await self._set_status_in_table(
ids,
status,
out_model.table_name(),
status_type,
out_model.id_column(),
)
async def get_document_ids_by_status(
self,
status_type: str,
status: str | list[str],
collection_id: Optional[UUID] = None,
):
"""Get the IDs for a given status.
Args:
ids_key (str): The key to retrieve the IDs.
status_type (str): The type of status to retrieve.
status (str | list[str]): The status or list of statuses to retrieve.
"""
if isinstance(status, str):
status = [status]
out_model = self._get_status_model(status_type)
return await self._get_ids_from_table(
status, out_model.table_name(), status_type, collection_id
)
async def get_documents_overview(
self,
offset: int,
limit: int,
filter_user_ids: Optional[list[UUID]] = None,
filter_document_ids: Optional[list[UUID]] = None,
filter_collection_ids: Optional[list[UUID]] = None,
include_summary_embedding: Optional[bool] = True,
filters: Optional[dict[str, Any]] = None,
sort_order: str = "DESC", # Add this parameter with a default of DESC
) -> dict[str, Any]:
"""Fetch overviews of documents with optional offset/limit pagination.
You can use either:
- Traditional filters: `filter_user_ids`, `filter_document_ids`, `filter_collection_ids`
- A `filters` dict (e.g., like we do in semantic search), which will be passed to `apply_filters`.
If both the `filters` dict and any of the traditional filter arguments are provided,
this method will raise an error.
"""
filters = copy.deepcopy(filters)
filters = transform_filter_fields(filters) # type: ignore
# Safety check: We do not allow mixing the old filter arguments with the new `filters` dict.
# This keeps the query logic unambiguous.
if filters and any(
[
filter_user_ids,
filter_document_ids,
filter_collection_ids,
]
):
raise HTTPException(
status_code=400,
detail=(
"Cannot use both the 'filters' dictionary "
"and the 'filter_*_ids' parameters simultaneously."
),
)
conditions = []
params: list[Any] = []
param_index = 1
# -------------------------------------------
# 1) If using the new `filters` dict approach
# -------------------------------------------
if filters:
# Apply the filters to generate a WHERE clause
filter_condition, filter_params = apply_filters(
filters, params, mode="condition_only"
)
if filter_condition:
conditions.append(filter_condition)
# Make sure we keep adding to the same params list
# params.extend(filter_params)
param_index += len(filter_params)
# -------------------------------------------
# 2) If using the old filter_*_ids approach
# -------------------------------------------
else:
# Handle document IDs with AND
if filter_document_ids:
conditions.append(f"id = ANY(${param_index})")
params.append(filter_document_ids)
param_index += 1
# For owner/collection filters, we used OR logic previously
# so we combine them into a single sub-condition in parentheses
or_conditions = []
if filter_user_ids:
or_conditions.append(f"owner_id = ANY(${param_index})")
params.append(filter_user_ids)
param_index += 1
if filter_collection_ids:
or_conditions.append(f"collection_ids && ${param_index}")
params.append(filter_collection_ids)
param_index += 1
if or_conditions:
conditions.append(f"({' OR '.join(or_conditions)})")
# -------------------------
# Build the full query
# -------------------------
base_query = (
f"FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}"
)
if conditions:
# Combine everything with AND
base_query += " WHERE " + " AND ".join(conditions)
# Construct SELECT fields (including total_entries via window function)
select_fields = """
SELECT
id,
collection_ids,
owner_id,
type,
metadata,
title,
version,
size_in_bytes,
ingestion_status,
extraction_status,
created_at,
updated_at,
summary,
summary_embedding,
total_tokens,
COUNT(*) OVER() AS total_entries
"""
query = f"""
{select_fields}
{base_query}
ORDER BY created_at {sort_order}
OFFSET ${param_index}
"""
params.append(offset)
param_index += 1
if limit != -1:
query += f" LIMIT ${param_index}"
params.append(limit)
param_index += 1
try:
results = await self.connection_manager.fetch_query(query, params)
total_entries = results[0]["total_entries"] if results else 0
documents = []
for row in results:
# Safely handle the embedding
embedding = None
if (
"summary_embedding" in row
and row["summary_embedding"] is not None
):
try:
# The embedding is stored as a string like "[0.1, 0.2, ...]"
embedding_str = row["summary_embedding"]
if embedding_str.startswith(
"["
) and embedding_str.endswith("]"):
embedding = [
float(x)
for x in embedding_str[1:-1].split(",")
if x
]
except Exception as e:
logger.warning(
f"Failed to parse embedding for document {row['id']}: {e}"
)
documents.append(
DocumentResponse(
id=row["id"],
collection_ids=row["collection_ids"],
owner_id=row["owner_id"],
document_type=DocumentType(row["type"]),
metadata=json.loads(row["metadata"]),
title=row["title"],
version=row["version"],
size_in_bytes=row["size_in_bytes"],
ingestion_status=IngestionStatus(
row["ingestion_status"]
),
extraction_status=GraphExtractionStatus(
row["extraction_status"]
),
created_at=row["created_at"],
updated_at=row["updated_at"],
summary=row["summary"] if "summary" in row else None,
summary_embedding=(
embedding if include_summary_embedding else None
),
total_tokens=row["total_tokens"],
)
)
return {"results": documents, "total_entries": total_entries}
except Exception as e:
logger.error(f"Error in get_documents_overview: {str(e)}")
raise HTTPException(
status_code=500,
detail="Database query failed",
) from e
async def update_document_metadata(
self,
document_id: UUID,
metadata: list[dict],
overwrite: bool = False,
) -> DocumentResponse:
"""
Update the metadata of a document, either by appending to existing metadata or overwriting it.
Accepts a list of metadata dictionaries.
"""
doc_result = await self.get_documents_overview(
offset=0,
limit=1,
filter_document_ids=[document_id],
)
if not doc_result["results"]:
raise HTTPException(
status_code=404,
detail=f"Document with ID {document_id} not found",
)
existing_doc = doc_result["results"][0]
if overwrite:
combined_metadata: dict[str, Any] = {}
for meta_item in metadata:
combined_metadata |= meta_item
existing_doc.metadata = combined_metadata
else:
for meta_item in metadata:
existing_doc.metadata.update(meta_item)
await self.upsert_documents_overview(existing_doc)
return existing_doc
async def semantic_document_search(
self, query_embedding: list[float], search_settings: SearchSettings
) -> list[DocumentResponse]:
"""Search documents using semantic similarity with their summary
embeddings."""
where_clauses = ["summary_embedding IS NOT NULL"]
params: list[str | int | bytes] = [str(query_embedding)]
vector_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)
filters = copy.deepcopy(search_settings.filters)
if filters:
filter_condition, params = apply_filters(
transform_filter_fields(filters), params, mode="condition_only"
)
if filter_condition:
where_clauses.append(filter_condition)
where_clause = " AND ".join(where_clauses)
query = f"""
WITH document_scores AS (
SELECT
id,
collection_ids,
owner_id,
type,
metadata,
title,
version,
size_in_bytes,
ingestion_status,
extraction_status,
created_at,
updated_at,
summary,
summary_embedding,
total_tokens,
(summary_embedding <=> $1::vector({vector_dim})) as semantic_distance
FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
WHERE {where_clause}
ORDER BY semantic_distance ASC
LIMIT ${len(params) + 1}
OFFSET ${len(params) + 2}
)
SELECT *,
1.0 - semantic_distance as semantic_score
FROM document_scores
"""
params.extend([search_settings.limit, search_settings.offset])
results = await self.connection_manager.fetch_query(query, params)
return [
DocumentResponse(
id=row["id"],
collection_ids=row["collection_ids"],
owner_id=row["owner_id"],
document_type=DocumentType(row["type"]),
metadata={
**(
json.loads(row["metadata"])
if search_settings.include_metadatas
else {}
),
"search_score": float(row["semantic_score"]),
"search_type": "semantic",
},
title=row["title"],
version=row["version"],
size_in_bytes=row["size_in_bytes"],
ingestion_status=IngestionStatus(row["ingestion_status"]),
extraction_status=GraphExtractionStatus(
row["extraction_status"]
),
created_at=row["created_at"],
updated_at=row["updated_at"],
summary=row["summary"],
summary_embedding=[
float(x)
for x in row["summary_embedding"][1:-1].split(",")
if x
],
total_tokens=row["total_tokens"],
)
for row in results
]
async def full_text_document_search(
self, query_text: str, search_settings: SearchSettings
) -> list[DocumentResponse]:
"""Enhanced full-text search using generated tsvector."""
where_clauses = ["raw_tsvector @@ websearch_to_tsquery('english', $1)"]
params: list[str | int | bytes] = [query_text]
filters = copy.deepcopy(search_settings.filters)
if filters:
filter_condition, params = apply_filters(
transform_filter_fields(filters), params, mode="condition_only"
)
if filter_condition:
where_clauses.append(filter_condition)
where_clause = " AND ".join(where_clauses)
query = f"""
WITH document_scores AS (
SELECT
id,
collection_ids,
owner_id,
type,
metadata,
title,
version,
size_in_bytes,
ingestion_status,
extraction_status,
created_at,
updated_at,
summary,
summary_embedding,
total_tokens,
ts_rank_cd(raw_tsvector, websearch_to_tsquery('english', $1), 32) as text_score
FROM {self._get_table_name(PostgresDocumentsHandler.TABLE_NAME)}
WHERE {where_clause}
ORDER BY text_score DESC
LIMIT ${len(params) + 1}
OFFSET ${len(params) + 2}
)
SELECT * FROM document_scores
"""
params.extend([search_settings.limit, search_settings.offset])
results = await self.connection_manager.fetch_query(query, params)
return [
DocumentResponse(
id=row["id"],
collection_ids=row["collection_ids"],
owner_id=row["owner_id"],
document_type=DocumentType(row["type"]),
metadata={
**(
json.loads(row["metadata"])
if search_settings.include_metadatas
else {}
),
"search_score": float(row["text_score"]),
"search_type": "full_text",
},
title=row["title"],
version=row["version"],
size_in_bytes=row["size_in_bytes"],
ingestion_status=IngestionStatus(row["ingestion_status"]),
extraction_status=GraphExtractionStatus(
row["extraction_status"]
),
created_at=row["created_at"],
updated_at=row["updated_at"],
summary=row["summary"],
summary_embedding=(
[
float(x)
for x in row["summary_embedding"][1:-1].split(",")
if x
]
if row["summary_embedding"]
else None
),
total_tokens=row["total_tokens"],
)
for row in results
]
async def hybrid_document_search(
self,
query_text: str,
query_embedding: list[float],
search_settings: SearchSettings,
) -> list[DocumentResponse]:
"""Search documents using both semantic and full-text search with RRF
fusion."""
# Get more results than needed for better fusion
extended_settings = copy.deepcopy(search_settings)
extended_settings.limit = search_settings.limit * 3
# Get results from both search methods
semantic_results = await self.semantic_document_search(
query_embedding, extended_settings
)
full_text_results = await self.full_text_document_search(
query_text, extended_settings
)
# Combine results using RRF
doc_scores: dict[str, dict] = {}
# Process semantic results
for rank, result in enumerate(semantic_results, 1):
doc_id = str(result.id)
doc_scores[doc_id] = {
"semantic_rank": rank,
"full_text_rank": len(full_text_results)
+ 1, # Default rank if not found
"data": result,
}
# Process full-text results
for rank, result in enumerate(full_text_results, 1):
doc_id = str(result.id)
if doc_id in doc_scores:
doc_scores[doc_id]["full_text_rank"] = rank
else:
doc_scores[doc_id] = {
"semantic_rank": len(semantic_results)
+ 1, # Default rank if not found
"full_text_rank": rank,
"data": result,
}
# Calculate RRF scores using hybrid search settings
rrf_k = search_settings.hybrid_settings.rrf_k
semantic_weight = search_settings.hybrid_settings.semantic_weight
full_text_weight = search_settings.hybrid_settings.full_text_weight
for scores in doc_scores.values():
semantic_score = 1 / (rrf_k + scores["semantic_rank"])
full_text_score = 1 / (rrf_k + scores["full_text_rank"])
# Weighted combination
combined_score = (
semantic_score * semantic_weight
+ full_text_score * full_text_weight
) / (semantic_weight + full_text_weight)
scores["final_score"] = combined_score
# Sort by final score and apply offset/limit
sorted_results = sorted(
doc_scores.values(), key=lambda x: x["final_score"], reverse=True
)[
search_settings.offset : search_settings.offset
+ search_settings.limit
]
return [
DocumentResponse(
**{
**result["data"].__dict__,
"metadata": {
**(
result["data"].metadata
if search_settings.include_metadatas
else {}
),
"search_score": result["final_score"],
"semantic_rank": result["semantic_rank"],
"full_text_rank": result["full_text_rank"],
"search_type": "hybrid",
},
}
)
for result in sorted_results
]
async def search_documents(
self,
query_text: str,
query_embedding: Optional[list[float]] = None,
settings: Optional[SearchSettings] = None,
) -> list[DocumentResponse]:
"""Main search method that delegates to the appropriate search method
based on settings."""
if settings is None:
settings = SearchSettings()
if (
settings.use_semantic_search and settings.use_fulltext_search
) or settings.use_hybrid_search:
if query_embedding is None:
raise ValueError(
"query_embedding is required for hybrid search"
)
return await self.hybrid_document_search(
query_text, query_embedding, settings
)
elif settings.use_semantic_search:
if query_embedding is None:
raise ValueError(
"query_embedding is required for vector search"
)
return await self.semantic_document_search(
query_embedding, settings
)
else:
return await self.full_text_document_search(query_text, settings)
async def export_to_csv(
self,
columns: Optional[list[str]] = None,
filters: Optional[dict] = None,
include_header: bool = True,
) -> tuple[str, IO]:
"""Creates a CSV file from the PostgreSQL data and returns the path to
the temp file."""
valid_columns = {
"id",
"collection_ids",
"owner_id",
"type",
"metadata",
"title",
"summary",
"version",
"size_in_bytes",
"ingestion_status",
"extraction_status",
"created_at",
"updated_at",
"total_tokens",
}
filters = copy.deepcopy(filters)
filters = transform_filter_fields(filters) # type: ignore
if not columns:
columns = list(valid_columns)
elif invalid_cols := set(columns) - valid_columns:
raise ValueError(f"Invalid columns: {invalid_cols}")
select_stmt = f"""
SELECT
id::text,
collection_ids::text,
owner_id::text,
type::text,
metadata::text AS metadata,
title,
summary,
version,
size_in_bytes,
ingestion_status,
extraction_status,
to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at,
total_tokens
FROM {self._get_table_name(self.TABLE_NAME)}
"""
conditions = []
params: list[Any] = []
param_index = 1
if filters:
for field, value in filters.items():
if field not in valid_columns:
continue
if isinstance(value, dict):
for op, val in value.items():
if op == "$eq":
conditions.append(f"{field} = ${param_index}")
params.append(val)
param_index += 1
elif op == "$gt":
conditions.append(f"{field} > ${param_index}")
params.append(val)
param_index += 1
elif op == "$lt":
conditions.append(f"{field} < ${param_index}")
params.append(val)
param_index += 1
else:
# Direct equality
conditions.append(f"{field} = ${param_index}")
params.append(value)
param_index += 1
if conditions:
select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
select_stmt = f"{select_stmt} ORDER BY created_at DESC"
temp_file = None
try:
temp_file = tempfile.NamedTemporaryFile(
mode="w", delete=True, suffix=".csv"
)
writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
async with self.connection_manager.pool.get_connection() as conn: # type: ignore
async with conn.transaction():
cursor = await conn.cursor(select_stmt, *params)
if include_header:
writer.writerow(columns)
chunk_size = 1000
while True:
rows = await cursor.fetch(chunk_size)
if not rows:
break
for row in rows:
row_dict = {
"id": row[0],
"collection_ids": row[1],
"owner_id": row[2],
"type": row[3],
"metadata": row[4],
"title": row[5],
"summary": row[6],
"version": row[7],
"size_in_bytes": row[8],
"ingestion_status": row[9],
"extraction_status": row[10],
"created_at": row[11],
"updated_at": row[12],
"total_tokens": row[13],
}
writer.writerow([row_dict[col] for col in columns])
temp_file.flush()
return temp_file.name, temp_file
except Exception as e:
if temp_file:
temp_file.close()
raise HTTPException(
status_code=500,
detail=f"Failed to export data: {str(e)}",
) from e