1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
|
import json
import logging
import uuid
from typing import Any, AsyncGenerator, Generator, Optional
from r2r.base import (
AsyncState,
LLMChatCompletionChunk,
LLMProvider,
PipeType,
PromptProvider,
)
from r2r.base.abstractions.llm import GenerationConfig
from ..abstractions.generator_pipe import GeneratorPipe
from .search_rag_pipe import SearchRAGPipe
logger = logging.getLogger(__name__)
class StreamingSearchRAGPipe(SearchRAGPipe):
SEARCH_STREAM_MARKER = "search"
COMPLETION_STREAM_MARKER = "completion"
def __init__(
self,
llm_provider: LLMProvider,
prompt_provider: PromptProvider,
type: PipeType = PipeType.GENERATOR,
config: Optional[GeneratorPipe] = None,
*args,
**kwargs,
):
super().__init__(
llm_provider=llm_provider,
prompt_provider=prompt_provider,
type=type,
config=config
or GeneratorPipe.Config(
name="default_streaming_rag_pipe", task_prompt="default_rag"
),
*args,
**kwargs,
)
async def _run_logic(
self,
input: SearchRAGPipe.Input,
state: AsyncState,
run_id: uuid.UUID,
rag_generation_config: GenerationConfig,
*args: Any,
**kwargs: Any,
) -> AsyncGenerator[str, None]:
iteration = 0
context = ""
# dump the search results and construct the context
async for query, search_results in input.message:
yield f"<{self.SEARCH_STREAM_MARKER}>"
if search_results.vector_search_results:
context += "Vector Search Results:\n"
for result in search_results.vector_search_results:
if iteration >= 1:
yield ","
yield json.dumps(result.json())
context += (
f"{iteration + 1}:\n{result.metadata['text']}\n\n"
)
iteration += 1
# if search_results.kg_search_results:
# for result in search_results.kg_search_results:
# if iteration >= 1:
# yield ","
# yield json.dumps(result.json())
# context += f"Result {iteration+1}:\n{result.metadata['text']}\n\n"
# iteration += 1
yield f"</{self.SEARCH_STREAM_MARKER}>"
messages = self._get_message_payload(query, context)
yield f"<{self.COMPLETION_STREAM_MARKER}>"
response = ""
for chunk in self.llm_provider.get_completion_stream(
messages=messages, generation_config=rag_generation_config
):
chunk = StreamingSearchRAGPipe._process_chunk(chunk)
response += chunk
yield chunk
yield f"</{self.COMPLETION_STREAM_MARKER}>"
await self.enqueue_log(
run_id=run_id,
key="llm_response",
value=response,
)
async def _yield_chunks(
self,
start_marker: str,
chunks: Generator[str, None, None],
end_marker: str,
) -> str:
yield start_marker
for chunk in chunks:
yield chunk
yield end_marker
def _get_message_payload(
self, query: str, context: str
) -> list[dict[str, str]]:
return [
{
"role": "system",
"content": self.prompt_provider.get_prompt(
self.config.system_prompt
),
},
{
"role": "user",
"content": self.prompt_provider.get_prompt(
self.config.task_prompt,
inputs={"query": query, "context": context},
),
},
]
@staticmethod
def _process_chunk(chunk: LLMChatCompletionChunk) -> str:
return chunk.choices[0].delta.content or ""
|