aboutsummaryrefslogtreecommitdiff
"""Abstractions for search functionality."""

from copy import copy
from enum import Enum
from typing import Any, Optional
from uuid import NAMESPACE_DNS, UUID, uuid5

from pydantic import Field

from .base import R2RSerializable
from .document import DocumentResponse
from .llm import GenerationConfig
from .vector import IndexMeasure


def generate_id_from_label(label) -> UUID:
    return uuid5(NAMESPACE_DNS, label)


class ChunkSearchResult(R2RSerializable):
    """Result of a search operation."""

    id: UUID
    document_id: UUID
    owner_id: Optional[UUID]
    collection_ids: list[UUID]
    score: Optional[float] = None
    text: str
    metadata: dict[str, Any]

    def __str__(self) -> str:
        if self.score:
            return (
                f"ChunkSearchResult(score={self.score:.3f}, text={self.text})"
            )
        else:
            return f"ChunkSearchResult(text={self.text})"

    def __repr__(self) -> str:
        return self.__str__()

    def as_dict(self) -> dict:
        return {
            "id": self.id,
            "document_id": self.document_id,
            "owner_id": self.owner_id,
            "collection_ids": self.collection_ids,
            "score": self.score,
            "text": self.text,
            "metadata": self.metadata,
        }

    class Config:
        populate_by_name = True
        json_schema_extra = {
            "example": {
                "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
                "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b",
                "owner_id": "2acb499e-8428-543b-bd85-0d9098718220",
                "collection_ids": [],
                "score": 0.23943702876567796,
                "text": "Example text from the document",
                "metadata": {
                    "title": "example_document.pdf",
                    "associated_query": "What is the capital of France?",
                },
            }
        }


class GraphSearchResultType(str, Enum):
    ENTITY = "entity"
    RELATIONSHIP = "relationship"
    COMMUNITY = "community"


class GraphEntityResult(R2RSerializable):
    id: Optional[UUID] = None
    name: str
    description: str
    metadata: Optional[dict[str, Any]] = None

    class Config:
        json_schema_extra = {
            "example": {
                "name": "Entity Name",
                "description": "Entity Description",
                "metadata": {},
            }
        }


class GraphRelationshipResult(R2RSerializable):
    id: Optional[UUID] = None
    subject: str
    predicate: str
    object: str
    subject_id: Optional[UUID] = None
    object_id: Optional[UUID] = None
    metadata: Optional[dict[str, Any]] = None
    score: Optional[float] = None
    description: str | None = None

    class Config:
        json_schema_extra = {
            "example": {
                "name": "Relationship Name",
                "description": "Relationship Description",
                "metadata": {},
            }
        }

    def __str__(self) -> str:
        return f"GraphRelationshipResult(subject={self.subject}, predicate={self.predicate}, object={self.object})"


class GraphCommunityResult(R2RSerializable):
    id: Optional[UUID] = None
    name: str
    summary: str
    metadata: Optional[dict[str, Any]] = None

    class Config:
        json_schema_extra = {
            "example": {
                "name": "Community Name",
                "summary": "Community Summary",
                "rating": 9,
                "rating_explanation": "Rating Explanation",
                "metadata": {},
            }
        }

    def __str__(self) -> str:
        return (
            f"GraphCommunityResult(name={self.name}, summary={self.summary})"
        )


class GraphSearchResult(R2RSerializable):
    content: GraphEntityResult | GraphRelationshipResult | GraphCommunityResult
    result_type: Optional[GraphSearchResultType] = None
    chunk_ids: Optional[list[UUID]] = None
    metadata: dict[str, Any] = {}
    score: Optional[float] = None
    id: UUID

    def __str__(self) -> str:
        return f"GraphSearchResult(content={self.content}, result_type={self.result_type})"

    class Config:
        populate_by_name = True
        json_schema_extra = {
            "example": {
                "content": {
                    "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
                    "name": "Entity Name",
                    "description": "Entity Description",
                    "metadata": {},
                },
                "result_type": "entity",
                "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"],
                "metadata": {
                    "associated_query": "What is the capital of France?"
                },
            }
        }


