aboutsummaryrefslogtreecommitdiff
import json
import logging
import uuid
from typing import Any, AsyncGenerator, Optional

from r2r.base import (
    AsyncPipe,
    AsyncState,
    PipeType,
    VectorSearchResult,
    generate_id_from_label,
)
from r2r.integrations import SerperClient

from ..abstractions.search_pipe import SearchPipe

logger = logging.getLogger(__name__)


class WebSearchPipe(SearchPipe):
    def __init__(
        self,
        serper_client: SerperClient,
        type: PipeType = PipeType.SEARCH,
        config: Optional[SearchPipe.SearchConfig] = None,
        *args,
        **kwargs,
    ):
        super().__init__(
            type=type,
            config=config or SearchPipe.SearchConfig(),
            *args,
            **kwargs,
        )
        self.serper_client = serper_client

    async def search(
        self,
        message: str,
        run_id: uuid.UUID,
        *args: Any,
        **kwargs: Any,
    ) -> AsyncGenerator[VectorSearchResult, None]:
        search_limit_override = kwargs.get("search_limit", None)
        await self.enqueue_log(
            run_id=run_id, key="search_query", value=message
        )
        # TODO - Make more general in the future by creating a SearchProvider interface
        results = self.serper_client.get_raw(
            query=message,
            limit=search_limit_override or self.config.search_limit,
        )

        search_results = []
        for result in results:
            if result.get("snippet") is None:
                continue
            result["text"] = result.pop("snippet")
            search_result = VectorSearchResult(
                id=generate_id_from_label(str(result)),
                score=result.get(
                    "score", 0
                ),  # TODO - Consider dynamically generating scores based on similarity
                metadata=result,
            )
            search_results.append(search_result)
            yield search_result

        await self.enqueue_log(
            run_id=run_id,
            key="search_results",
            value=json.dumps([ele.json() for ele in search_results]),
        )

    async def _run_logic(
        self,
        input: AsyncPipe.Input,
        state: AsyncState,
        run_id: uuid.UUID,
        *args: Any,
        **kwargs,
    ) -> AsyncGenerator[VectorSearchResult, None]:
        search_queries = []
        search_results = []
        async for search_request in input.message:
            search_queries.append(search_request)
            async for result in self.search(
                message=search_request, run_id=run_id, *args, **kwargs
            ):
                search_results.append(result)
                yield result

        await state.update(
            self.config.name, {"output": {"search_results": search_results}}
        )

        await state.update(
            self.config.name,
            {
                "output": {
                    "search_queries": search_queries,
                    "search_results": search_results,
                }
            },
        )