diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/shared/abstractions/search.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/shared/abstractions/search.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/shared/abstractions/search.py | 614 |
1 files changed, 614 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/shared/abstractions/search.py b/.venv/lib/python3.12/site-packages/shared/abstractions/search.py new file mode 100644 index 00000000..bf0f650e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/shared/abstractions/search.py @@ -0,0 +1,614 @@ +"""Abstractions for search functionality.""" + +from copy import copy +from enum import Enum +from typing import Any, Optional +from uuid import NAMESPACE_DNS, UUID, uuid5 + +from pydantic import Field + +from .base import R2RSerializable +from .document import DocumentResponse +from .llm import GenerationConfig +from .vector import IndexMeasure + + +def generate_id_from_label(label) -> UUID: + return uuid5(NAMESPACE_DNS, label) + + +class ChunkSearchResult(R2RSerializable): + """Result of a search operation.""" + + id: UUID + document_id: UUID + owner_id: Optional[UUID] + collection_ids: list[UUID] + score: Optional[float] = None + text: str + metadata: dict[str, Any] + + def __str__(self) -> str: + if self.score: + return ( + f"ChunkSearchResult(score={self.score:.3f}, text={self.text})" + ) + else: + return f"ChunkSearchResult(text={self.text})" + + def __repr__(self) -> str: + return self.__str__() + + def as_dict(self) -> dict: + return { + "id": self.id, + "document_id": self.document_id, + "owner_id": self.owner_id, + "collection_ids": self.collection_ids, + "score": self.score, + "text": self.text, + "metadata": self.metadata, + } + + class Config: + populate_by_name = True + json_schema_extra = { + "example": { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b", + "owner_id": "2acb499e-8428-543b-bd85-0d9098718220", + "collection_ids": [], + "score": 0.23943702876567796, + "text": "Example text from the document", + "metadata": { + "title": "example_document.pdf", + "associated_query": "What is the capital of France?", + }, + } + } + + +class GraphSearchResultType(str, Enum): + ENTITY = "entity" + RELATIONSHIP = "relationship" + COMMUNITY = "community" + + +class GraphEntityResult(R2RSerializable): + id: Optional[UUID] = None + name: str + description: str + metadata: Optional[dict[str, Any]] = None + + class Config: + json_schema_extra = { + "example": { + "name": "Entity Name", + "description": "Entity Description", + "metadata": {}, + } + } + + +class GraphRelationshipResult(R2RSerializable): + id: Optional[UUID] = None + subject: str + predicate: str + object: str + subject_id: Optional[UUID] = None + object_id: Optional[UUID] = None + metadata: Optional[dict[str, Any]] = None + score: Optional[float] = None + description: str | None = None + + class Config: + json_schema_extra = { + "example": { + "name": "Relationship Name", + "description": "Relationship Description", + "metadata": {}, + } + } + + def __str__(self) -> str: + return f"GraphRelationshipResult(subject={self.subject}, predicate={self.predicate}, object={self.object})" + + +class GraphCommunityResult(R2RSerializable): + id: Optional[UUID] = None + name: str + summary: str + metadata: Optional[dict[str, Any]] = None + + class Config: + json_schema_extra = { + "example": { + "name": "Community Name", + "summary": "Community Summary", + "rating": 9, + "rating_explanation": "Rating Explanation", + "metadata": {}, + } + } + + def __str__(self) -> str: + return ( + f"GraphCommunityResult(name={self.name}, summary={self.summary})" + ) + + +class GraphSearchResult(R2RSerializable): + content: GraphEntityResult | GraphRelationshipResult | GraphCommunityResult + result_type: Optional[GraphSearchResultType] = None + chunk_ids: Optional[list[UUID]] = None + metadata: dict[str, Any] = {} + score: Optional[float] = None + id: UUID + + def __str__(self) -> str: + return f"GraphSearchResult(content={self.content}, result_type={self.result_type})" + + class Config: + populate_by_name = True + json_schema_extra = { + "example": { + "content": { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "name": "Entity Name", + "description": "Entity Description", + "metadata": {}, + }, + "result_type": "entity", + "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"], + "metadata": { + "associated_query": "What is the capital of France?" + }, + } + } + + +class WebPageSearchResult(R2RSerializable): + title: Optional[str] = None + link: Optional[str] = None + snippet: Optional[str] = None + position: int + type: str = "organic" + date: Optional[str] = None + sitelinks: Optional[list[dict]] = None + id: UUID + + class Config: + json_schema_extra = { + "example": { + "title": "Page Title", + "link": "https://example.com/page", + "snippet": "Page snippet", + "position": 1, + "date": "2021-01-01", + "sitelinks": [ + { + "title": "Sitelink Title", + "link": "https://example.com/sitelink", + } + ], + } + } + + def __str__(self) -> str: + return f"WebPageSearchResult(title={self.title}, link={self.link}, snippet={self.snippet})" + + +class RelatedSearchResult(R2RSerializable): + query: str + type: str = "related" + id: UUID + + +class PeopleAlsoAskResult(R2RSerializable): + question: str + snippet: str + link: str + title: str + id: UUID + type: str = "peopleAlsoAsk" + + +class WebSearchResult(R2RSerializable): + organic_results: list[WebPageSearchResult] = [] + related_searches: list[RelatedSearchResult] = [] + people_also_ask: list[PeopleAlsoAskResult] = [] + + @classmethod + def from_serper_results(cls, results: list[dict]) -> "WebSearchResult": + organic = [] + related = [] + paa = [] + + for result in results: + if result["type"] == "organic": + organic.append( + WebPageSearchResult( + **result, id=generate_id_from_label(result.get("link")) + ) + ) + elif result["type"] == "relatedSearches": + related.append( + RelatedSearchResult( + **result, + id=generate_id_from_label(result.get("query")), + ) + ) + elif result["type"] == "peopleAlsoAsk": + paa.append( + PeopleAlsoAskResult( + **result, id=generate_id_from_label(result.get("link")) + ) + ) + + return cls( + organic_results=organic, + related_searches=related, + people_also_ask=paa, + ) + + +class AggregateSearchResult(R2RSerializable): + """Result of an aggregate search operation.""" + + chunk_search_results: Optional[list[ChunkSearchResult]] = None + graph_search_results: Optional[list[GraphSearchResult]] = None + web_search_results: Optional[list[WebPageSearchResult]] = None + document_search_results: Optional[list[DocumentResponse]] = None + + def __str__(self) -> str: + return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results}, document_search_results={str(self.document_search_results)})" + + def __repr__(self) -> str: + return f"AggregateSearchResult(chunk_search_results={self.chunk_search_results}, graph_search_results={self.graph_search_results}, web_search_results={self.web_search_results}, document_search_results={str(self.document_search_results)})" + + def as_dict(self) -> dict: + return { + "chunk_search_results": ( + [result.as_dict() for result in self.chunk_search_results] + if self.chunk_search_results + else [] + ), + "graph_search_results": ( + [result.to_dict() for result in self.graph_search_results] + if self.graph_search_results + else [] + ), + "web_search_results": ( + [result.to_dict() for result in self.web_search_results] + if self.web_search_results + else [] + ), + "document_search_results": ( + [cdr.to_dict() for cdr in self.document_search_results] + if self.document_search_results + else [] + ), + } + + class Config: + populate_by_name = True + json_schema_extra = { + "example": { + "chunk_search_results": [ + { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "document_id": "3e157b3a-8469-51db-90d9-52e7d896b49b", + "owner_id": "2acb499e-8428-543b-bd85-0d9098718220", + "collection_ids": [], + "score": 0.23943702876567796, + "text": "Example text from the document", + "metadata": { + "title": "example_document.pdf", + "associated_query": "What is the capital of France?", + }, + } + ], + "graph_search_results": [ + { + "content": { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "name": "Entity Name", + "description": "Entity Description", + "metadata": {}, + }, + "result_type": "entity", + "chunk_ids": ["c68dc72e-fc23-5452-8f49-d7bd46088a96"], + "metadata": { + "associated_query": "What is the capital of France?" + }, + } + ], + "web_search_results": [ + { + "title": "Page Title", + "link": "https://example.com/page", + "snippet": "Page snippet", + "position": 1, + "date": "2021-01-01", + "sitelinks": [ + { + "title": "Sitelink Title", + "link": "https://example.com/sitelink", + } + ], + } + ], + "document_search_results": [ + { + "document": { + "id": "3f3d47f3-8baf-58eb-8bc2-0171fb1c6e09", + "title": "Document Title", + "chunks": ["Chunk 1", "Chunk 2"], + "metadata": {}, + }, + } + ], + } + } + + +class HybridSearchSettings(R2RSerializable): + """Settings for hybrid search combining full-text and semantic search.""" + + full_text_weight: float = Field( + default=1.0, description="Weight to apply to full text search" + ) + semantic_weight: float = Field( + default=5.0, description="Weight to apply to semantic search" + ) + full_text_limit: int = Field( + default=200, + description="Maximum number of results to return from full text search", + ) + rrf_k: int = Field( + default=50, description="K-value for RRF (Rank Reciprocal Fusion)" + ) + + +class ChunkSearchSettings(R2RSerializable): + """Settings specific to chunk/vector search.""" + + index_measure: IndexMeasure = Field( + default=IndexMeasure.cosine_distance, + description="The distance measure to use for indexing", + ) + probes: int = Field( + default=10, + description="Number of ivfflat index lists to query. Higher increases accuracy but decreases speed.", + ) + ef_search: int = Field( + default=40, + description="Size of the dynamic candidate list for HNSW index search. Higher increases accuracy but decreases speed.", + ) + enabled: bool = Field( + default=True, + description="Whether to enable chunk search", + ) + + +class GraphSearchSettings(R2RSerializable): + """Settings specific to knowledge graph search.""" + + generation_config: Optional[GenerationConfig] = Field( + default=None, + description="Configuration for text generation during graph search.", + ) + max_community_description_length: int = Field( + default=65536, + ) + max_llm_queries_for_global_search: int = Field( + default=250, + ) + limits: dict[str, int] = Field( + default={}, + ) + enabled: bool = Field( + default=True, + description="Whether to enable graph search", + ) + + +class SearchSettings(R2RSerializable): + """Main search settings class that combines shared settings with + specialized settings for chunks and graph.""" + + # Search type flags + use_hybrid_search: bool = Field( + default=False, + description="Whether to perform a hybrid search. This is equivalent to setting `use_semantic_search=True` and `use_fulltext_search=True`, e.g. combining vector and keyword search.", + ) + use_semantic_search: bool = Field( + default=True, + description="Whether to use semantic search", + ) + use_fulltext_search: bool = Field( + default=False, + description="Whether to use full-text search", + ) + + # Common search parameters + filters: dict[str, Any] = Field( + default_factory=dict, + description="""Filters to apply to the search. Allowed operators include `eq`, `neq`, `gt`, `gte`, `lt`, `lte`, `like`, `ilike`, `in`, and `nin`. + + Commonly seen filters include operations include the following: + + `{"document_id": {"$eq": "9fbe403b-..."}}` + + `{"document_id": {"$in": ["9fbe403b-...", "3e157b3a-..."]}}` + + `{"collection_ids": {"$overlap": ["122fdf6a-...", "..."]}}` + + `{"$and": {"$document_id": ..., "collection_ids": ...}}`""", + ) + limit: int = Field( + default=10, + description="Maximum number of results to return", + ge=1, + le=1_000, + ) + offset: int = Field( + default=0, + ge=0, + description="Offset to paginate search results", + ) + include_metadatas: bool = Field( + default=True, + description="Whether to include element metadata in the search results", + ) + include_scores: bool = Field( + default=True, + description="""Whether to include search score values in the + search results""", + ) + + # Search strategy and settings + search_strategy: str = Field( + default="vanilla", + description="""Search strategy to use + (e.g., 'vanilla', 'query_fusion', 'hyde')""", + ) + hybrid_settings: HybridSearchSettings = Field( + default_factory=HybridSearchSettings, + description="""Settings for hybrid search (only used if + `use_semantic_search` and `use_fulltext_search` are both true)""", + ) + + # Specialized settings + chunk_settings: ChunkSearchSettings = Field( + default_factory=ChunkSearchSettings, + description="Settings specific to chunk/vector search", + ) + graph_settings: GraphSearchSettings = Field( + default_factory=GraphSearchSettings, + description="Settings specific to knowledge graph search", + ) + + # For HyDE or multi-query: + num_sub_queries: int = Field( + default=5, + description="Number of sub-queries/hypothetical docs to generate when using hyde or rag_fusion search strategies.", + ) + + class Config: + populate_by_name = True + json_encoders = {UUID: str} + json_schema_extra = { + "example": { + "use_semantic_search": True, + "use_fulltext_search": False, + "use_hybrid_search": False, + "filters": {"category": "technology"}, + "limit": 20, + "offset": 0, + "search_strategy": "vanilla", + "hybrid_settings": { + "full_text_weight": 1.0, + "semantic_weight": 5.0, + "full_text_limit": 200, + "rrf_k": 50, + }, + "chunk_settings": { + "enabled": True, + "index_measure": "cosine_distance", + "include_metadata": True, + "probes": 10, + "ef_search": 40, + }, + "graph_settings": { + "enabled": True, + "generation_config": GenerationConfig.Config.json_schema_extra, + "max_community_description_length": 65536, + "max_llm_queries_for_global_search": 250, + "limits": { + "entity": 20, + "relationship": 20, + "community": 20, + }, + }, + } + } + + def __init__(self, **data): + # Handle legacy search_filters field + data["filters"] = { + **data.get("filters", {}), + **data.get("search_filters", {}), + } + super().__init__(**data) + + def model_dump(self, *args, **kwargs): + return super().model_dump(*args, **kwargs) + + @classmethod + def get_default(cls, mode: str) -> "SearchSettings": + """Return default search settings for a given mode.""" + if mode == "basic": + # A simpler search that relies primarily on semantic search. + return cls( + use_semantic_search=True, + use_fulltext_search=False, + use_hybrid_search=False, + search_strategy="vanilla", + # Other relevant defaults can be provided here as needed + ) + elif mode == "advanced": + # A more powerful, combined search that leverages both semantic and fulltext. + return cls( + use_semantic_search=True, + use_fulltext_search=True, + use_hybrid_search=True, + search_strategy="hyde", + # Other advanced defaults as needed + ) + else: + # For 'custom' or unrecognized modes, return a basic empty config. + return cls() + + +class SearchMode(str, Enum): + """Search modes for the search endpoint.""" + + basic = "basic" + advanced = "advanced" + custom = "custom" + + +def select_search_filters( + auth_user: Any, + search_settings: SearchSettings, +) -> dict[str, Any]: + filters = copy(search_settings.filters) + selected_collections = None + if not auth_user.is_superuser: + user_collections = set(auth_user.collection_ids) + for key in filters.keys(): + if "collection_ids" in key: + selected_collections = set(map(UUID, filters[key]["$overlap"])) + break + + if selected_collections: + allowed_collections = user_collections.intersection( + selected_collections + ) + else: + allowed_collections = user_collections + # for non-superusers, we filter by user_id and selected & allowed collections + collection_filters = { + "$or": [ + {"owner_id": {"$eq": auth_user.id}}, + {"collection_ids": {"$overlap": list(allowed_collections)}}, + ] # type: ignore + } + + filters.pop("collection_ids", None) + if filters != {}: + filters = {"$and": [collection_filters, filters]} # type: ignore + else: + filters = collection_filters + return filters |