class WebPageSearchResult(R2RSerializable):
    title: Optional[str] = None
    link: Optional[str] = None
    snippet: Optional[str] = None
    position: int
    type: str = "organic"
    date: Optional[str] = None
    sitelinks: Optional[list[dict]] = None
    id: UUID

    class Config:
        json_schema_extra = {
            "example": {
                "title": "Page Title",
                "link": "https://example.com/page",
                "snippet": "Page snippet",
                "position": 1,
                "date": "2021-01-01",
                "sitelinks": [
                    {
                        "title": "Sitelink Title",
                        "link": "https://example.com/sitelink",
                    }
                ],
            }
        }

    def __str__(self) -> str:
        return f"WebPageSearchResult(title={self.title}, link={self.link}, snippet={self.snippet})"


class RelatedSearchResult(R2RSerializable):
    query: str
    type: str = "related"
    id: UUID


class PeopleAlsoAskResult(R2RSerializable):
    question: str
    snippet: str
    link: str
    title: str
    id: UUID
    type: str = "peopleAlsoAsk"


class WebSearchResult(R2RSerializable):
    organic_results: list[WebPageSearchResult] = []
    related_searches: list[RelatedSearchResult] = []
    people_also_ask: list[PeopleAlsoAskResult] = []

    @classmethod
    def from_serper_results(cls, results: list[dict]) -> "WebSearchResult":
        organic = []
        related = []
        paa = []

        for result in results:
            if result["type"] == "organic":
                organic.append(
                    WebPageSearchResult(
                        **result, id=generate_id_from_label(result.get("link"))
                    )
                )
            elif result["type"] == "relatedSearches":
                related.append(
                    RelatedSearchResult(
                        **result,
                        id=generate_id_from_label(result.get("query")),
                    )
                )
            elif result["type"] == "peopleAlsoAsk":
                paa.append(
                    PeopleAlsoAskResult(
                        **result, id=generate_id_from_label(result.get("link"))
                    )
                )

        return cls(
            organic_results=organic,
            related_searches=related,
            people_also_ask=paa,
        )


class AggregateSearchResult(R2RSerializable):
    """Result of an aggregate search operation."""

    chunk_search_results: Optional[list[ChunkSearchResult]] = None
    graph_search_results: Optional[list[GraphSearchResult]] = None
    web_search_results: Optional[list[WebPageSearchResult]] = None
    document_search_results: Optional[list[DocumentResponse]] = None

    def __str__(self) -> str:
        return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results}, document_search_results={str(self.document_search_results)})"

    def __repr__(self) -> str:
        return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results}, document_search_results={str(self.document_search_results)})"

    def as_dict(self) -> dict:
        return {
            "chunk_search_results": (
                [result.as_dict() for result in self.chunk_search_results]
                if self.chunk_search_results
                else []
            ),
            "graph_search_results": (
                [result.to_dict() for result in self.graph_search_results]
                if self.graph_search_results
                else []
            ),
            "web_search_results": (
                [result.to_dict() for result in self.web_search_results]
                if self.web_search_results
                else []
            ),
            "document_search_results": (
                [cdr.to_dict() for cdr in self.document_search_results]
                if self.document_search_results
                else []
            ),
        }

    class Config:
        populate_by_name = True
        json_schema_extra = {
            "example": {
                "chunk_search_results": [
                    {
                        "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
                        "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b",
                        "owner_id": "2acb499e-8428-543b-bd85-0d9098718220",
                        "collection_ids": [],
                        "score": 0.23943702876567796,
                        "text": "Example text from the document",
                        "metadata": {
                            "title": "example_document.pdf",
                            "associated_query": "What is the capital of France?",
                        },
                    }
                ],
                "graph_search_results": [
                    {
                        "content": {
                            "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
                            "name": "Entity Name",
                            "description": "Entity Description",
                            "metadata": {},
                        },
                        "result_type": "entity",
                        "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"],
                        "metadata": {
                            "associated_query": "What is the capital of France?"
                        },
                    }
                ],
                "web_search_results": [
                    {
                        "title": "Page Title",
                        "link": "https://example.com/page",
                        "snippet": "Page snippet",
                        "position": 1,
                        "date": "2021-01-01",
                        "sitelinks": [
                            {
                                "title": "Sitelink Title",
                                "link": "https://example.com/sitelink",
                            }
                        ],
                    }
                ],
                "document_search_results": [
                    {
                        "document": {
                            "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
                            "title": "Document Title",
                            "chunks": ["Chunk 1", "Chunk 2"],
                            "metadata": {},
                        },
                    }
                ],
            }
        }


