about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/shared/abstractions/search.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/shared/abstractions/search.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/shared/abstractions/search.py')
-rw-r--r--.venv/lib/python3.12/site-packages/shared/abstractions/search.py614
1 files changed, 614 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/search.py b/.venv/lib/python3.12/site-packages/shared/abstractions/search.py
new file mode 100644
index 00000000..bf0f650e
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/shared/abstractions/search.py
@@ -0,0 +1,614 @@
+"""Abstractions for search functionality."""
+
+from copy import copy
+from enum import Enum
+from typing import Any, Optional
+from uuid import NAMESPACE_DNS, UUID, uuid5
+
+from pydantic import Field
+
+from .base import R2RSerializable
+from .document import DocumentResponse
+from .llm import GenerationConfig
+from .vector import IndexMeasure
+
+
+def generate_id_from_label(label) -> UUID:
+    return uuid5(NAMESPACE_DNS, label)
+
+
+class ChunkSearchResult(R2RSerializable):
+    """Result of a search operation."""
+
+    id: UUID
+    document_id: UUID
+    owner_id: Optional[UUID]
+    collection_ids: list[UUID]
+    score: Optional[float] = None
+    text: str
+    metadata: dict[str, Any]
+
+    def __str__(self) -> str:
+        if self.score:
+            return (
+                f"ChunkSearchResult(score={self.score:.3f}, text={self.text})"
+            )
+        else:
+            return f"ChunkSearchResult(text={self.text})"
+
+    def __repr__(self) -> str:
+        return self.__str__()
+
+    def as_dict(self) -> dict:
+        return {
+            "id": self.id,
+            "document_id": self.document_id,
+            "owner_id": self.owner_id,
+            "collection_ids": self.collection_ids,
+            "score": self.score,
+            "text": self.text,
+            "metadata": self.metadata,
+        }
+
+    class Config:
+        populate_by_name = True
+        json_schema_extra = {
+            "example": {
+                "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+                "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b",
+                "owner_id": "2acb499e-8428-543b-bd85-0d9098718220",
+                "collection_ids": [],
+                "score": 0.23943702876567796,
+                "text": "Example text from the document",
+                "metadata": {
+                    "title": "example_document.pdf",
+                    "associated_query": "What is the capital of France?",
+                },
+            }
+        }
+
+
+class GraphSearchResultType(str, Enum):
+    ENTITY = "entity"
+    RELATIONSHIP = "relationship"
+    COMMUNITY = "community"
+
+
+class GraphEntityResult(R2RSerializable):
+    id: Optional[UUID] = None
+    name: str
+    description: str
+    metadata: Optional[dict[str, Any]] = None
+
+    class Config:
+        json_schema_extra = {
+            "example": {
+                "name": "Entity Name",
+                "description": "Entity Description",
+                "metadata": {},
+            }
+        }
+
+
+class GraphRelationshipResult(R2RSerializable):
+    id: Optional[UUID] = None
+    subject: str
+    predicate: str
+    object: str
+    subject_id: Optional[UUID] = None
+    object_id: Optional[UUID] = None
+    metadata: Optional[dict[str, Any]] = None
+    score: Optional[float] = None
+    description: str | None = None
+
+    class Config:
+        json_schema_extra = {
+            "example": {
+                "name": "Relationship Name",
+                "description": "Relationship Description",
+                "metadata": {},
+            }
+        }
+
+    def __str__(self) -> str:
+        return f"GraphRelationshipResult(subject={self.subject}, predicate={self.predicate}, object={self.object})"
+
+
+class GraphCommunityResult(R2RSerializable):
+    id: Optional[UUID] = None
+    name: str
+    summary: str
+    metadata: Optional[dict[str, Any]] = None
+
+    class Config:
+        json_schema_extra = {
+            "example": {
+                "name": "Community Name",
+                "summary": "Community Summary",
+                "rating": 9,
+                "rating_explanation": "Rating Explanation",
+                "metadata": {},
+            }
+        }
+
+    def __str__(self) -> str:
+        return (
+            f"GraphCommunityResult(name={self.name}, summary={self.summary})"
+        )
+
+
+class GraphSearchResult(R2RSerializable):
+    content: GraphEntityResult | GraphRelationshipResult | GraphCommunityResult
+    result_type: Optional[GraphSearchResultType] = None
+    chunk_ids: Optional[list[UUID]] = None
+    metadata: dict[str, Any] = {}
+    score: Optional[float] = None
+    id: UUID
+
+    def __str__(self) -> str:
+        return f"GraphSearchResult(content={self.content}, result_type={self.result_type})"
+
+    class Config:
+        populate_by_name = True
+        json_schema_extra = {
+            "example": {
+                "content": {
+                    "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+                    "name": "Entity Name",
+                    "description": "Entity Description",
+                    "metadata": {},
+                },
+                "result_type": "entity",
+                "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"],
+                "metadata": {
+                    "associated_query": "What is the capital of France?"
+                },
+            }
+        }
+
+
+class WebPageSearchResult(R2RSerializable):
+    title: Optional[str] = None
+    link: Optional[str] = None
+    snippet: Optional[str] = None
+    position: int
+    type: str = "organic"
+    date: Optional[str] = None
+    sitelinks: Optional[list[dict]] = None
+    id: UUID
+
+    class Config:
+        json_schema_extra = {
+            "example": {
+                "title": "Page Title",
+                "link": "https://example.com/page",
+                "snippet": "Page snippet",
+                "position": 1,
+                "date": "2021-01-01",
+                "sitelinks": [
+                    {
+                        "title": "Sitelink Title",
+                        "link": "https://example.com/sitelink",
+                    }
+                ],
+            }
+        }
+
+    def __str__(self) -> str:
+        return f"WebPageSearchResult(title={self.title}, link={self.link}, snippet={self.snippet})"
+
+
+class RelatedSearchResult(R2RSerializable):
+    query: str
+    type: str = "related"
+    id: UUID
+
+
+class PeopleAlsoAskResult(R2RSerializable):
+    question: str
+    snippet: str
+    link: str
+    title: str
+    id: UUID
+    type: str = "peopleAlsoAsk"
+
+
+class WebSearchResult(R2RSerializable):
+    organic_results: list[WebPageSearchResult] = []
+    related_searches: list[RelatedSearchResult] = []
+    people_also_ask: list[PeopleAlsoAskResult] = []
+
+    @classmethod
+    def from_serper_results(cls, results: list[dict]) -> "WebSearchResult":
+        organic = []
+        related = []
+        paa = []
+
+        for result in results:
+            if result["type"] == "organic":
+                organic.append(
+                    WebPageSearchResult(
+                        **result, id=generate_id_from_label(result.get("link"))
+                    )
+                )
+            elif result["type"] == "relatedSearches":
+                related.append(
+                    RelatedSearchResult(
+                        **result,
+                        id=generate_id_from_label(result.get("query")),
+                    )
+                )
+            elif result["type"] == "peopleAlsoAsk":
+                paa.append(
+                    PeopleAlsoAskResult(
+                        **result, id=generate_id_from_label(result.get("link"))
+                    )
+                )
+
+        return cls(
+            organic_results=organic,
+            related_searches=related,
+            people_also_ask=paa,
+        )
+
+
+class AggregateSearchResult(R2RSerializable):
+    """Result of an aggregate search operation."""
+
+    chunk_search_results: Optional[list[ChunkSearchResult]] = None
+    graph_search_results: Optional[list[GraphSearchResult]] = None
+    web_search_results: Optional[list[WebPageSearchResult]] = None
+    document_search_results: Optional[list[DocumentResponse]] = None
+
+    def __str__(self) -> str:
+        return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results}, document_search_results={str(self.document_search_results)})"
+
+    def __repr__(self) -> str:
+        return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results}, document_search_results={str(self.document_search_results)})"
+
+    def as_dict(self) -> dict:
+        return {
+            "chunk_search_results": (
+                [result.as_dict() for result in self.chunk_search_results]
+                if self.chunk_search_results
+                else []
+            ),
+            "graph_search_results": (
+                [result.to_dict() for result in self.graph_search_results]
+                if self.graph_search_results
+                else []
+            ),
+            "web_search_results": (
+                [result.to_dict() for result in self.web_search_results]
+                if self.web_search_results
+                else []
+            ),
+            "document_search_results": (
+                [cdr.to_dict() for cdr in self.document_search_results]
+                if self.document_search_results
+                else []
+            ),
+        }
+
+    class Config:
+        populate_by_name = True
+        json_schema_extra = {
+            "example": {
+                "chunk_search_results": [
+                    {
+                        "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+                        "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b",
+                        "owner_id": "2acb499e-8428-543b-bd85-0d9098718220",
+                        "collection_ids": [],
+                        "score": 0.23943702876567796,
+                        "text": "Example text from the document",
+                        "metadata": {
+                            "title": "example_document.pdf",
+                            "associated_query": "What is the capital of France?",
+                        },
+                    }
+                ],
+                "graph_search_results": [
+                    {
+                        "content": {
+                            "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+                            "name": "Entity Name",
+                            "description": "Entity Description",
+                            "metadata": {},
+                        },
+                        "result_type": "entity",
+                        "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"],
+                        "metadata": {
+                            "associated_query": "What is the capital of France?"
+                        },
+                    }
+                ],
+                "web_search_results": [
+                    {
+                        "title": "Page Title",
+                        "link": "https://example.com/page",
+                        "snippet": "Page snippet",
+                        "position": 1,
+                        "date": "2021-01-01",
+                        "sitelinks": [
+                            {
+                                "title": "Sitelink Title",
+                                "link": "https://example.com/sitelink",
+                            }
+                        ],
+                    }
+                ],
+                "document_search_results": [
+                    {
+                        "document": {
+                            "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09",
+                            "title": "Document Title",
+                            "chunks": ["Chunk 1", "Chunk 2"],
+                            "metadata": {},
+                        },
+                    }
+                ],
+            }
+        }
+
+
+class HybridSearchSettings(R2RSerializable):
+    """Settings for hybrid search combining full-text and semantic search."""
+
+    full_text_weight: float = Field(
+        default=1.0, description="Weight to apply to full text search"
+    )
+    semantic_weight: float = Field(
+        default=5.0, description="Weight to apply to semantic search"
+    )
+    full_text_limit: int = Field(
+        default=200,
+        description="Maximum number of results to return from full text search",
+    )
+    rrf_k: int = Field(
+        default=50, description="K-value for RRF (Rank Reciprocal Fusion)"
+    )
+
+
+class ChunkSearchSettings(R2RSerializable):
+    """Settings specific to chunk/vector search."""
+
+    index_measure: IndexMeasure = Field(
+        default=IndexMeasure.cosine_distance,
+        description="The distance measure to use for indexing",
+    )
+    probes: int = Field(
+        default=10,
+        description="Number of ivfflat index lists to query. Higher increases accuracy but decreases speed.",
+    )
+    ef_search: int = Field(
+        default=40,
+        description="Size of the dynamic candidate list for HNSW index search. Higher increases accuracy but decreases speed.",
+    )
+    enabled: bool = Field(
+        default=True,
+        description="Whether to enable chunk search",
+    )
+
+
+class GraphSearchSettings(R2RSerializable):
+    """Settings specific to knowledge graph search."""
+
+    generation_config: Optional[GenerationConfig] = Field(
+        default=None,
+        description="Configuration for text generation during graph search.",
+    )
+    max_community_description_length: int = Field(
+        default=65536,
+    )
+    max_llm_queries_for_global_search: int = Field(
+        default=250,
+    )
+    limits: dict[str, int] = Field(
+        default={},
+    )
+    enabled: bool = Field(
+        default=True,
+        description="Whether to enable graph search",
+    )
+
+
+class SearchSettings(R2RSerializable):
+    """Main search settings class that combines shared settings with
+    specialized settings for chunks and graph."""
+
+    # Search type flags
+    use_hybrid_search: bool = Field(
+        default=False,
+        description="Whether to perform a hybrid search. This is equivalent to setting `use_semantic_search=True` and `use_fulltext_search=True`, e.g. combining vector and keyword search.",
+    )
+    use_semantic_search: bool = Field(
+        default=True,
+        description="Whether to use semantic search",
+    )
+    use_fulltext_search: bool = Field(
+        default=False,
+        description="Whether to use full-text search",
+    )
+
+    # Common search parameters
+    filters: dict[str, Any] = Field(
+        default_factory=dict,
+        description="""Filters to apply to the search. Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`.
+
+      Commonly seen filters include operations include the following:
+
+        `{"document_id": {"$eq": "9fbe403b-..."}}`
+
+        `{"document_id": {"$in": ["9fbe403b-...", "3e157b3a-..."]}}`
+
+        `{"collection_ids": {"$overlap": ["122fdf6a-...", "..."]}}`
+
+        `{"$and": {"$document_id": ..., "collection_ids": ...}}`""",
+    )
+    limit: int = Field(
+        default=10,
+        description="Maximum number of results to return",
+        ge=1,
+        le=1_000,
+    )
+    offset: int = Field(
+        default=0,
+        ge=0,
+        description="Offset to paginate search results",
+    )
+    include_metadatas: bool = Field(
+        default=True,
+        description="Whether to include element metadata in the search results",
+    )
+    include_scores: bool = Field(
+        default=True,
+        description="""Whether to include search score values in the
+        search results""",
+    )
+
+    # Search strategy and settings
+    search_strategy: str = Field(
+        default="vanilla",
+        description="""Search strategy to use
+        (e.g., 'vanilla', 'query_fusion', 'hyde')""",
+    )
+    hybrid_settings: HybridSearchSettings = Field(
+        default_factory=HybridSearchSettings,
+        description="""Settings for hybrid search (only used if
+        `use_semantic_search` and `use_fulltext_search` are both true)""",
+    )
+
+    # Specialized settings
+    chunk_settings: ChunkSearchSettings = Field(
+        default_factory=ChunkSearchSettings,
+        description="Settings specific to chunk/vector search",
+    )
+    graph_settings: GraphSearchSettings = Field(
+        default_factory=GraphSearchSettings,
+        description="Settings specific to knowledge graph search",
+    )
+
+    # For HyDE or multi-query:
+    num_sub_queries: int = Field(
+        default=5,
+        description="Number of sub-queries/hypothetical docs to generate when using hyde or rag_fusion search strategies.",
+    )
+
+    class Config:
+        populate_by_name = True
+        json_encoders = {UUID: str}
+        json_schema_extra = {
+            "example": {
+                "use_semantic_search": True,
+                "use_fulltext_search": False,
+                "use_hybrid_search": False,
+                "filters": {"category": "technology"},
+                "limit": 20,
+                "offset": 0,
+                "search_strategy": "vanilla",
+                "hybrid_settings": {
+                    "full_text_weight": 1.0,
+                    "semantic_weight": 5.0,
+                    "full_text_limit": 200,
+                    "rrf_k": 50,
+                },
+                "chunk_settings": {
+                    "enabled": True,
+                    "index_measure": "cosine_distance",
+                    "include_metadata": True,
+                    "probes": 10,
+                    "ef_search": 40,
+                },
+                "graph_settings": {
+                    "enabled": True,
+                    "generation_config": GenerationConfig.Config.json_schema_extra,
+                    "max_community_description_length": 65536,
+                    "max_llm_queries_for_global_search": 250,
+                    "limits": {
+                        "entity": 20,
+                        "relationship": 20,
+                        "community": 20,
+                    },
+                },
+            }
+        }
+
+    def __init__(self, **data):
+        # Handle legacy search_filters field
+        data["filters"] = {
+            **data.get("filters", {}),
+            **data.get("search_filters", {}),
+        }
+        super().__init__(**data)
+
+    def model_dump(self, *args, **kwargs):
+        return super().model_dump(*args, **kwargs)
+
+    @classmethod
+    def get_default(cls, mode: str) -> "SearchSettings":
+        """Return default search settings for a given mode."""
+        if mode == "basic":
+            # A simpler search that relies primarily on semantic search.
+            return cls(
+                use_semantic_search=True,
+                use_fulltext_search=False,
+                use_hybrid_search=False,
+                search_strategy="vanilla",
+                # Other relevant defaults can be provided here as needed
+            )
+        elif mode == "advanced":
+            # A more powerful, combined search that leverages both semantic and fulltext.
+            return cls(
+                use_semantic_search=True,
+                use_fulltext_search=True,
+                use_hybrid_search=True,
+                search_strategy="hyde",
+                # Other advanced defaults as needed
+            )
+        else:
+            # For 'custom' or unrecognized modes, return a basic empty config.
+            return cls()
+
+
+class SearchMode(str, Enum):
+    """Search modes for the search endpoint."""
+
+    basic = "basic"
+    advanced = "advanced"
+    custom = "custom"
+
+
+def select_search_filters(
+    auth_user: Any,
+    search_settings: SearchSettings,
+) -> dict[str, Any]:
+    filters = copy(search_settings.filters)
+    selected_collections = None
+    if not auth_user.is_superuser:
+        user_collections = set(auth_user.collection_ids)
+        for key in filters.keys():
+            if "collection_ids" in key:
+                selected_collections = set(map(UUID, filters[key]["$overlap"]))
+                break
+
+        if selected_collections:
+            allowed_collections = user_collections.intersection(
+                selected_collections
+            )
+        else:
+            allowed_collections = user_collections
+        # for non-superusers, we filter by user_id and selected & allowed collections
+        collection_filters = {
+            "$or": [
+                {"owner_id": {"$eq": auth_user.id}},
+                {"collection_ids": {"$overlap": list(allowed_collections)}},
+            ]  # type: ignore
+        }
+
+        filters.pop("collection_ids", None)
+        if filters != {}:
+            filters = {"$and": [collection_filters, filters]}  # type: ignore
+        else:
+            filters = collection_filters
+    return filters