import copy
import json
import logging
import math
import time
import uuid
from typing import Any, Optional, TypedDict
from uuid import UUID
import numpy as np
from core.base import (
ChunkSearchResult,
Handler,
IndexArgsHNSW,
IndexArgsIVFFlat,
IndexMeasure,
IndexMethod,
R2RException,
SearchSettings,
VectorEntry,
VectorQuantizationType,
VectorTableName,
)
from core.base.utils import _decorate_vector_type
from .base import PostgresConnectionManager
from .filters import apply_filters
logger = logging.getLogger()
def psql_quote_literal(value: str) -> str:
"""Safely quote a string literal for PostgreSQL to prevent SQL injection.
This is a simple implementation - in production, you should use proper parameterization
or your database driver's quoting functions.
"""
return "'" + value.replace("'", "''") + "'"
def index_measure_to_ops(
measure: IndexMeasure,
quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
):
return _decorate_vector_type(measure.ops, quantization_type)
def quantize_vector_to_binary(
vector: list[float] | np.ndarray,
threshold: float = 0.0,
) -> bytes:
"""Quantizes a float vector to a binary vector string for PostgreSQL bit
type. Used when quantization_type is INT1.
Args:
vector (List[float] | np.ndarray): Input vector of floats
threshold (float, optional): Threshold for binarization. Defaults to 0.0.
Returns:
str: Binary string representation for PostgreSQL bit type
"""
# Convert input to numpy array if it isn't already
if not isinstance(vector, np.ndarray):
vector = np.array(vector)
# Convert to binary (1 where value > threshold, 0 otherwise)
binary_vector = (vector > threshold).astype(int)
# Convert to string of 1s and 0s
# Convert to string of 1s and 0s, then to bytes
binary_string = "".join(map(str, binary_vector))
return binary_string.encode("ascii")
class HybridSearchIntermediateResult(TypedDict):
semantic_rank: int
full_text_rank: int
data: ChunkSearchResult
rrf_score: float
class PostgresChunksHandler(Handler):
TABLE_NAME = VectorTableName.CHUNKS
def __init__(
self,
project_name: str,
connection_manager: PostgresConnectionManager,
dimension: int | float,
quantization_type: VectorQuantizationType,
):
super().__init__(project_name, connection_manager)
self.dimension = dimension
self.quantization_type = quantization_type
async def create_tables(self):
# First check if table already exists and validate dimensions
table_exists_query = """
SELECT EXISTS (
SELECT FROM pg_tables
WHERE schemaname = $1
AND tablename = $2
);
"""
table_name = VectorTableName.CHUNKS
table_exists = await self.connection_manager.fetch_query(
table_exists_query, (self.project_name, table_name)
)
if len(table_exists) > 0 and table_exists[0]["exists"]:
# Table exists, check vector dimension
vector_dim_query = """
SELECT a.atttypmod as dimension
FROM pg_attribute a
JOIN pg_class c ON a.attrelid = c.oid
JOIN pg_namespace n ON c.relnamespace = n.oid
WHERE n.nspname = $1
AND c.relname = $2
AND a.attname = 'vec';
"""
vector_dim_result = await self.connection_manager.fetch_query(
vector_dim_query, (self.project_name, table_name)
)
if vector_dim_result and len(vector_dim_result) > 0:
existing_dimension = vector_dim_result[0]["dimension"]
# In pgvector, dimension is stored as atttypmod - 4
if existing_dimension > 0: # If it has a specific dimension
# Compare with provided dimension
if (
self.dimension > 0
and existing_dimension != self.dimension
):
raise ValueError(
f"Dimension mismatch: Table '{self.project_name}.{table_name}' was created with "
f"dimension {existing_dimension}, but {self.dimension} was provided. "
f"You must use the same dimension for existing tables."
)
# Check for old table name
check_query = """
SELECT EXISTS (
SELECT FROM pg_tables
WHERE schemaname = $1
AND tablename = $2
);
"""
old_table_exists = await self.connection_manager.fetch_query(
check_query, (self.project_name, self.project_name)
)
if len(old_table_exists) > 0 and old_table_exists[0]["exists"]:
raise ValueError(
f"Found old vector table '{self.project_name}.{self.project_name}'. "
"Please run `r2r db upgrade` with the CLI, or to run manually, "
"run in R2R/py/migrations with 'alembic upgrade head' to update "
"your database schema to the new version."
)
binary_col = (
""
if self.quantization_type != VectorQuantizationType.INT1
else f"vec_binary bit({self.dimension}),"
)
if self.dimension > 0:
vector_col = f"vec vector({self.dimension})"
else:
vector_col = "vec vector"
query = f"""
CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (
id UUID PRIMARY KEY,
document_id UUID,
owner_id UUID,
collection_ids UUID[],
{vector_col},
{binary_col}
text TEXT,
metadata JSONB,
fts tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED
);
CREATE INDEX IF NOT EXISTS idx_vectors_document_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (document_id);
CREATE INDEX IF NOT EXISTS idx_vectors_owner_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (owner_id);
CREATE INDEX IF NOT EXISTS idx_vectors_collection_ids ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (collection_ids);
CREATE INDEX IF NOT EXISTS idx_vectors_text ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (to_tsvector('english', text));
"""
await self.connection_manager.execute_query(query)
async def upsert(self, entry: VectorEntry) -> None:
"""Upsert function that handles vector quantization only when
quantization_type is INT1.
Matches the table schema where vec_binary column only exists for INT1
quantization.
"""
# Check the quantization type to determine which columns to use
if self.quantization_type == VectorQuantizationType.INT1:
bit_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)
# For quantized vectors, use vec_binary column
query = f"""
INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
(id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
VALUES ($1, $2, $3, $4, $5, $6::bit({bit_dim}), $7, $8)
ON CONFLICT (id) DO UPDATE SET
document_id = EXCLUDED.document_id,
owner_id = EXCLUDED.owner_id,
collection_ids = EXCLUDED.collection_ids,
vec = EXCLUDED.vec,
vec_binary = EXCLUDED.vec_binary,
text = EXCLUDED.text,
metadata = EXCLUDED.metadata;
"""
await self.connection_manager.execute_query(
query,
(
entry.id,
entry.document_id,
entry.owner_id,
entry.collection_ids,
str(entry.vector.data),
quantize_vector_to_binary(
entry.vector.data
), # Convert to binary
entry.text,
json.dumps(entry.metadata),
),
)
else:
# For regular vectors, use vec column only
query = f"""
INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
(id, document_id, owner_id, collection_ids, vec, text, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (id) DO UPDATE SET
document_id = EXCLUDED.document_id,
owner_id = EXCLUDED.owner_id,
collection_ids = EXCLUDED.collection_ids,
vec = EXCLUDED.vec,
text = EXCLUDED.text,
metadata = EXCLUDED.metadata;
"""
await self.connection_manager.execute_query(
query,
(
entry.id,
entry.document_id,
entry.owner_id,
entry.collection_ids,
str(entry.vector.data),
entry.text,
json.dumps(entry.metadata),
),
)
async def upsert_entries(self, entries: list[VectorEntry]) -> None:
"""Batch upsert function that handles vector quantization only when
quantization_type is INT1.
Matches the table schema where vec_binary column only exists for INT1
quantization.
"""
if self.quantization_type == VectorQuantizationType.INT1:
bit_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)
# For quantized vectors, use vec_binary column
query = f"""
INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
(id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata)
VALUES ($1, $2, $3, $4, $5, $6::bit({bit_dim}), $7, $8)
ON CONFLICT (id) DO UPDATE SET
document_id = EXCLUDED.document_id,
owner_id = EXCLUDED.owner_id,
collection_ids = EXCLUDED.collection_ids,
vec = EXCLUDED.vec,
vec_binary = EXCLUDED.vec_binary,
text = EXCLUDED.text,
metadata = EXCLUDED.metadata;
"""
bin_params = [
(
entry.id,
entry.document_id,
entry.owner_id,
entry.collection_ids,
str(entry.vector.data),
quantize_vector_to_binary(
entry.vector.data
), # Convert to binary
entry.text,
json.dumps(entry.metadata),
)
for entry in entries
]
await self.connection_manager.execute_many(query, bin_params)
else:
# For regular vectors, use vec column only
query = f"""
INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
(id, document_id, owner_id, collection_ids, vec, text, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (id) DO UPDATE SET
document_id = EXCLUDED.document_id,
owner_id = EXCLUDED.owner_id,
collection_ids = EXCLUDED.collection_ids,
vec = EXCLUDED.vec,
text = EXCLUDED.text,
metadata = EXCLUDED.metadata;
"""
params = [
(
entry.id,
entry.document_id,
entry.owner_id,
entry.collection_ids,
str(entry.vector.data),
entry.text,
json.dumps(entry.metadata),
)
for entry in entries
]
await self.connection_manager.execute_many(query, params)
async def semantic_search(
self, query_vector: list[float], search_settings: SearchSettings
) -> list[ChunkSearchResult]:
try:
imeasure_obj = IndexMeasure(
search_settings.chunk_settings.index_measure
)
except ValueError:
raise ValueError("Invalid index measure") from None
table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME)
cols = [
f"{table_name}.id",
f"{table_name}.document_id",
f"{table_name}.owner_id",
f"{table_name}.collection_ids",
f"{table_name}.text",
]
params: list[str | int | bytes] = []
# For binary vectors (INT1), implement two-stage search
if self.quantization_type == VectorQuantizationType.INT1:
# Convert query vector to binary format
binary_query = quantize_vector_to_binary(query_vector)
# TODO - Put depth multiplier in config / settings
extended_limit = (
search_settings.limit * 20
) # Get 20x candidates for re-ranking
if (
imeasure_obj == IndexMeasure.hamming_distance
or imeasure_obj == IndexMeasure.jaccard_distance
):
binary_search_measure_repr = imeasure_obj.pgvector_repr
else:
binary_search_measure_repr = (
IndexMeasure.hamming_distance.pgvector_repr
)
# Use binary column and binary-specific distance measures for first stage
bit_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)
stage1_distance = f"{table_name}.vec_binary {binary_search_measure_repr} $1::bit{bit_dim}"
stage1_param = binary_query
cols.append(
f"{table_name}.vec"
) # Need original vector for re-ranking
if search_settings.include_metadatas:
cols.append(f"{table_name}.metadata")
select_clause = ", ".join(cols)
where_clause = ""
params.append(stage1_param)
if search_settings.filters:
where_clause, params = apply_filters(
search_settings.filters, params, mode="where_clause"
)
vector_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)
# First stage: Get candidates using binary search
query = f"""
WITH candidates AS (
SELECT {select_clause},
({stage1_distance}) as binary_distance
FROM {table_name}
{where_clause}
ORDER BY {stage1_distance}
LIMIT ${len(params) + 1}
OFFSET ${len(params) + 2}
)
-- Second stage: Re-rank using original vectors
SELECT
id,
document_id,
owner_id,
collection_ids,
text,
{"metadata," if search_settings.include_metadatas else ""}
(vec <=> ${len(params) + 4}::vector{vector_dim}) as distance
FROM candidates
ORDER BY distance
LIMIT ${len(params) + 3}
"""
params.extend(
[
extended_limit, # First stage limit
search_settings.offset,
search_settings.limit, # Final limit
str(query_vector), # For re-ranking
]
)
else:
# Standard float vector handling
vector_dim = (
"" if math.isnan(self.dimension) else f"({self.dimension})"
)
distance_calc = f"{table_name}.vec {search_settings.chunk_settings.index_measure.pgvector_repr} $1::vector{vector_dim}"
query_param = str(query_vector)
if search_settings.include_scores:
cols.append(f"({distance_calc}) AS distance")
if search_settings.include_metadatas:
cols.append(f"{table_name}.metadata")
select_clause = ", ".join(cols)
where_clause = ""
params.append(query_param)
if search_settings.filters:
where_clause, new_params = apply_filters(
search_settings.filters,
params,
mode="where_clause", # Get just conditions without WHERE
)
params = new_params
query = f"""
SELECT {select_clause}
FROM {table_name}
{where_clause}
ORDER BY {distance_calc}
LIMIT ${len(params) + 1}
OFFSET ${len(params) + 2}
"""
params.extend([search_settings.limit, search_settings.offset])
results = await self.connection_manager.fetch_query(query, params)
return [
ChunkSearchResult(
id=UUID(str(result["id"])),
document_id=UUID(str(result["document_id"])),
owner_id=UUID(str(result["owner_id"])),
collection_ids=result["collection_ids"],
text=result["text"],
score=(
(1 - float(result["distance"]))
if "distance" in result
else -1
),
metadata=(
json.loads(result["metadata"])
if search_settings.include_metadatas
else {}
),
)
for result in results
]
async def full_text_search(
self, query_text: str, search_settings: SearchSettings
) -> list[ChunkSearchResult]:
conditions = []
params: list[str | int | bytes] = [query_text]
conditions.append("fts @@ websearch_to_tsquery('english', $1)")
if search_settings.filters:
filter_condition, params = apply_filters(
search_settings.filters, params, mode="condition_only"
)
if filter_condition:
conditions.append(filter_condition)
where_clause = "WHERE " + " AND ".join(conditions)
query = f"""
SELECT
id,
document_id,
owner_id,
collection_ids,
text,
metadata,
ts_rank(fts, websearch_to_tsquery('english', $1), 32) as rank
FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
{where_clause}
ORDER BY rank DESC
OFFSET ${len(params) + 1}
LIMIT ${len(params) + 2}
"""
params.extend(
[
search_settings.offset,
search_settings.hybrid_settings.full_text_limit,
]
)
results = await self.connection_manager.fetch_query(query, params)
return [
ChunkSearchResult(
id=UUID(str(r["id"])),
document_id=UUID(str(r["document_id"])),
owner_id=UUID(str(r["owner_id"])),
collection_ids=r["collection_ids"],
text=r["text"],
score=float(r["rank"]),
metadata=json.loads(r["metadata"]),
)
for r in results
]
async def hybrid_search(
self,
query_text: str,
query_vector: list[float],
search_settings: SearchSettings,
*args,
**kwargs,
) -> list[ChunkSearchResult]:
if search_settings.hybrid_settings is None:
raise ValueError(
"Please provide a valid `hybrid_settings` in the `search_settings`."
)
if (
search_settings.hybrid_settings.full_text_limit
< search_settings.limit
):
raise ValueError(
"The `full_text_limit` must be greater than or equal to the `limit`."
)
semantic_settings = copy.deepcopy(search_settings)
semantic_settings.limit += search_settings.offset
full_text_settings = copy.deepcopy(search_settings)
full_text_settings.hybrid_settings.full_text_limit += (
search_settings.offset
)
semantic_results: list[ChunkSearchResult] = await self.semantic_search(
query_vector, semantic_settings
)
full_text_results: list[
ChunkSearchResult
] = await self.full_text_search(query_text, full_text_settings)
semantic_limit = search_settings.limit
full_text_limit = search_settings.hybrid_settings.full_text_limit
semantic_weight = search_settings.hybrid_settings.semantic_weight
full_text_weight = search_settings.hybrid_settings.full_text_weight
rrf_k = search_settings.hybrid_settings.rrf_k
combined_results: dict[uuid.UUID, HybridSearchIntermediateResult] = {}
for rank, result in enumerate(semantic_results, 1):
combined_results[result.id] = {
"semantic_rank": rank,
"full_text_rank": full_text_limit,
"data": result,
"rrf_score": 0.0, # Initialize with 0, will be calculated later
}
for rank, result in enumerate(full_text_results, 1):
if result.id in combined_results:
combined_results[result.id]["full_text_rank"] = rank
else:
combined_results[result.id] = {
"semantic_rank": semantic_limit,
"full_text_rank": rank,
"data": result,
"rrf_score": 0.0, # Initialize with 0, will be calculated later
}
combined_results = {
k: v
for k, v in combined_results.items()
if v["semantic_rank"] <= semantic_limit * 2
and v["full_text_rank"] <= full_text_limit * 2
}
for hyb_result in combined_results.values():
semantic_score = 1 / (rrf_k + hyb_result["semantic_rank"])
full_text_score = 1 / (rrf_k + hyb_result["full_text_rank"])
hyb_result["rrf_score"] = (
semantic_score * semantic_weight
+ full_text_score * full_text_weight
) / (semantic_weight + full_text_weight)
sorted_results = sorted(
combined_results.values(),
key=lambda x: x["rrf_score"],
reverse=True,
)
offset_results = sorted_results[
search_settings.offset : search_settings.offset
+ search_settings.limit
]
return [
ChunkSearchResult(
id=result["data"].id,
document_id=result["data"].document_id,
owner_id=result["data"].owner_id,
collection_ids=result["data"].collection_ids,
text=result["data"].text,
score=result["rrf_score"],
metadata={
**result["data"].metadata,
"semantic_rank": result["semantic_rank"],
"full_text_rank": result["full_text_rank"],
},
)
for result in offset_results
]
async def delete(
self, filters: dict[str, Any]
) -> dict[str, dict[str, str]]:
params: list[str | int | bytes] = []
where_clause, params = apply_filters(
filters, params, mode="condition_only"
)
query = f"""
DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
WHERE {where_clause}
RETURNING id, document_id, text;
"""
results = await self.connection_manager.fetch_query(query, params)
return {
str(result["id"]): {
"status": "deleted",
"id": str(result["id"]),
"document_id": str(result["document_id"]),
"text": result["text"],
}
for result in results
}
async def assign_document_chunks_to_collection(
self, document_id: UUID, collection_id: UUID
) -> None:
query = f"""
UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
SET collection_ids = array_append(collection_ids, $1)
WHERE document_id = $2 AND NOT ($1 = ANY(collection_ids));
"""
return await self.connection_manager.execute_query(
query, (str(collection_id), str(document_id))
)
async def remove_document_from_collection_vector(
self, document_id: UUID, collection_id: UUID
) -> None:
query = f"""
UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
SET collection_ids = array_remove(collection_ids, $1)
WHERE document_id = $2;
"""
await self.connection_manager.execute_query(
query, (collection_id, document_id)
)
async def delete_user_vector(self, owner_id: UUID) -> None:
query = f"""
DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
WHERE owner_id = $1;
"""
await self.connection_manager.execute_query(query, (owner_id,))
async def delete_collection_vector(self, collection_id: UUID) -> None:
query = f"""
DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
WHERE $1 = ANY(collection_ids)
RETURNING collection_ids
"""
await self.connection_manager.fetchrow_query(query, (collection_id,))
return None
async def list_document_chunks(
self,
document_id: UUID,
offset: int,
limit: int,
include_vectors: bool = False,
) -> dict[str, Any]:
vector_select = ", vec" if include_vectors else ""
limit_clause = f"LIMIT {limit}" if limit > -1 else ""
query = f"""
SELECT id, document_id, owner_id, collection_ids, text, metadata{vector_select}, COUNT(*) OVER() AS total
FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
WHERE document_id = $1
ORDER BY (metadata->>'chunk_order')::integer
OFFSET $2
{limit_clause};
"""
params = [document_id, offset]
results = await self.connection_manager.fetch_query(query, params)
chunks = []
total = 0
if results:
total = results[0].get("total", 0)
chunks = [
{
"id": result["id"],
"document_id": result["document_id"],
"owner_id": result["owner_id"],
"collection_ids": result["collection_ids"],
"text": result["text"],
"metadata": json.loads(result["metadata"]),
"vector": (
json.loads(result["vec"]) if include_vectors else None
),
}
for result in results
]
return {"results": chunks, "total_entries": total}
async def get_chunk(self, id: UUID) -> dict:
query = f"""
SELECT id, document_id, owner_id, collection_ids, text, metadata
FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
WHERE id = $1;
"""
result = await self.connection_manager.fetchrow_query(query, (id,))
if result:
return {
"id": result["id"],
"document_id": result["document_id"],
"owner_id": result["owner_id"],
"collection_ids": result["collection_ids"],
"text": result["text"],
"metadata": json.loads(result["metadata"]),
}
raise R2RException(
message=f"Chunk with ID {id} not found", status_code=404
)
async def create_index(
self,
table_name: Optional[VectorTableName] = None,
index_measure: IndexMeasure = IndexMeasure.cosine_distance,
index_method: IndexMethod = IndexMethod.auto,
index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW] = None,
index_name: Optional[str] = None,
index_column: Optional[str] = None,
concurrently: bool = True,
) -> None:
"""Creates an index for the collection.
Note:
When `vecs` creates an index on a pgvector column in PostgreSQL, it uses a multi-step
process that enables performant indexes to be built for large collections with low end
database hardware.
Those steps are:
- Creates a new table with a different name
- Randomly selects records from the existing table
- Inserts the random records from the existing table into the new table
- Creates the requested vector index on the new table
- Upserts all data from the existing table into the new table
- Drops the existing table
- Renames the new table to the existing tables name
If you create dependencies (like views) on the table that underpins
a `vecs.Collection` the `create_index` step may require you to drop those dependencies before
it will succeed.
Args:
index_measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'.
index_method (IndexMethod, optional): The indexing method to use. Defaults to 'auto'.
index_arguments: (IndexArgsIVFFlat | IndexArgsHNSW, optional): Index type specific arguments
index_name (str, optional): The name of the index to create. Defaults to None.
concurrently (bool, optional): Whether to create the index concurrently. Defaults to True.
Raises:
ValueError: If an invalid index method is used, or if *replace* is False and an index already exists.
"""
if table_name == VectorTableName.CHUNKS:
table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}" # TODO - Fix bug in vector table naming convention
if index_column:
col_name = index_column
else:
col_name = (
"vec"
if (
index_measure != IndexMeasure.hamming_distance
and index_measure != IndexMeasure.jaccard_distance
)
else "vec_binary"
)
elif table_name == VectorTableName.ENTITIES_DOCUMENT:
table_name_str = (
f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
)
col_name = "description_embedding"
elif table_name == VectorTableName.GRAPHS_ENTITIES:
table_name_str = (
f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
)
col_name = "description_embedding"
elif table_name == VectorTableName.COMMUNITIES:
table_name_str = (
f"{self.project_name}.{VectorTableName.COMMUNITIES}"
)
col_name = "embedding"
else:
raise ValueError("invalid table name")
if index_method not in (
IndexMethod.ivfflat,
IndexMethod.hnsw,
IndexMethod.auto,
):
raise ValueError("invalid index method")
if index_arguments:
# Disallow case where user submits index arguments but uses the
# IndexMethod.auto index (index build arguments should only be
# used with a specific index)
if index_method == IndexMethod.auto:
raise ValueError(
"Index build parameters are not allowed when using the IndexMethod.auto index."
)
# Disallow case where user specifies one index type but submits
# index build arguments for the other index type
if (
isinstance(index_arguments, IndexArgsHNSW)
and index_method != IndexMethod.hnsw
) or (
isinstance(index_arguments, IndexArgsIVFFlat)
and index_method != IndexMethod.ivfflat
):
raise ValueError(
f"{index_arguments.__class__.__name__} build parameters were supplied but {index_method} index was specified."
)
if index_method == IndexMethod.auto:
index_method = IndexMethod.hnsw
ops = index_measure_to_ops(
index_measure # , quantization_type=self.quantization_type
)
if ops is None:
raise ValueError("Unknown index measure")
concurrently_sql = "CONCURRENTLY" if concurrently else ""
index_name = (
index_name
or f"ix_{ops}_{index_method}__{col_name}_{time.strftime('%Y%m%d%H%M%S')}"
)
create_index_sql = f"""
CREATE INDEX {concurrently_sql} {index_name}
ON {table_name_str}
USING {index_method} ({col_name} {ops}) {self._get_index_options(index_method, index_arguments)};
"""
try:
if concurrently:
async with (
self.connection_manager.pool.get_connection() as conn # type: ignore
):
# Disable automatic transaction management
await conn.execute(
"SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
)
await conn.execute(create_index_sql)
else:
# Non-concurrent index creation can use normal query execution
await self.connection_manager.execute_query(create_index_sql)
except Exception as e:
raise Exception(f"Failed to create index: {e}") from e
return None
async def list_indices(
self,
offset: int,
limit: int,
filters: Optional[dict[str, Any]] = None,
) -> dict:
where_clauses = []
params: list[Any] = [self.project_name] # Start with schema name
param_count = 1
# Handle filtering
if filters:
if "table_name" in filters:
where_clauses.append(f"i.tablename = ${param_count + 1}")
params.append(filters["table_name"])
param_count += 1
if "index_method" in filters:
where_clauses.append(f"am.amname = ${param_count + 1}")
params.append(filters["index_method"])
param_count += 1
if "index_name" in filters:
where_clauses.append(
f"LOWER(i.indexname) LIKE LOWER(${param_count + 1})"
)
params.append(f"%{filters['index_name']}%")
param_count += 1
where_clause = " AND ".join(where_clauses) if where_clauses else ""
if where_clause:
where_clause = f"AND {where_clause}"
query = f"""
WITH index_info AS (
SELECT
i.indexname as name,
i.tablename as table_name,
i.indexdef as definition,
am.amname as method,
pg_relation_size(c.oid) as size_in_bytes,
c.reltuples::bigint as row_estimate,
COALESCE(psat.idx_scan, 0) as number_of_scans,
COALESCE(psat.idx_tup_read, 0) as tuples_read,
COALESCE(psat.idx_tup_fetch, 0) as tuples_fetched,
COUNT(*) OVER() as total_count
FROM pg_indexes i
JOIN pg_class c ON c.relname = i.indexname
JOIN pg_am am ON c.relam = am.oid
LEFT JOIN pg_stat_user_indexes psat ON psat.indexrelname = i.indexname
AND psat.schemaname = i.schemaname
WHERE i.schemaname = $1
AND i.indexdef LIKE '%vector%'
{where_clause}
)
SELECT *
FROM index_info
ORDER BY name
LIMIT ${param_count + 1}
OFFSET ${param_count + 2}
"""
# Add limit and offset to params
params.extend([limit, offset])
results = await self.connection_manager.fetch_query(query, params)
indices = []
total_entries = 0
if results:
total_entries = results[0]["total_count"]
for result in results:
index_info = {
"name": result["name"],
"table_name": result["table_name"],
"definition": result["definition"],
"size_in_bytes": result["size_in_bytes"],
"row_estimate": result["row_estimate"],
"number_of_scans": result["number_of_scans"],
"tuples_read": result["tuples_read"],
"tuples_fetched": result["tuples_fetched"],
}
indices.append(index_info)
return {"indices": indices, "total_entries": total_entries}
async def delete_index(
self,
index_name: str,
table_name: Optional[VectorTableName] = None,
concurrently: bool = True,
) -> None:
"""Deletes a vector index.
Args:
index_name (str): Name of the index to delete
table_name (VectorTableName, optional): Table the index belongs to
concurrently (bool): Whether to drop the index concurrently
Raises:
ValueError: If table name is invalid or index doesn't exist
Exception: If index deletion fails
"""
# Validate table name and get column name
if table_name == VectorTableName.CHUNKS:
table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}"
col_name = "vec"
elif table_name == VectorTableName.ENTITIES_DOCUMENT:
table_name_str = (
f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}"
)
col_name = "description_embedding"
elif table_name == VectorTableName.GRAPHS_ENTITIES:
table_name_str = (
f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}"
)
col_name = "description_embedding"
elif table_name == VectorTableName.COMMUNITIES:
table_name_str = (
f"{self.project_name}.{VectorTableName.COMMUNITIES}"
)
col_name = "description_embedding"
else:
raise ValueError("invalid table name")
# Extract schema and base table name
schema_name, base_table_name = table_name_str.split(".")
# Verify index exists and is a vector index
query = """
SELECT indexdef
FROM pg_indexes
WHERE indexname = $1
AND schemaname = $2
AND tablename = $3
AND indexdef LIKE $4
"""
result = await self.connection_manager.fetchrow_query(
query, (index_name, schema_name, base_table_name, f"%({col_name}%")
)
if not result:
raise ValueError(
f"Vector index '{index_name}' does not exist on table {table_name_str}"
)
# Drop the index
concurrently_sql = "CONCURRENTLY" if concurrently else ""
drop_query = (
f"DROP INDEX {concurrently_sql} {schema_name}.{index_name}"
)
try:
if concurrently:
async with (
self.connection_manager.pool.get_connection() as conn # type: ignore
):
# Disable automatic transaction management
await conn.execute(
"SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED"
)
await conn.execute(drop_query)
else:
await self.connection_manager.execute_query(drop_query)
except Exception as e:
raise Exception(f"Failed to delete index: {e}") from e
async def list_chunks(
self,
offset: int,
limit: int,
filters: Optional[dict[str, Any]] = None,
include_vectors: bool = False,
) -> dict[str, Any]:
"""List chunks with pagination support.
Args:
offset (int, optional): Number of records to skip. Defaults to 0.
limit (int, optional): Maximum number of records to return. Defaults to 10.
filters (dict, optional): Dictionary of filters to apply. Defaults to None.
include_vectors (bool, optional): Whether to include vector data. Defaults to False.
Returns:
dict: Dictionary containing:
- results: List of chunk records
- total_entries: Total number of chunks matching the filters
"""
vector_select = ", vec" if include_vectors else ""
select_clause = f"""
id, document_id, owner_id, collection_ids,
text, metadata{vector_select}, COUNT(*) OVER() AS total_entries
"""
params: list[str | int | bytes] = []
where_clause = ""
if filters:
where_clause, params = apply_filters(
filters, params, mode="where_clause"
)
query = f"""
SELECT {select_clause}
FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
{where_clause}
LIMIT ${len(params) + 1}
OFFSET ${len(params) + 2}
"""
params.extend([limit, offset])
# Execute the query
results = await self.connection_manager.fetch_query(query, params)
# Process results
chunks = []
total_entries = 0
if results:
total_entries = results[0].get("total_entries", 0)
chunks = [
{
"id": str(result["id"]),
"document_id": str(result["document_id"]),
"owner_id": str(result["owner_id"]),
"collection_ids": result["collection_ids"],
"text": result["text"],
"metadata": json.loads(result["metadata"]),
"vector": (
json.loads(result["vec"]) if include_vectors else None
),
}
for result in results
]
return {"results": chunks, "total_entries": total_entries}
async def search_documents(
self,
query_text: str,
settings: SearchSettings,
) -> list[dict[str, Any]]:
"""Search for documents based on their metadata fields and/or body
text. Joins with documents table to get complete document metadata.
Args:
query_text (str): The search query text
settings (SearchSettings): Search settings including search preferences and filters
Returns:
list[dict[str, Any]]: List of documents with their search scores and complete metadata
"""
where_clauses = []
params: list[str | int | bytes] = [query_text]
search_over_body = getattr(settings, "search_over_body", True)
search_over_metadata = getattr(settings, "search_over_metadata", True)
metadata_weight = getattr(settings, "metadata_weight", 3.0)
title_weight = getattr(settings, "title_weight", 1.0)
metadata_keys = getattr(
settings, "metadata_keys", ["title", "description"]
)
# Build the dynamic metadata field search expression
metadata_fields_expr = " || ' ' || ".join(
[
f"COALESCE(v.metadata->>{psql_quote_literal(key)}, '')"
for key in metadata_keys # type: ignore
]
)
query = f"""
WITH
-- Metadata search scores
metadata_scores AS (
SELECT DISTINCT ON (v.document_id)
v.document_id,
d.metadata as doc_metadata,
CASE WHEN $1 = '' THEN 0.0
ELSE
ts_rank_cd(
setweight(to_tsvector('english', {metadata_fields_expr}), 'A'),
websearch_to_tsquery('english', $1),
32
)
END as metadata_rank
FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} v
LEFT JOIN {self._get_table_name("documents")} d ON v.document_id = d.id
WHERE v.metadata IS NOT NULL
),
-- Body search scores
body_scores AS (
SELECT
document_id,
AVG(
ts_rank_cd(
setweight(to_tsvector('english', COALESCE(text, '')), 'B'),
websearch_to_tsquery('english', $1),
32
)
) as body_rank
FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)}
WHERE $1 != ''
{"AND to_tsvector('english', text) @@ websearch_to_tsquery('english', $1)" if search_over_body else ""}
GROUP BY document_id
),
-- Combined scores with document metadata
combined_scores AS (
SELECT
COALESCE(m.document_id, b.document_id) as document_id,
m.doc_metadata as metadata,
COALESCE(m.metadata_rank, 0) as debug_metadata_rank,
COALESCE(b.body_rank, 0) as debug_body_rank,
CASE
WHEN {str(search_over_metadata).lower()} AND {str(search_over_body).lower()} THEN
COALESCE(m.metadata_rank, 0) * {metadata_weight} + COALESCE(b.body_rank, 0) * {title_weight}
WHEN {str(search_over_metadata).lower()} THEN
COALESCE(m.metadata_rank, 0)
WHEN {str(search_over_body).lower()} THEN
COALESCE(b.body_rank, 0)
ELSE 0
END as rank
FROM metadata_scores m
FULL OUTER JOIN body_scores b ON m.document_id = b.document_id
WHERE (
($1 = '') OR
({str(search_over_metadata).lower()} AND m.metadata_rank > 0) OR
({str(search_over_body).lower()} AND b.body_rank > 0)
)
"""
# Add any additional filters
if settings.filters:
filter_clause, params = apply_filters(settings.filters, params)
where_clauses.append(filter_clause)
if where_clauses:
query += f" AND {' AND '.join(where_clauses)}"
query += """
)
SELECT
document_id,
metadata,
rank as score,
debug_metadata_rank,
debug_body_rank
FROM combined_scores
WHERE rank > 0
ORDER BY rank DESC
OFFSET ${offset_param} LIMIT ${limit_param}
""".format(
offset_param=len(params) + 1,
limit_param=len(params) + 2,
)
# Add offset and limit to params
params.extend([settings.offset, settings.limit])
# Execute query
results = await self.connection_manager.fetch_query(query, params)
# Format results with complete document metadata
return [
{
"document_id": str(r["document_id"]),
"metadata": (
json.loads(r["metadata"])
if isinstance(r["metadata"], str)
else r["metadata"]
),
"score": float(r["score"]),
"debug_metadata_rank": float(r["debug_metadata_rank"]),
"debug_body_rank": float(r["debug_body_rank"]),
}
for r in results
]
def _get_index_options(
self,
method: IndexMethod,
index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW],
) -> str:
if method == IndexMethod.ivfflat:
if isinstance(index_arguments, IndexArgsIVFFlat):
return f"WITH (lists={index_arguments.n_lists})"
else:
# Default value if no arguments provided
return "WITH (lists=100)"
elif method == IndexMethod.hnsw:
if isinstance(index_arguments, IndexArgsHNSW):
return f"WITH (m={index_arguments.m}, ef_construction={index_arguments.ef_construction})"
else:
# Default values if no arguments provided
return "WITH (m=16, ef_construction=64)"
else:
return "" # No options for other methods