class HybridSearchSettings(R2RSerializable):
    """Settings for hybrid search combining full-text and semantic search."""

    full_text_weight: float = Field(
        default=1.0, description="Weight to apply to full text search"
    )
    semantic_weight: float = Field(
        default=5.0, description="Weight to apply to semantic search"
    )
    full_text_limit: int = Field(
        default=200,
        description="Maximum number of results to return from full text search",
    )
    rrf_k: int = Field(
        default=50, description="K-value for RRF (Rank Reciprocal Fusion)"
    )


class ChunkSearchSettings(R2RSerializable):
    """Settings specific to chunk/vector search."""

    index_measure: IndexMeasure = Field(
        default=IndexMeasure.cosine_distance,
        description="The distance measure to use for indexing",
    )
    probes: int = Field(
        default=10,
        description="Number of ivfflat index lists to query. Higher increases accuracy but decreases speed.",
    )
    ef_search: int = Field(
        default=40,
        description="Size of the dynamic candidate list for HNSW index search. Higher increases accuracy but decreases speed.",
    )
    enabled: bool = Field(
        default=True,
        description="Whether to enable chunk search",
    )


class GraphSearchSettings(R2RSerializable):
    """Settings specific to knowledge graph search."""

    generation_config: Optional[GenerationConfig] = Field(
        default=None,
        description="Configuration for text generation during graph search.",
    )
    max_community_description_length: int = Field(
        default=65536,
    )
    max_llm_queries_for_global_search: int = Field(
        default=250,
    )
    limits: dict[str, int] = Field(
        default={},
    )
    enabled: bool = Field(
        default=True,
        description="Whether to enable graph search",
    )


