aboutsummaryrefslogtreecommitdiff
import uuid
from copy import copy
from typing import Any, AsyncGenerator, Optional

from r2r.base.abstractions.llm import GenerationConfig
from r2r.base.abstractions.search import VectorSearchResult
from r2r.base.pipes.base_pipe import AsyncPipe

from ..abstractions.search_pipe import SearchPipe
from .query_transform_pipe import QueryTransformPipe


class MultiSearchPipe(AsyncPipe):
    class PipeConfig(AsyncPipe.PipeConfig):
        name: str = "multi_search_pipe"

    def __init__(
        self,
        query_transform_pipe: QueryTransformPipe,
        inner_search_pipe: SearchPipe,
        config: Optional[PipeConfig] = None,
        *args,
        **kwargs,
    ):
        self.query_transform_pipe = query_transform_pipe
        self.vector_search_pipe = inner_search_pipe
        if (
            not query_transform_pipe.config.name
            == inner_search_pipe.config.name
        ):
            raise ValueError(
                "The query transform pipe and search pipe must have the same name."
            )
        if config and not config.name == query_transform_pipe.config.name:
            raise ValueError(
                "The pipe config name must match the query transform pipe name."
            )

        super().__init__(
            config=config
            or MultiSearchPipe.PipeConfig(
                name=query_transform_pipe.config.name
            ),
            *args,
            **kwargs,
        )

    async def _run_logic(
        self,
        input: Any,
        state: Any,
        run_id: uuid.UUID,
        query_transform_generation_config: Optional[GenerationConfig] = None,
        *args: Any,
        **kwargs: Any,
    ) -> AsyncGenerator[VectorSearchResult, None]:
        query_transform_generation_config = (
            query_transform_generation_config
            or copy(kwargs.get("rag_generation_config", None))
            or GenerationConfig(model="gpt-4o")
        )
        query_transform_generation_config.stream = False

        query_generator = await self.query_transform_pipe.run(
            input,
            state,
            query_transform_generation_config=query_transform_generation_config,
            num_query_xf_outputs=3,
            *args,
            **kwargs,
        )

        async for search_result in await self.vector_search_pipe.run(
            self.vector_search_pipe.Input(message=query_generator),
            state,
            *args,
            **kwargs,
        ):
            yield search_result