about summary refs log tree commit diff
"""Abstractions for search functionality."""

import uuid
from typing import Any, Dict, List, Optional, Tuple

from pydantic import BaseModel, Field

from .llm import GenerationConfig


class VectorSearchRequest(BaseModel):
    """Request for a search operation."""

    query: str
    limit: int
    filters: Optional[dict[str, Any]] = None


class VectorSearchResult(BaseModel):
    """Result of a search operation."""

    id: uuid.UUID
    score: float
    metadata: dict[str, Any]

    def __str__(self) -> str:
        return f"VectorSearchResult(id={self.id}, score={self.score}, metadata={self.metadata})"

    def __repr__(self) -> str:
        return f"VectorSearchResult(id={self.id}, score={self.score}, metadata={self.metadata})"

    def dict(self) -> dict:
        return {
            "id": self.id,
            "score": self.score,
            "metadata": self.metadata,
        }


class KGSearchRequest(BaseModel):
    """Request for a knowledge graph search operation."""

    query: str


# [query, ...]
KGSearchResult = List[Tuple[str, List[Dict[str, Any]]]]


class AggregateSearchResult(BaseModel):
    """Result of an aggregate search operation."""

    vector_search_results: Optional[List[VectorSearchResult]]
    kg_search_results: Optional[KGSearchResult] = None

    def __str__(self) -> str:
        return f"AggregateSearchResult(vector_search_results={self.vector_search_results}, kg_search_results={self.kg_search_results})"

    def __repr__(self) -> str:
        return f"AggregateSearchResult(vector_search_results={self.vector_search_results}, kg_search_results={self.kg_search_results})"

    def dict(self) -> dict:
        return {
            "vector_search_results": (
                [result.dict() for result in self.vector_search_results]
                if self.vector_search_results
                else []
            ),
            "kg_search_results": self.kg_search_results or [],
        }


class VectorSearchSettings(BaseModel):
    use_vector_search: bool = True
    search_filters: dict[str, Any] = Field(default_factory=dict)
    search_limit: int = 10
    do_hybrid_search: bool = False


class KGSearchSettings(BaseModel):
    use_kg_search: bool = False
    agent_generation_config: Optional[GenerationConfig] = Field(
        default_factory=GenerationConfig
    )