aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/pipelines/rag_pipeline.py
blob: b257ccaa91b7c887ce4bb04568ca85a7d3b5e82f (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import asyncio
import logging
from typing import Any, Optional

from ..base.abstractions.llm import GenerationConfig
from ..base.abstractions.search import KGSearchSettings, VectorSearchSettings
from ..base.logging.kv_logger import KVLoggingSingleton
from ..base.logging.run_manager import RunManager, manage_run
from ..base.pipeline.base_pipeline import AsyncPipeline
from ..base.pipes.base_pipe import AsyncPipe, AsyncState
from ..base.utils import to_async_generator

logger = logging.getLogger(__name__)


class RAGPipeline(AsyncPipeline):
    """A pipeline for RAG."""

    pipeline_type: str = "rag"

    def __init__(
        self,
        pipe_logger: Optional[KVLoggingSingleton] = None,
        run_manager: Optional[RunManager] = None,
    ):
        super().__init__(pipe_logger, run_manager)
        self._search_pipeline = None
        self._rag_pipeline = None

    async def run(
        self,
        input: Any,
        state: Optional[AsyncState] = None,
        run_manager: Optional[RunManager] = None,
        log_run_info=True,
        vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
        kg_search_settings: KGSearchSettings = KGSearchSettings(),
        rag_generation_config: GenerationConfig = GenerationConfig(),
        *args: Any,
        **kwargs: Any,
    ):
        self.state = state or AsyncState()
        async with manage_run(run_manager, self.pipeline_type):
            if log_run_info:
                await run_manager.log_run_info(
                    key="pipeline_type",
                    value=self.pipeline_type,
                    is_info_log=True,
                )

            if not self._search_pipeline:
                raise ValueError(
                    "_search_pipeline must be set before running the RAG pipeline"
                )

            async def multi_query_generator(input):
                tasks = []
                async for query in input:
                    task = asyncio.create_task(
                        self._search_pipeline.run(
                            to_async_generator([query]),
                            state=state,
                            stream=False,  # do not stream the search results
                            run_manager=run_manager,
                            log_run_info=False,  # do not log the run info as it is already logged above
                            vector_search_settings=vector_search_settings,
                            kg_search_settings=kg_search_settings,
                            *args,
                            **kwargs,
                        )
                    )
                    tasks.append((query, task))

                for query, task in tasks:
                    yield (query, await task)

            rag_results = await self._rag_pipeline.run(
                input=multi_query_generator(input),
                state=state,
                stream=rag_generation_config.stream,
                run_manager=run_manager,
                log_run_info=False,
                rag_generation_config=rag_generation_config,
                *args,
                **kwargs,
            )
            return rag_results

    def add_pipe(
        self,
        pipe: AsyncPipe,
        add_upstream_outputs: Optional[list[dict[str, str]]] = None,
        rag_pipe: bool = True,
        *args,
        **kwargs,
    ) -> None:
        logger.debug(f"Adding pipe {pipe.config.name} to the RAGPipeline")
        if not rag_pipe:
            raise ValueError(
                "Only pipes that are part of the RAG pipeline can be added to the RAG pipeline"
            )
        if not self._rag_pipeline:
            self._rag_pipeline = AsyncPipeline()
        self._rag_pipeline.add_pipe(
            pipe, add_upstream_outputs, *args, **kwargs
        )

    def set_search_pipeline(
        self,
        _search_pipeline: AsyncPipeline,
        *args,
        **kwargs,
    ) -> None:
        logger.debug(f"Setting search pipeline for the RAGPipeline")
        self._search_pipeline = _search_pipeline