class SearchSettings(R2RSerializable):
    """Main search settings class that combines shared settings with
    specialized settings for chunks and graph."""

    # Search type flags
    use_hybrid_search: bool = Field(
        default=False,
        description="Whether to perform a hybrid search. This is equivalent to setting `use_semantic_search=True` and `use_fulltext_search=True`, e.g. combining vector and keyword search.",
    )
    use_semantic_search: bool = Field(
        default=True,
        description="Whether to use semantic search",
    )
    use_fulltext_search: bool = Field(
        default=False,
        description="Whether to use full-text search",
    )

    # Common search parameters
    filters: dict[str, Any] = Field(
        default_factory=dict,
        description="""Filters to apply to the search. Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`.

      Commonly seen filters include operations include the following:

        `{"document_id": {"$eq": "9fbe403b-..."}}`

        `{"document_id": {"$in": ["9fbe403b-...", "3e157b3a-..."]}}`

        `{"collection_ids": {"$overlap": ["122fdf6a-...", "..."]}}`

        `{"$and": {"$document_id": ..., "collection_ids": ...}}`""",
    )
    limit: int = Field(
        default=10,
        description="Maximum number of results to return",
        ge=1,
        le=1_000,
    )
    offset: int = Field(
        default=0,
        ge=0,
        description="Offset to paginate search results",
    )
    include_metadatas: bool = Field(
        default=True,
        description="Whether to include element metadata in the search results",
    )
    include_scores: bool = Field(
        default=True,
        description="""Whether to include search score values in the
        search results""",
    )

    # Search strategy and settings
    search_strategy: str = Field(
        default="vanilla",
        description="""Search strategy to use
        (e.g., 'vanilla', 'query_fusion', 'hyde')""",
    )
    hybrid_settings: HybridSearchSettings = Field(
        default_factory=HybridSearchSettings,
        description="""Settings for hybrid search (only used if
        `use_semantic_search` and `use_fulltext_search` are both true)""",
    )

    # Specialized settings
    chunk_settings: ChunkSearchSettings = Field(
        default_factory=ChunkSearchSettings,
        description="Settings specific to chunk/vector search",
    )
    graph_settings: GraphSearchSettings = Field(
        default_factory=GraphSearchSettings,
        description="Settings specific to knowledge graph search",
    )

    # For HyDE or multi-query:
    num_sub_queries: int = Field(
        default=5,
        description="Number of sub-queries/hypothetical docs to generate when using hyde or rag_fusion search strategies.",
    )

    class Config:
        populate_by_name = True
        json_encoders = {UUID: str}
        json_schema_extra = {
            "example": {
                "use_semantic_search": True,
                "use_fulltext_search": False,
                "use_hybrid_search": False,
                "filters": {"category": "technology"},
                "limit": 20,
                "offset": 0,
                "search_strategy": "vanilla",
                "hybrid_settings": {
                    "full_text_weight": 1.0,
                    "semantic_weight": 5.0,
                    "full_text_limit": 200,
                    "rrf_k": 50,
                },
                "chunk_settings": {
                    "enabled": True,
                    "index_measure": "cosine_distance",
                    "include_metadata": True,
                    "probes": 10,
                    "ef_search": 40,
                },
                "graph_settings": {
                    "enabled": True,
                    "generation_config": GenerationConfig.Config.json_schema_extra,
                    "max_community_description_length": 65536,
                    "max_llm_queries_for_global_search": 250,
                    "limits": {
                        "entity": 20,
                        "relationship": 20,
                        "community": 20,
                    },
                },
            }
        }

    def __init__(self, **data):
        # Handle legacy search_filters field
        data["filters"] = {
            **data.get("filters", {}),
            **data.get("search_filters", {}),
        }
        super().__init__(**data)

    def model_dump(self, *args, **kwargs):
        return super().model_dump(*args, **kwargs)

    @classmethod
    def get_default(cls, mode: str) -> "SearchSettings":
        """Return default search settings for a given mode."""
        if mode == "basic":
            # A simpler search that relies primarily on semantic search.
            return cls(
                use_semantic_search=True,
                use_fulltext_search=False,
                use_hybrid_search=False,
                search_strategy="vanilla",
                # Other relevant defaults can be provided here as needed
            )
        elif mode == "advanced":
            # A more powerful, combined search that leverages both semantic and fulltext.
            return cls(
                use_semantic_search=True,
                use_fulltext_search=True,
                use_hybrid_search=True,
                search_strategy="hyde",
                # Other advanced defaults as needed
            )
        else:
            # For 'custom' or unrecognized modes, return a basic empty config.
            return cls()


class SearchMode(str, Enum):
    """Search modes for the search endpoint."""

    basic = "basic"
    advanced = "advanced"
    custom = "custom"


def select_search_filters(
    auth_user: Any,
    search_settings: SearchSettings,
) -> dict[str, Any]:
    filters = copy(search_settings.filters)
    selected_collections = None
    if not auth_user.is_superuser:
        user_collections = set(auth_user.collection_ids)
        for key in filters.keys():
            if "collection_ids" in key:
                selected_collections = set(map(UUID, filters[key]["$overlap"]))
                break

        if selected_collections:
            allowed_collections = user_collections.intersection(
                selected_collections
            )
        else:
            allowed_collections = user_collections
        # for non-superusers, we filter by user_id and selected & allowed collections
        collection_filters = {
            "$or": [
                {"owner_id": {"$eq": auth_user.id}},
                {"collection_ids": {"$overlap": list(allowed_collections)}},
            ]  # type: ignore
        }

        filters.pop("collection_ids", None)
        if filters != {}:
            filters = {"$and": [collection_filters, filters]}  # type: ignore
        else:
            filters = collection_filters
    return filters