aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py
blob: b01f64457c7a401f86cd510cca8cfda1190ff810 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import json
import logging
import uuid
from typing import Any, AsyncGenerator, Generator, Optional

from r2r.base import (
    AsyncState,
    LLMChatCompletionChunk,
    LLMProvider,
    PipeType,
    PromptProvider,
)
from r2r.base.abstractions.llm import GenerationConfig

from ..abstractions.generator_pipe import GeneratorPipe
from .search_rag_pipe import SearchRAGPipe

logger = logging.getLogger(__name__)


class StreamingSearchRAGPipe(SearchRAGPipe):
    SEARCH_STREAM_MARKER = "search"
    COMPLETION_STREAM_MARKER = "completion"

    def __init__(
        self,
        llm_provider: LLMProvider,
        prompt_provider: PromptProvider,
        type: PipeType = PipeType.GENERATOR,
        config: Optional[GeneratorPipe] = None,
        *args,
        **kwargs,
    ):
        super().__init__(
            llm_provider=llm_provider,
            prompt_provider=prompt_provider,
            type=type,
            config=config
            or GeneratorPipe.Config(
                name="default_streaming_rag_pipe", task_prompt="default_rag"
            ),
            *args,
            **kwargs,
        )

    async def _run_logic(
        self,
        input: SearchRAGPipe.Input,
        state: AsyncState,
        run_id: uuid.UUID,
        rag_generation_config: GenerationConfig,
        *args: Any,
        **kwargs: Any,
    ) -> AsyncGenerator[str, None]:
        iteration = 0
        context = ""
        # dump the search results and construct the context
        async for query, search_results in input.message:
            yield f"<{self.SEARCH_STREAM_MARKER}>"
            if search_results.vector_search_results:
                context += "Vector Search Results:\n"
                for result in search_results.vector_search_results:
                    if iteration >= 1:
                        yield ","
                    yield json.dumps(result.json())
                    context += (
                        f"{iteration + 1}:\n{result.metadata['text']}\n\n"
                    )
                    iteration += 1

            # if search_results.kg_search_results:
            #     for result in search_results.kg_search_results:
            #         if iteration >= 1:
            #             yield ","
            #         yield json.dumps(result.json())
            #         context += f"Result {iteration+1}:\n{result.metadata['text']}\n\n"
            #         iteration += 1

            yield f"</{self.SEARCH_STREAM_MARKER}>"

            messages = self._get_message_payload(query, context)
            yield f"<{self.COMPLETION_STREAM_MARKER}>"
            response = ""
            for chunk in self.llm_provider.get_completion_stream(
                messages=messages, generation_config=rag_generation_config
            ):
                chunk = StreamingSearchRAGPipe._process_chunk(chunk)
                response += chunk
                yield chunk

            yield f"</{self.COMPLETION_STREAM_MARKER}>"

            await self.enqueue_log(
                run_id=run_id,
                key="llm_response",
                value=response,
            )

    async def _yield_chunks(
        self,
        start_marker: str,
        chunks: Generator[str, None, None],
        end_marker: str,
    ) -> str:
        yield start_marker
        for chunk in chunks:
            yield chunk
        yield end_marker

    def _get_message_payload(
        self, query: str, context: str
    ) -> list[dict[str, str]]:
        return [
            {
                "role": "system",
                "content": self.prompt_provider.get_prompt(
                    self.config.system_prompt
                ),
            },
            {
                "role": "user",
                "content": self.prompt_provider.get_prompt(
                    self.config.task_prompt,
                    inputs={"query": query, "context": context},
                ),
            },
        ]

    @staticmethod
    def _process_chunk(chunk: LLMChatCompletionChunk) -> str:
        return chunk.choices[0].delta.content or ""