diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/database/chunks.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/database/chunks.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/core/providers/database/chunks.py | 1316 |
1 files changed, 1316 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/chunks.py b/.venv/lib/python3.12/site-packages/core/providers/database/chunks.py new file mode 100644 index 00000000..177f3395 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/providers/database/chunks.py @@ -0,0 +1,1316 @@ +import copy +import json +import logging +import math +import time +import uuid +from typing import Any, Optional, TypedDict +from uuid import UUID + +import numpy as np + +from core.base import ( + ChunkSearchResult, + Handler, + IndexArgsHNSW, + IndexArgsIVFFlat, + IndexMeasure, + IndexMethod, + R2RException, + SearchSettings, + VectorEntry, + VectorQuantizationType, + VectorTableName, +) +from core.base.utils import _decorate_vector_type + +from .base import PostgresConnectionManager +from .filters import apply_filters + +logger = logging.getLogger() + + +def psql_quote_literal(value: str) -> str: + """Safely quote a string literal for PostgreSQL to prevent SQL injection. + + This is a simple implementation - in production, you should use proper parameterization + or your database driver's quoting functions. + """ + return "'" + value.replace("'", "''") + "'" + + +def index_measure_to_ops( + measure: IndexMeasure, + quantization_type: VectorQuantizationType = VectorQuantizationType.FP32, +): + return _decorate_vector_type(measure.ops, quantization_type) + + +def quantize_vector_to_binary( + vector: list[float] | np.ndarray, + threshold: float = 0.0, +) -> bytes: + """Quantizes a float vector to a binary vector string for PostgreSQL bit + type. Used when quantization_type is INT1. + + Args: + vector (List[float] | np.ndarray): Input vector of floats + threshold (float, optional): Threshold for binarization. Defaults to 0.0. + + Returns: + str: Binary string representation for PostgreSQL bit type + """ + # Convert input to numpy array if it isn't already + if not isinstance(vector, np.ndarray): + vector = np.array(vector) + + # Convert to binary (1 where value > threshold, 0 otherwise) + binary_vector = (vector > threshold).astype(int) + + # Convert to string of 1s and 0s + # Convert to string of 1s and 0s, then to bytes + binary_string = "".join(map(str, binary_vector)) + return binary_string.encode("ascii") + + +class HybridSearchIntermediateResult(TypedDict): + semantic_rank: int + full_text_rank: int + data: ChunkSearchResult + rrf_score: float + + +class PostgresChunksHandler(Handler): + TABLE_NAME = VectorTableName.CHUNKS + + def __init__( + self, + project_name: str, + connection_manager: PostgresConnectionManager, + dimension: int | float, + quantization_type: VectorQuantizationType, + ): + super().__init__(project_name, connection_manager) + self.dimension = dimension + self.quantization_type = quantization_type + + async def create_tables(self): + # First check if table already exists and validate dimensions + table_exists_query = """ + SELECT EXISTS ( + SELECT FROM pg_tables + WHERE schemaname = $1 + AND tablename = $2 + ); + """ + table_name = VectorTableName.CHUNKS + table_exists = await self.connection_manager.fetch_query( + table_exists_query, (self.project_name, table_name) + ) + + if len(table_exists) > 0 and table_exists[0]["exists"]: + # Table exists, check vector dimension + vector_dim_query = """ + SELECT a.atttypmod as dimension + FROM pg_attribute a + JOIN pg_class c ON a.attrelid = c.oid + JOIN pg_namespace n ON c.relnamespace = n.oid + WHERE n.nspname = $1 + AND c.relname = $2 + AND a.attname = 'vec'; + """ + + vector_dim_result = await self.connection_manager.fetch_query( + vector_dim_query, (self.project_name, table_name) + ) + + if vector_dim_result and len(vector_dim_result) > 0: + existing_dimension = vector_dim_result[0]["dimension"] + # In pgvector, dimension is stored as atttypmod - 4 + if existing_dimension > 0: # If it has a specific dimension + # Compare with provided dimension + if ( + self.dimension > 0 + and existing_dimension != self.dimension + ): + raise ValueError( + f"Dimension mismatch: Table '{self.project_name}.{table_name}' was created with " + f"dimension {existing_dimension}, but {self.dimension} was provided. " + f"You must use the same dimension for existing tables." + ) + + # Check for old table name + check_query = """ + SELECT EXISTS ( + SELECT FROM pg_tables + WHERE schemaname = $1 + AND tablename = $2 + ); + """ + old_table_exists = await self.connection_manager.fetch_query( + check_query, (self.project_name, self.project_name) + ) + + if len(old_table_exists) > 0 and old_table_exists[0]["exists"]: + raise ValueError( + f"Found old vector table '{self.project_name}.{self.project_name}'. " + "Please run `r2r db upgrade` with the CLI, or to run manually, " + "run in R2R/py/migrations with 'alembic upgrade head' to update " + "your database schema to the new version." + ) + + binary_col = ( + "" + if self.quantization_type != VectorQuantizationType.INT1 + else f"vec_binary bit({self.dimension})," + ) + + if self.dimension > 0: + vector_col = f"vec vector({self.dimension})" + else: + vector_col = "vec vector" + + query = f""" + CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} ( + id UUID PRIMARY KEY, + document_id UUID, + owner_id UUID, + collection_ids UUID[], + {vector_col}, + {binary_col} + text TEXT, + metadata JSONB, + fts tsvector GENERATED ALWAYS AS (to_tsvector('english', text)) STORED + ); + CREATE INDEX IF NOT EXISTS idx_vectors_document_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (document_id); + CREATE INDEX IF NOT EXISTS idx_vectors_owner_id ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} (owner_id); + CREATE INDEX IF NOT EXISTS idx_vectors_collection_ids ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (collection_ids); + CREATE INDEX IF NOT EXISTS idx_vectors_text ON {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} USING GIN (to_tsvector('english', text)); + """ + + await self.connection_manager.execute_query(query) + + async def upsert(self, entry: VectorEntry) -> None: + """Upsert function that handles vector quantization only when + quantization_type is INT1. + + Matches the table schema where vec_binary column only exists for INT1 + quantization. + """ + # Check the quantization type to determine which columns to use + if self.quantization_type == VectorQuantizationType.INT1: + bit_dim = ( + "" if math.isnan(self.dimension) else f"({self.dimension})" + ) + + # For quantized vectors, use vec_binary column + query = f""" + INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} + (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata) + VALUES ($1, $2, $3, $4, $5, $6::bit({bit_dim}), $7, $8) + ON CONFLICT (id) DO UPDATE SET + document_id = EXCLUDED.document_id, + owner_id = EXCLUDED.owner_id, + collection_ids = EXCLUDED.collection_ids, + vec = EXCLUDED.vec, + vec_binary = EXCLUDED.vec_binary, + text = EXCLUDED.text, + metadata = EXCLUDED.metadata; + """ + await self.connection_manager.execute_query( + query, + ( + entry.id, + entry.document_id, + entry.owner_id, + entry.collection_ids, + str(entry.vector.data), + quantize_vector_to_binary( + entry.vector.data + ), # Convert to binary + entry.text, + json.dumps(entry.metadata), + ), + ) + else: + # For regular vectors, use vec column only + query = f""" + INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} + (id, document_id, owner_id, collection_ids, vec, text, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (id) DO UPDATE SET + document_id = EXCLUDED.document_id, + owner_id = EXCLUDED.owner_id, + collection_ids = EXCLUDED.collection_ids, + vec = EXCLUDED.vec, + text = EXCLUDED.text, + metadata = EXCLUDED.metadata; + """ + + await self.connection_manager.execute_query( + query, + ( + entry.id, + entry.document_id, + entry.owner_id, + entry.collection_ids, + str(entry.vector.data), + entry.text, + json.dumps(entry.metadata), + ), + ) + + async def upsert_entries(self, entries: list[VectorEntry]) -> None: + """Batch upsert function that handles vector quantization only when + quantization_type is INT1. + + Matches the table schema where vec_binary column only exists for INT1 + quantization. + """ + if self.quantization_type == VectorQuantizationType.INT1: + bit_dim = ( + "" if math.isnan(self.dimension) else f"({self.dimension})" + ) + + # For quantized vectors, use vec_binary column + query = f""" + INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} + (id, document_id, owner_id, collection_ids, vec, vec_binary, text, metadata) + VALUES ($1, $2, $3, $4, $5, $6::bit({bit_dim}), $7, $8) + ON CONFLICT (id) DO UPDATE SET + document_id = EXCLUDED.document_id, + owner_id = EXCLUDED.owner_id, + collection_ids = EXCLUDED.collection_ids, + vec = EXCLUDED.vec, + vec_binary = EXCLUDED.vec_binary, + text = EXCLUDED.text, + metadata = EXCLUDED.metadata; + """ + bin_params = [ + ( + entry.id, + entry.document_id, + entry.owner_id, + entry.collection_ids, + str(entry.vector.data), + quantize_vector_to_binary( + entry.vector.data + ), # Convert to binary + entry.text, + json.dumps(entry.metadata), + ) + for entry in entries + ] + await self.connection_manager.execute_many(query, bin_params) + + else: + # For regular vectors, use vec column only + query = f""" + INSERT INTO {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} + (id, document_id, owner_id, collection_ids, vec, text, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (id) DO UPDATE SET + document_id = EXCLUDED.document_id, + owner_id = EXCLUDED.owner_id, + collection_ids = EXCLUDED.collection_ids, + vec = EXCLUDED.vec, + text = EXCLUDED.text, + metadata = EXCLUDED.metadata; + """ + params = [ + ( + entry.id, + entry.document_id, + entry.owner_id, + entry.collection_ids, + str(entry.vector.data), + entry.text, + json.dumps(entry.metadata), + ) + for entry in entries + ] + + await self.connection_manager.execute_many(query, params) + + async def semantic_search( + self, query_vector: list[float], search_settings: SearchSettings + ) -> list[ChunkSearchResult]: + try: + imeasure_obj = IndexMeasure( + search_settings.chunk_settings.index_measure + ) + except ValueError: + raise ValueError("Invalid index measure") from None + + table_name = self._get_table_name(PostgresChunksHandler.TABLE_NAME) + cols = [ + f"{table_name}.id", + f"{table_name}.document_id", + f"{table_name}.owner_id", + f"{table_name}.collection_ids", + f"{table_name}.text", + ] + + params: list[str | int | bytes] = [] + + # For binary vectors (INT1), implement two-stage search + if self.quantization_type == VectorQuantizationType.INT1: + # Convert query vector to binary format + binary_query = quantize_vector_to_binary(query_vector) + # TODO - Put depth multiplier in config / settings + extended_limit = ( + search_settings.limit * 20 + ) # Get 20x candidates for re-ranking + + if ( + imeasure_obj == IndexMeasure.hamming_distance + or imeasure_obj == IndexMeasure.jaccard_distance + ): + binary_search_measure_repr = imeasure_obj.pgvector_repr + else: + binary_search_measure_repr = ( + IndexMeasure.hamming_distance.pgvector_repr + ) + + # Use binary column and binary-specific distance measures for first stage + bit_dim = ( + "" if math.isnan(self.dimension) else f"({self.dimension})" + ) + stage1_distance = f"{table_name}.vec_binary {binary_search_measure_repr} $1::bit{bit_dim}" + stage1_param = binary_query + + cols.append( + f"{table_name}.vec" + ) # Need original vector for re-ranking + if search_settings.include_metadatas: + cols.append(f"{table_name}.metadata") + + select_clause = ", ".join(cols) + where_clause = "" + params.append(stage1_param) + + if search_settings.filters: + where_clause, params = apply_filters( + search_settings.filters, params, mode="where_clause" + ) + + vector_dim = ( + "" if math.isnan(self.dimension) else f"({self.dimension})" + ) + + # First stage: Get candidates using binary search + query = f""" + WITH candidates AS ( + SELECT {select_clause}, + ({stage1_distance}) as binary_distance + FROM {table_name} + {where_clause} + ORDER BY {stage1_distance} + LIMIT ${len(params) + 1} + OFFSET ${len(params) + 2} + ) + -- Second stage: Re-rank using original vectors + SELECT + id, + document_id, + owner_id, + collection_ids, + text, + {"metadata," if search_settings.include_metadatas else ""} + (vec <=> ${len(params) + 4}::vector{vector_dim}) as distance + FROM candidates + ORDER BY distance + LIMIT ${len(params) + 3} + """ + + params.extend( + [ + extended_limit, # First stage limit + search_settings.offset, + search_settings.limit, # Final limit + str(query_vector), # For re-ranking + ] + ) + + else: + # Standard float vector handling + vector_dim = ( + "" if math.isnan(self.dimension) else f"({self.dimension})" + ) + distance_calc = f"{table_name}.vec {search_settings.chunk_settings.index_measure.pgvector_repr} $1::vector{vector_dim}" + query_param = str(query_vector) + + if search_settings.include_scores: + cols.append(f"({distance_calc}) AS distance") + if search_settings.include_metadatas: + cols.append(f"{table_name}.metadata") + + select_clause = ", ".join(cols) + where_clause = "" + params.append(query_param) + + if search_settings.filters: + where_clause, new_params = apply_filters( + search_settings.filters, + params, + mode="where_clause", # Get just conditions without WHERE + ) + params = new_params + + query = f""" + SELECT {select_clause} + FROM {table_name} + {where_clause} + ORDER BY {distance_calc} + LIMIT ${len(params) + 1} + OFFSET ${len(params) + 2} + """ + params.extend([search_settings.limit, search_settings.offset]) + results = await self.connection_manager.fetch_query(query, params) + + return [ + ChunkSearchResult( + id=UUID(str(result["id"])), + document_id=UUID(str(result["document_id"])), + owner_id=UUID(str(result["owner_id"])), + collection_ids=result["collection_ids"], + text=result["text"], + score=( + (1 - float(result["distance"])) + if "distance" in result + else -1 + ), + metadata=( + json.loads(result["metadata"]) + if search_settings.include_metadatas + else {} + ), + ) + for result in results + ] + + async def full_text_search( + self, query_text: str, search_settings: SearchSettings + ) -> list[ChunkSearchResult]: + conditions = [] + params: list[str | int | bytes] = [query_text] + + conditions.append("fts @@ websearch_to_tsquery('english', $1)") + + if search_settings.filters: + filter_condition, params = apply_filters( + search_settings.filters, params, mode="condition_only" + ) + if filter_condition: + conditions.append(filter_condition) + + where_clause = "WHERE " + " AND ".join(conditions) + + query = f""" + SELECT + id, + document_id, + owner_id, + collection_ids, + text, + metadata, + ts_rank(fts, websearch_to_tsquery('english', $1), 32) as rank + FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} + {where_clause} + ORDER BY rank DESC + OFFSET ${len(params) + 1} + LIMIT ${len(params) + 2} + """ + + params.extend( + [ + search_settings.offset, + search_settings.hybrid_settings.full_text_limit, + ] + ) + + results = await self.connection_manager.fetch_query(query, params) + return [ + ChunkSearchResult( + id=UUID(str(r["id"])), + document_id=UUID(str(r["document_id"])), + owner_id=UUID(str(r["owner_id"])), + collection_ids=r["collection_ids"], + text=r["text"], + score=float(r["rank"]), + metadata=json.loads(r["metadata"]), + ) + for r in results + ] + + async def hybrid_search( + self, + query_text: str, + query_vector: list[float], + search_settings: SearchSettings, + *args, + **kwargs, + ) -> list[ChunkSearchResult]: + if search_settings.hybrid_settings is None: + raise ValueError( + "Please provide a valid `hybrid_settings` in the `search_settings`." + ) + if ( + search_settings.hybrid_settings.full_text_limit + < search_settings.limit + ): + raise ValueError( + "The `full_text_limit` must be greater than or equal to the `limit`." + ) + + semantic_settings = copy.deepcopy(search_settings) + semantic_settings.limit += search_settings.offset + + full_text_settings = copy.deepcopy(search_settings) + full_text_settings.hybrid_settings.full_text_limit += ( + search_settings.offset + ) + + semantic_results: list[ChunkSearchResult] = await self.semantic_search( + query_vector, semantic_settings + ) + full_text_results: list[ + ChunkSearchResult + ] = await self.full_text_search(query_text, full_text_settings) + + semantic_limit = search_settings.limit + full_text_limit = search_settings.hybrid_settings.full_text_limit + semantic_weight = search_settings.hybrid_settings.semantic_weight + full_text_weight = search_settings.hybrid_settings.full_text_weight + rrf_k = search_settings.hybrid_settings.rrf_k + + combined_results: dict[uuid.UUID, HybridSearchIntermediateResult] = {} + + for rank, result in enumerate(semantic_results, 1): + combined_results[result.id] = { + "semantic_rank": rank, + "full_text_rank": full_text_limit, + "data": result, + "rrf_score": 0.0, # Initialize with 0, will be calculated later + } + + for rank, result in enumerate(full_text_results, 1): + if result.id in combined_results: + combined_results[result.id]["full_text_rank"] = rank + else: + combined_results[result.id] = { + "semantic_rank": semantic_limit, + "full_text_rank": rank, + "data": result, + "rrf_score": 0.0, # Initialize with 0, will be calculated later + } + + combined_results = { + k: v + for k, v in combined_results.items() + if v["semantic_rank"] <= semantic_limit * 2 + and v["full_text_rank"] <= full_text_limit * 2 + } + + for hyb_result in combined_results.values(): + semantic_score = 1 / (rrf_k + hyb_result["semantic_rank"]) + full_text_score = 1 / (rrf_k + hyb_result["full_text_rank"]) + hyb_result["rrf_score"] = ( + semantic_score * semantic_weight + + full_text_score * full_text_weight + ) / (semantic_weight + full_text_weight) + + sorted_results = sorted( + combined_results.values(), + key=lambda x: x["rrf_score"], + reverse=True, + ) + offset_results = sorted_results[ + search_settings.offset : search_settings.offset + + search_settings.limit + ] + + return [ + ChunkSearchResult( + id=result["data"].id, + document_id=result["data"].document_id, + owner_id=result["data"].owner_id, + collection_ids=result["data"].collection_ids, + text=result["data"].text, + score=result["rrf_score"], + metadata={ + **result["data"].metadata, + "semantic_rank": result["semantic_rank"], + "full_text_rank": result["full_text_rank"], + }, + ) + for result in offset_results + ] + + async def delete( + self, filters: dict[str, Any] + ) -> dict[str, dict[str, str]]: + params: list[str | int | bytes] = [] + where_clause, params = apply_filters( + filters, params, mode="condition_only" + ) + + query = f""" + DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} + WHERE {where_clause} + RETURNING id, document_id, text; + """ + + results = await self.connection_manager.fetch_query(query, params) + + return { + str(result["id"]): { + "status": "deleted", + "id": str(result["id"]), + "document_id": str(result["document_id"]), + "text": result["text"], + } + for result in results + } + + async def assign_document_chunks_to_collection( + self, document_id: UUID, collection_id: UUID + ) -> None: + query = f""" + UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} + SET collection_ids = array_append(collection_ids, $1) + WHERE document_id = $2 AND NOT ($1 = ANY(collection_ids)); + """ + return await self.connection_manager.execute_query( + query, (str(collection_id), str(document_id)) + ) + + async def remove_document_from_collection_vector( + self, document_id: UUID, collection_id: UUID + ) -> None: + query = f""" + UPDATE {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} + SET collection_ids = array_remove(collection_ids, $1) + WHERE document_id = $2; + """ + await self.connection_manager.execute_query( + query, (collection_id, document_id) + ) + + async def delete_user_vector(self, owner_id: UUID) -> None: + query = f""" + DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} + WHERE owner_id = $1; + """ + await self.connection_manager.execute_query(query, (owner_id,)) + + async def delete_collection_vector(self, collection_id: UUID) -> None: + query = f""" + DELETE FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} + WHERE $1 = ANY(collection_ids) + RETURNING collection_ids + """ + await self.connection_manager.fetchrow_query(query, (collection_id,)) + return None + + async def list_document_chunks( + self, + document_id: UUID, + offset: int, + limit: int, + include_vectors: bool = False, + ) -> dict[str, Any]: + vector_select = ", vec" if include_vectors else "" + limit_clause = f"LIMIT {limit}" if limit > -1 else "" + + query = f""" + SELECT id, document_id, owner_id, collection_ids, text, metadata{vector_select}, COUNT(*) OVER() AS total + FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} + WHERE document_id = $1 + ORDER BY (metadata->>'chunk_order')::integer + OFFSET $2 + {limit_clause}; + """ + + params = [document_id, offset] + + results = await self.connection_manager.fetch_query(query, params) + + chunks = [] + total = 0 + if results: + total = results[0].get("total", 0) + chunks = [ + { + "id": result["id"], + "document_id": result["document_id"], + "owner_id": result["owner_id"], + "collection_ids": result["collection_ids"], + "text": result["text"], + "metadata": json.loads(result["metadata"]), + "vector": ( + json.loads(result["vec"]) if include_vectors else None + ), + } + for result in results + ] + + return {"results": chunks, "total_entries": total} + + async def get_chunk(self, id: UUID) -> dict: + query = f""" + SELECT id, document_id, owner_id, collection_ids, text, metadata + FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} + WHERE id = $1; + """ + + result = await self.connection_manager.fetchrow_query(query, (id,)) + + if result: + return { + "id": result["id"], + "document_id": result["document_id"], + "owner_id": result["owner_id"], + "collection_ids": result["collection_ids"], + "text": result["text"], + "metadata": json.loads(result["metadata"]), + } + raise R2RException( + message=f"Chunk with ID {id} not found", status_code=404 + ) + + async def create_index( + self, + table_name: Optional[VectorTableName] = None, + index_measure: IndexMeasure = IndexMeasure.cosine_distance, + index_method: IndexMethod = IndexMethod.auto, + index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW] = None, + index_name: Optional[str] = None, + index_column: Optional[str] = None, + concurrently: bool = True, + ) -> None: + """Creates an index for the collection. + + Note: + When `vecs` creates an index on a pgvector column in PostgreSQL, it uses a multi-step + process that enables performant indexes to be built for large collections with low end + database hardware. + + Those steps are: + + - Creates a new table with a different name + - Randomly selects records from the existing table + - Inserts the random records from the existing table into the new table + - Creates the requested vector index on the new table + - Upserts all data from the existing table into the new table + - Drops the existing table + - Renames the new table to the existing tables name + + If you create dependencies (like views) on the table that underpins + a `vecs.Collection` the `create_index` step may require you to drop those dependencies before + it will succeed. + + Args: + index_measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'. + index_method (IndexMethod, optional): The indexing method to use. Defaults to 'auto'. + index_arguments: (IndexArgsIVFFlat | IndexArgsHNSW, optional): Index type specific arguments + index_name (str, optional): The name of the index to create. Defaults to None. + concurrently (bool, optional): Whether to create the index concurrently. Defaults to True. + Raises: + ValueError: If an invalid index method is used, or if *replace* is False and an index already exists. + """ + + if table_name == VectorTableName.CHUNKS: + table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}" # TODO - Fix bug in vector table naming convention + if index_column: + col_name = index_column + else: + col_name = ( + "vec" + if ( + index_measure != IndexMeasure.hamming_distance + and index_measure != IndexMeasure.jaccard_distance + ) + else "vec_binary" + ) + elif table_name == VectorTableName.ENTITIES_DOCUMENT: + table_name_str = ( + f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}" + ) + col_name = "description_embedding" + elif table_name == VectorTableName.GRAPHS_ENTITIES: + table_name_str = ( + f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}" + ) + col_name = "description_embedding" + elif table_name == VectorTableName.COMMUNITIES: + table_name_str = ( + f"{self.project_name}.{VectorTableName.COMMUNITIES}" + ) + col_name = "embedding" + else: + raise ValueError("invalid table name") + + if index_method not in ( + IndexMethod.ivfflat, + IndexMethod.hnsw, + IndexMethod.auto, + ): + raise ValueError("invalid index method") + + if index_arguments: + # Disallow case where user submits index arguments but uses the + # IndexMethod.auto index (index build arguments should only be + # used with a specific index) + if index_method == IndexMethod.auto: + raise ValueError( + "Index build parameters are not allowed when using the IndexMethod.auto index." + ) + # Disallow case where user specifies one index type but submits + # index build arguments for the other index type + if ( + isinstance(index_arguments, IndexArgsHNSW) + and index_method != IndexMethod.hnsw + ) or ( + isinstance(index_arguments, IndexArgsIVFFlat) + and index_method != IndexMethod.ivfflat + ): + raise ValueError( + f"{index_arguments.__class__.__name__} build parameters were supplied but {index_method} index was specified." + ) + + if index_method == IndexMethod.auto: + index_method = IndexMethod.hnsw + + ops = index_measure_to_ops( + index_measure # , quantization_type=self.quantization_type + ) + + if ops is None: + raise ValueError("Unknown index measure") + + concurrently_sql = "CONCURRENTLY" if concurrently else "" + + index_name = ( + index_name + or f"ix_{ops}_{index_method}__{col_name}_{time.strftime('%Y%m%d%H%M%S')}" + ) + + create_index_sql = f""" + CREATE INDEX {concurrently_sql} {index_name} + ON {table_name_str} + USING {index_method} ({col_name} {ops}) {self._get_index_options(index_method, index_arguments)}; + """ + + try: + if concurrently: + async with ( + self.connection_manager.pool.get_connection() as conn # type: ignore + ): + # Disable automatic transaction management + await conn.execute( + "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED" + ) + await conn.execute(create_index_sql) + else: + # Non-concurrent index creation can use normal query execution + await self.connection_manager.execute_query(create_index_sql) + except Exception as e: + raise Exception(f"Failed to create index: {e}") from e + return None + + async def list_indices( + self, + offset: int, + limit: int, + filters: Optional[dict[str, Any]] = None, + ) -> dict: + where_clauses = [] + params: list[Any] = [self.project_name] # Start with schema name + param_count = 1 + + # Handle filtering + if filters: + if "table_name" in filters: + where_clauses.append(f"i.tablename = ${param_count + 1}") + params.append(filters["table_name"]) + param_count += 1 + if "index_method" in filters: + where_clauses.append(f"am.amname = ${param_count + 1}") + params.append(filters["index_method"]) + param_count += 1 + if "index_name" in filters: + where_clauses.append( + f"LOWER(i.indexname) LIKE LOWER(${param_count + 1})" + ) + params.append(f"%{filters['index_name']}%") + param_count += 1 + + where_clause = " AND ".join(where_clauses) if where_clauses else "" + if where_clause: + where_clause = f"AND {where_clause}" + + query = f""" + WITH index_info AS ( + SELECT + i.indexname as name, + i.tablename as table_name, + i.indexdef as definition, + am.amname as method, + pg_relation_size(c.oid) as size_in_bytes, + c.reltuples::bigint as row_estimate, + COALESCE(psat.idx_scan, 0) as number_of_scans, + COALESCE(psat.idx_tup_read, 0) as tuples_read, + COALESCE(psat.idx_tup_fetch, 0) as tuples_fetched, + COUNT(*) OVER() as total_count + FROM pg_indexes i + JOIN pg_class c ON c.relname = i.indexname + JOIN pg_am am ON c.relam = am.oid + LEFT JOIN pg_stat_user_indexes psat ON psat.indexrelname = i.indexname + AND psat.schemaname = i.schemaname + WHERE i.schemaname = $1 + AND i.indexdef LIKE '%vector%' + {where_clause} + ) + SELECT * + FROM index_info + ORDER BY name + LIMIT ${param_count + 1} + OFFSET ${param_count + 2} + """ + + # Add limit and offset to params + params.extend([limit, offset]) + + results = await self.connection_manager.fetch_query(query, params) + + indices = [] + total_entries = 0 + + if results: + total_entries = results[0]["total_count"] + for result in results: + index_info = { + "name": result["name"], + "table_name": result["table_name"], + "definition": result["definition"], + "size_in_bytes": result["size_in_bytes"], + "row_estimate": result["row_estimate"], + "number_of_scans": result["number_of_scans"], + "tuples_read": result["tuples_read"], + "tuples_fetched": result["tuples_fetched"], + } + indices.append(index_info) + + return {"indices": indices, "total_entries": total_entries} + + async def delete_index( + self, + index_name: str, + table_name: Optional[VectorTableName] = None, + concurrently: bool = True, + ) -> None: + """Deletes a vector index. + + Args: + index_name (str): Name of the index to delete + table_name (VectorTableName, optional): Table the index belongs to + concurrently (bool): Whether to drop the index concurrently + + Raises: + ValueError: If table name is invalid or index doesn't exist + Exception: If index deletion fails + """ + # Validate table name and get column name + if table_name == VectorTableName.CHUNKS: + table_name_str = f"{self.project_name}.{VectorTableName.CHUNKS}" + col_name = "vec" + elif table_name == VectorTableName.ENTITIES_DOCUMENT: + table_name_str = ( + f"{self.project_name}.{VectorTableName.ENTITIES_DOCUMENT}" + ) + col_name = "description_embedding" + elif table_name == VectorTableName.GRAPHS_ENTITIES: + table_name_str = ( + f"{self.project_name}.{VectorTableName.GRAPHS_ENTITIES}" + ) + col_name = "description_embedding" + elif table_name == VectorTableName.COMMUNITIES: + table_name_str = ( + f"{self.project_name}.{VectorTableName.COMMUNITIES}" + ) + col_name = "description_embedding" + else: + raise ValueError("invalid table name") + + # Extract schema and base table name + schema_name, base_table_name = table_name_str.split(".") + + # Verify index exists and is a vector index + query = """ + SELECT indexdef + FROM pg_indexes + WHERE indexname = $1 + AND schemaname = $2 + AND tablename = $3 + AND indexdef LIKE $4 + """ + + result = await self.connection_manager.fetchrow_query( + query, (index_name, schema_name, base_table_name, f"%({col_name}%") + ) + + if not result: + raise ValueError( + f"Vector index '{index_name}' does not exist on table {table_name_str}" + ) + + # Drop the index + concurrently_sql = "CONCURRENTLY" if concurrently else "" + drop_query = ( + f"DROP INDEX {concurrently_sql} {schema_name}.{index_name}" + ) + + try: + if concurrently: + async with ( + self.connection_manager.pool.get_connection() as conn # type: ignore + ): + # Disable automatic transaction management + await conn.execute( + "SET SESSION CHARACTERISTICS AS TRANSACTION ISOLATION LEVEL READ COMMITTED" + ) + await conn.execute(drop_query) + else: + await self.connection_manager.execute_query(drop_query) + except Exception as e: + raise Exception(f"Failed to delete index: {e}") from e + + async def list_chunks( + self, + offset: int, + limit: int, + filters: Optional[dict[str, Any]] = None, + include_vectors: bool = False, + ) -> dict[str, Any]: + """List chunks with pagination support. + + Args: + offset (int, optional): Number of records to skip. Defaults to 0. + limit (int, optional): Maximum number of records to return. Defaults to 10. + filters (dict, optional): Dictionary of filters to apply. Defaults to None. + include_vectors (bool, optional): Whether to include vector data. Defaults to False. + + Returns: + dict: Dictionary containing: + - results: List of chunk records + - total_entries: Total number of chunks matching the filters + """ + vector_select = ", vec" if include_vectors else "" + select_clause = f""" + id, document_id, owner_id, collection_ids, + text, metadata{vector_select}, COUNT(*) OVER() AS total_entries + """ + + params: list[str | int | bytes] = [] + where_clause = "" + if filters: + where_clause, params = apply_filters( + filters, params, mode="where_clause" + ) + + query = f""" + SELECT {select_clause} + FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} + {where_clause} + LIMIT ${len(params) + 1} + OFFSET ${len(params) + 2} + """ + + params.extend([limit, offset]) + + # Execute the query + results = await self.connection_manager.fetch_query(query, params) + + # Process results + chunks = [] + total_entries = 0 + if results: + total_entries = results[0].get("total_entries", 0) + chunks = [ + { + "id": str(result["id"]), + "document_id": str(result["document_id"]), + "owner_id": str(result["owner_id"]), + "collection_ids": result["collection_ids"], + "text": result["text"], + "metadata": json.loads(result["metadata"]), + "vector": ( + json.loads(result["vec"]) if include_vectors else None + ), + } + for result in results + ] + + return {"results": chunks, "total_entries": total_entries} + + async def search_documents( + self, + query_text: str, + settings: SearchSettings, + ) -> list[dict[str, Any]]: + """Search for documents based on their metadata fields and/or body + text. Joins with documents table to get complete document metadata. + + Args: + query_text (str): The search query text + settings (SearchSettings): Search settings including search preferences and filters + + Returns: + list[dict[str, Any]]: List of documents with their search scores and complete metadata + """ + where_clauses = [] + params: list[str | int | bytes] = [query_text] + + search_over_body = getattr(settings, "search_over_body", True) + search_over_metadata = getattr(settings, "search_over_metadata", True) + metadata_weight = getattr(settings, "metadata_weight", 3.0) + title_weight = getattr(settings, "title_weight", 1.0) + metadata_keys = getattr( + settings, "metadata_keys", ["title", "description"] + ) + + # Build the dynamic metadata field search expression + metadata_fields_expr = " || ' ' || ".join( + [ + f"COALESCE(v.metadata->>{psql_quote_literal(key)}, '')" + for key in metadata_keys # type: ignore + ] + ) + + query = f""" + WITH + -- Metadata search scores + metadata_scores AS ( + SELECT DISTINCT ON (v.document_id) + v.document_id, + d.metadata as doc_metadata, + CASE WHEN $1 = '' THEN 0.0 + ELSE + ts_rank_cd( + setweight(to_tsvector('english', {metadata_fields_expr}), 'A'), + websearch_to_tsquery('english', $1), + 32 + ) + END as metadata_rank + FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} v + LEFT JOIN {self._get_table_name("documents")} d ON v.document_id = d.id + WHERE v.metadata IS NOT NULL + ), + -- Body search scores + body_scores AS ( + SELECT + document_id, + AVG( + ts_rank_cd( + setweight(to_tsvector('english', COALESCE(text, '')), 'B'), + websearch_to_tsquery('english', $1), + 32 + ) + ) as body_rank + FROM {self._get_table_name(PostgresChunksHandler.TABLE_NAME)} + WHERE $1 != '' + {"AND to_tsvector('english', text) @@ websearch_to_tsquery('english', $1)" if search_over_body else ""} + GROUP BY document_id + ), + -- Combined scores with document metadata + combined_scores AS ( + SELECT + COALESCE(m.document_id, b.document_id) as document_id, + m.doc_metadata as metadata, + COALESCE(m.metadata_rank, 0) as debug_metadata_rank, + COALESCE(b.body_rank, 0) as debug_body_rank, + CASE + WHEN {str(search_over_metadata).lower()} AND {str(search_over_body).lower()} THEN + COALESCE(m.metadata_rank, 0) * {metadata_weight} + COALESCE(b.body_rank, 0) * {title_weight} + WHEN {str(search_over_metadata).lower()} THEN + COALESCE(m.metadata_rank, 0) + WHEN {str(search_over_body).lower()} THEN + COALESCE(b.body_rank, 0) + ELSE 0 + END as rank + FROM metadata_scores m + FULL OUTER JOIN body_scores b ON m.document_id = b.document_id + WHERE ( + ($1 = '') OR + ({str(search_over_metadata).lower()} AND m.metadata_rank > 0) OR + ({str(search_over_body).lower()} AND b.body_rank > 0) + ) + """ + + # Add any additional filters + if settings.filters: + filter_clause, params = apply_filters(settings.filters, params) + where_clauses.append(filter_clause) + + if where_clauses: + query += f" AND {' AND '.join(where_clauses)}" + + query += """ + ) + SELECT + document_id, + metadata, + rank as score, + debug_metadata_rank, + debug_body_rank + FROM combined_scores + WHERE rank > 0 + ORDER BY rank DESC + OFFSET ${offset_param} LIMIT ${limit_param} + """.format( + offset_param=len(params) + 1, + limit_param=len(params) + 2, + ) + + # Add offset and limit to params + params.extend([settings.offset, settings.limit]) + + # Execute query + results = await self.connection_manager.fetch_query(query, params) + + # Format results with complete document metadata + return [ + { + "document_id": str(r["document_id"]), + "metadata": ( + json.loads(r["metadata"]) + if isinstance(r["metadata"], str) + else r["metadata"] + ), + "score": float(r["score"]), + "debug_metadata_rank": float(r["debug_metadata_rank"]), + "debug_body_rank": float(r["debug_body_rank"]), + } + for r in results + ] + + def _get_index_options( + self, + method: IndexMethod, + index_arguments: Optional[IndexArgsIVFFlat | IndexArgsHNSW], + ) -> str: + if method == IndexMethod.ivfflat: + if isinstance(index_arguments, IndexArgsIVFFlat): + return f"WITH (lists={index_arguments.n_lists})" + else: + # Default value if no arguments provided + return "WITH (lists=100)" + elif method == IndexMethod.hnsw: + if isinstance(index_arguments, IndexArgsHNSW): + return f"WITH (m={index_arguments.m}, ef_construction={index_arguments.ef_construction})" + else: + # Default values if no arguments provided + return "WITH (m=16, ef_construction=64)" + else: + return "" # No options for other methods |