"""Abstractions for search functionality."""
from copy import copy
from enum import Enum
from typing import Any, Optional
from uuid import NAMESPACE_DNS, UUID, uuid5
from pydantic import Field
from .base import R2RSerializable
from .document import DocumentResponse
from .llm import GenerationConfig
from .vector import IndexMeasure
def generate_id_from_label(label) -> UUID:
return uuid5(NAMESPACE_DNS, label)
class ChunkSearchResult(R2RSerializable):
"""Result of a search operation."""
id: UUID
document_id: UUID
owner_id: Optional[UUID]
collection_ids: list[UUID]
score: Optional[float] = None
text: str
metadata: dict[str, Any]
def __str__(self) -> str:
if self.score:
return (
f"ChunkSearchResult(score={self.score:.3f}, text={self.text})"
)
else:
return f"ChunkSearchResult(text={self.text})"
def __repr__(self) -> str:
return self.__str__()
def as_dict(self) -> dict:
return {
"id": self.id,
"document_id": self.document_id,
"owner_id": self.owner_id,
"collection_ids": self.collection_ids,
"score": self.score,
"text": self.text,
"metadata": self.metadata,
}
class Config:
populate_by_name = True
json_schema_extra = {
"example": {
"id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
"document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b",
"owner_id": "2acb499e-8428-543b-bd85-0d9098718220",
"collection_ids": [],
"score": 0.23943702876567796,
"text": "Example text from the document",
"metadata": {
"title": "example_document.pdf",
"associated_query": "What is the capital of France?",
},
}
}
class GraphSearchResultType(str, Enum):
ENTITY = "entity"
RELATIONSHIP = "relationship"
COMMUNITY = "community"
class GraphEntityResult(R2RSerializable):
id: Optional[UUID] = None
name: str
description: str
metadata: Optional[dict[str, Any]] = None
class Config:
json_schema_extra = {
"example": {
"name": "Entity Name",
"description": "Entity Description",
"metadata": {},
}
}
class GraphRelationshipResult(R2RSerializable):
id: Optional[UUID] = None
subject: str
predicate: str
object: str
subject_id: Optional[UUID] = None
object_id: Optional[UUID] = None
metadata: Optional[dict[str, Any]] = None
score: Optional[float] = None
description: str | None = None
class Config:
json_schema_extra = {
"example": {
"name": "Relationship Name",
"description": "Relationship Description",
"metadata": {},
}
}
def __str__(self) -> str:
return f"GraphRelationshipResult(subject={self.subject}, predicate={self.predicate}, object={self.object})"
class GraphCommunityResult(R2RSerializable):
id: Optional[UUID] = None
name: str
summary: str
metadata: Optional[dict[str, Any]] = None
class Config:
json_schema_extra = {
"example": {
"name": "Community Name",
"summary": "Community Summary",
"rating": 9,
"rating_explanation": "Rating Explanation",
"metadata": {},
}
}
def __str__(self) -> str:
return (
f"GraphCommunityResult(name={self.name}, summary={self.summary})"
)
class GraphSearchResult(R2RSerializable):
content: GraphEntityResult | GraphRelationshipResult | GraphCommunityResult
result_type: Optional[GraphSearchResultType] = None
chunk_ids: Optional[list[UUID]] = None
metadata: dict[str, Any] = {}
score: Optional[float] = None
id: UUID
def __str__(self) -> str:
return f"GraphSearchResult(content={self.content}, result_type={self.result_type})"
class Config:
populate_by_name = True
json_schema_extra = {
"example": {
"content": {
"id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
"name": "Entity Name",
"description": "Entity Description",
"metadata": {},
},
"result_type": "entity",
"chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"],
"metadata": {
"associated_query": "What is the capital of France?"
},
}
}
class WebPageSearchResult(R2RSerializable):
title: Optional[str] = None
link: Optional[str] = None
snippet: Optional[str] = None
position: int
type: str = "organic"
date: Optional[str] = None
sitelinks: Optional[list[dict]] = None
id: UUID
class Config:
json_schema_extra = {
"example": {
"title": "Page Title",
"link": "https://example.com/page",
"snippet": "Page snippet",
"position": 1,
"date": "2021-01-01",
"sitelinks": [
{
"title": "Sitelink Title",
"link": "https://example.com/sitelink",
}
],
}
}
def __str__(self) -> str:
return f"WebPageSearchResult(title={self.title}, link={self.link}, snippet={self.snippet})"
class RelatedSearchResult(R2RSerializable):
query: str
type: str = "related"
id: UUID
class PeopleAlsoAskResult(R2RSerializable):
question: str
snippet: str
link: str
title: str
id: UUID
type: str = "peopleAlsoAsk"
class WebSearchResult(R2RSerializable):
organic_results: list[WebPageSearchResult] = []
related_searches: list[RelatedSearchResult] = []
people_also_ask: list[PeopleAlsoAskResult] = []
@classmethod
def from_serper_results(cls, results: list[dict]) -> "WebSearchResult":
organic = []
related = []
paa = []
for result in results:
if result["type"] == "organic":
organic.append(
WebPageSearchResult(
**result, id=generate_id_from_label(result.get("link"))
)
)
elif result["type"] == "relatedSearches":
related.append(
RelatedSearchResult(
**result,
id=generate_id_from_label(result.get("query")),
)
)
elif result["type"] == "peopleAlsoAsk":
paa.append(
PeopleAlsoAskResult(
**result, id=generate_id_from_label(result.get("link"))
)
)
return cls(
organic_results=organic,
related_searches=related,
people_also_ask=paa,
)
class AggregateSearchResult(R2RSerializable):
"""Result of an aggregate search operation."""
chunk_search_results: Optional[list[ChunkSearchResult]] = None
graph_search_results: Optional[list[GraphSearchResult]] = None
web_search_results: Optional[list[WebPageSearchResult]] = None
document_search_results: Optional[list[DocumentResponse]] = None
def __str__(self) -> str:
return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results}, document_search_results={str(self.document_search_results)})"
def __repr__(self) -> str:
return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results}, document_search_results={str(self.document_search_results)})"
def as_dict(self) -> dict:
return {
"chunk_search_results": (
[result.as_dict() for result in self.chunk_search_results]
if self.chunk_search_results
else []
),
"graph_search_results": (
[result.to_dict() for result in self.graph_search_results]
if self.graph_search_results
else []
),
"web_search_results": (
[result.to_dict() for result in self.web_search_results]
if self.web_search_results
else []
),
"document_search_results": (
[cdr.to_dict() for cdr in self.document_search_results]
if self.document_search_results
else []
),
}
class Config:
populate_by_name = True
json_schema_extra = {
"example": {
"chunk_search_results": [
{
"id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
"document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b",
"owner_id": "2acb499e-8428-543b-bd85-0d9098718220",
"collection_ids": [],
"score": 0.23943702876567796,
"text": "Example text from the document",
"metadata": {
"title": "example_document.pdf",
"associated_query": "What is the capital of France?",
},
}
],
"graph_search_results": [
{
"content": {
"id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
"name": "Entity Name",
"description": "Entity Description",
"metadata": {},
},
"result_type": "entity",
"chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"],
"metadata": {
"associated_query": "What is the capital of France?"
},
}
],
"web_search_results": [
{
"title": "Page Title",
"link": "https://example.com/page",
"snippet": "Page snippet",
"position": 1,
"date": "2021-01-01",
"sitelinks": [
{
"title": "Sitelink Title",
"link": "https://example.com/sitelink",
}
],
}
],
"document_search_results": [
{
"document": {
"id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
"title": "Document Title",
"chunks": ["Chunk 1", "Chunk 2"],
"metadata": {},
},
}
],
}
}
class HybridSearchSettings(R2RSerializable):
"""Settings for hybrid search combining full-text and semantic search."""
full_text_weight: float = Field(
default=1.0, description="Weight to apply to full text search"
)
semantic_weight: float = Field(
default=5.0, description="Weight to apply to semantic search"
)
full_text_limit: int = Field(
default=200,
description="Maximum number of results to return from full text search",
)
rrf_k: int = Field(
default=50, description="K-value for RRF (Rank Reciprocal Fusion)"
)
class ChunkSearchSettings(R2RSerializable):
"""Settings specific to chunk/vector search."""
index_measure: IndexMeasure = Field(
default=IndexMeasure.cosine_distance,
description="The distance measure to use for indexing",
)
probes: int = Field(
default=10,
description="Number of ivfflat index lists to query. Higher increases accuracy but decreases speed.",
)
ef_search: int = Field(
default=40,
description="Size of the dynamic candidate list for HNSW index search. Higher increases accuracy but decreases speed.",
)
enabled: bool = Field(
default=True,
description="Whether to enable chunk search",
)
class GraphSearchSettings(R2RSerializable):
"""Settings specific to knowledge graph search."""
generation_config: Optional[GenerationConfig] = Field(
default=None,
description="Configuration for text generation during graph search.",
)
max_community_description_length: int = Field(
default=65536,
)
max_llm_queries_for_global_search: int = Field(
default=250,
)
limits: dict[str, int] = Field(
default={},
)
enabled: bool = Field(
default=True,
description="Whether to enable graph search",
)
class SearchSettings(R2RSerializable):
"""Main search settings class that combines shared settings with
specialized settings for chunks and graph."""
# Search type flags
use_hybrid_search: bool = Field(
default=False,
description="Whether to perform a hybrid search. This is equivalent to setting `use_semantic_search=True` and `use_fulltext_search=True`, e.g. combining vector and keyword search.",
)
use_semantic_search: bool = Field(
default=True,
description="Whether to use semantic search",
)
use_fulltext_search: bool = Field(
default=False,
description="Whether to use full-text search",
)
# Common search parameters
filters: dict[str, Any] = Field(
default_factory=dict,
description="""Filters to apply to the search. Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`.
Commonly seen filters include operations include the following:
`{"document_id": {"$eq": "9fbe403b-..."}}`
`{"document_id": {"$in": ["9fbe403b-...", "3e157b3a-..."]}}`
`{"collection_ids": {"$overlap": ["122fdf6a-...", "..."]}}`
`{"$and": {"$document_id": ..., "collection_ids": ...}}`""",
)
limit: int = Field(
default=10,
description="Maximum number of results to return",
ge=1,
le=1_000,
)
offset: int = Field(
default=0,
ge=0,
description="Offset to paginate search results",
)
include_metadatas: bool = Field(
default=True,
description="Whether to include element metadata in the search results",
)
include_scores: bool = Field(
default=True,
description="""Whether to include search score values in the
search results""",
)
# Search strategy and settings
search_strategy: str = Field(
default="vanilla",
description="""Search strategy to use
(e.g., 'vanilla', 'query_fusion', 'hyde')""",
)
hybrid_settings: HybridSearchSettings = Field(
default_factory=HybridSearchSettings,
description="""Settings for hybrid search (only used if
`use_semantic_search` and `use_fulltext_search` are both true)""",
)
# Specialized settings
chunk_settings: ChunkSearchSettings = Field(
default_factory=ChunkSearchSettings,
description="Settings specific to chunk/vector search",
)
graph_settings: GraphSearchSettings = Field(
default_factory=GraphSearchSettings,
description="Settings specific to knowledge graph search",
)
# For HyDE or multi-query:
num_sub_queries: int = Field(
default=5,
description="Number of sub-queries/hypothetical docs to generate when using hyde or rag_fusion search strategies.",
)
class Config:
populate_by_name = True
json_encoders = {UUID: str}
json_schema_extra = {
"example": {
"use_semantic_search": True,
"use_fulltext_search": False,
"use_hybrid_search": False,
"filters": {"category": "technology"},
"limit": 20,
"offset": 0,
"search_strategy": "vanilla",
"hybrid_settings": {
"full_text_weight": 1.0,
"semantic_weight": 5.0,
"full_text_limit": 200,
"rrf_k": 50,
},
"chunk_settings": {
"enabled": True,
"index_measure": "cosine_distance",
"include_metadata": True,
"probes": 10,
"ef_search": 40,
},
"graph_settings": {
"enabled": True,
"generation_config": GenerationConfig.Config.json_schema_extra,
"max_community_description_length": 65536,
"max_llm_queries_for_global_search": 250,
"limits": {
"entity": 20,
"relationship": 20,
"community": 20,
},
},
}
}
def __init__(self, **data):
# Handle legacy search_filters field
data["filters"] = {
**data.get("filters", {}),
**data.get("search_filters", {}),
}
super().__init__(**data)
def model_dump(self, *args, **kwargs):
return super().model_dump(*args, **kwargs)
@classmethod
def get_default(cls, mode: str) -> "SearchSettings":
"""Return default search settings for a given mode."""
if mode == "basic":
# A simpler search that relies primarily on semantic search.
return cls(
use_semantic_search=True,
use_fulltext_search=False,
use_hybrid_search=False,
search_strategy="vanilla",
# Other relevant defaults can be provided here as needed
)
elif mode == "advanced":
# A more powerful, combined search that leverages both semantic and fulltext.
return cls(
use_semantic_search=True,
use_fulltext_search=True,
use_hybrid_search=True,
search_strategy="hyde",
# Other advanced defaults as needed
)
else:
# For 'custom' or unrecognized modes, return a basic empty config.
return cls()
class SearchMode(str, Enum):
"""Search modes for the search endpoint."""
basic = "basic"
advanced = "advanced"
custom = "custom"
def select_search_filters(
auth_user: Any,
search_settings: SearchSettings,
) -> dict[str, Any]:
filters = copy(search_settings.filters)
selected_collections = None
if not auth_user.is_superuser:
user_collections = set(auth_user.collection_ids)
for key in filters.keys():
if "collection_ids" in key:
selected_collections = set(map(UUID, filters[key]["$overlap"]))
break
if selected_collections:
allowed_collections = user_collections.intersection(
selected_collections
)
else:
allowed_collections = user_collections
# for non-superusers, we filter by user_id and selected & allowed collections
collection_filters = {
"$or": [
{"owner_id": {"$eq": auth_user.id}},
{"collection_ids": {"$overlap": list(allowed_collections)}},
] # type: ignore
}
filters.pop("collection_ids", None)
if filters != {}:
filters = {"$and": [collection_filters, filters]} # type: ignore
else:
filters = collection_filters
return filters