aboutsummaryrefslogtreecommitdiff
import json
import logging
import uuid
from typing import Any, AsyncGenerator, Optional

from r2r.base import (
    AsyncPipe,
    AsyncState,
    EmbeddingProvider,
    PipeType,
    VectorDBProvider,
    VectorSearchResult,
    VectorSearchSettings,
)

from ..abstractions.search_pipe import SearchPipe

logger = logging.getLogger(__name__)


class VectorSearchPipe(SearchPipe):
    def __init__(
        self,
        vector_db_provider: VectorDBProvider,
        embedding_provider: EmbeddingProvider,
        type: PipeType = PipeType.SEARCH,
        config: Optional[SearchPipe.SearchConfig] = None,
        *args,
        **kwargs,
    ):
        super().__init__(
            type=type,
            config=config or SearchPipe.SearchConfig(),
            *args,
            **kwargs,
        )
        self.embedding_provider = embedding_provider
        self.vector_db_provider = vector_db_provider

    async def search(
        self,
        message: str,
        run_id: uuid.UUID,
        vector_search_settings: VectorSearchSettings,
        *args: Any,
        **kwargs: Any,
    ) -> AsyncGenerator[VectorSearchResult, None]:
        await self.enqueue_log(
            run_id=run_id, key="search_query", value=message
        )
        search_filters = (
            vector_search_settings.search_filters or self.config.search_filters
        )
        search_limit = (
            vector_search_settings.search_limit or self.config.search_limit
        )
        results = []
        query_vector = self.embedding_provider.get_embedding(
            message,
        )
        search_results = (
            self.vector_db_provider.hybrid_search(
                query_vector=query_vector,
                query_text=message,
                filters=search_filters,
                limit=search_limit,
            )
            if vector_search_settings.do_hybrid_search
            else self.vector_db_provider.search(
                query_vector=query_vector,
                filters=search_filters,
                limit=search_limit,
            )
        )
        reranked_results = self.embedding_provider.rerank(
            query=message, results=search_results, limit=search_limit
        )
        for result in reranked_results:
            result.metadata["associatedQuery"] = message
            results.append(result)
            yield result
        await self.enqueue_log(
            run_id=run_id,
            key="search_results",
            value=json.dumps([ele.json() for ele in results]),
        )

    async def _run_logic(
        self,
        input: AsyncPipe.Input,
        state: AsyncState,
        run_id: uuid.UUID,
        vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
        *args: Any,
        **kwargs: Any,
    ) -> AsyncGenerator[VectorSearchResult, None]:
        search_queries = []
        search_results = []
        async for search_request in input.message:
            search_queries.append(search_request)
            async for result in self.search(
                message=search_request,
                run_id=run_id,
                vector_search_settings=vector_search_settings,
                *args,
                **kwargs,
            ):
                search_results.append(result)
                yield result

        await state.update(
            self.config.name, {"output": {"search_results": search_results}}
        )

        await state.update(
            self.config.name,
            {
                "output": {
                    "search_queries": search_queries,
                    "search_results": search_results,
                }
            },
        )