diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/pipes/retrieval/streaming_rag_pipe.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to 'R2R/r2r/pipes/retrieval/streaming_rag_pipe.py')
-rwxr-xr-x | R2R/r2r/pipes/retrieval/streaming_rag_pipe.py | 131 |
1 files changed, 131 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py b/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py new file mode 100755 index 00000000..b01f6445 --- /dev/null +++ b/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py @@ -0,0 +1,131 @@ +import json +import logging +import uuid +from typing import Any, AsyncGenerator, Generator, Optional + +from r2r.base import ( + AsyncState, + LLMChatCompletionChunk, + LLMProvider, + PipeType, + PromptProvider, +) +from r2r.base.abstractions.llm import GenerationConfig + +from ..abstractions.generator_pipe import GeneratorPipe +from .search_rag_pipe import SearchRAGPipe + +logger = logging.getLogger(__name__) + + +class StreamingSearchRAGPipe(SearchRAGPipe): + SEARCH_STREAM_MARKER = "search" + COMPLETION_STREAM_MARKER = "completion" + + def __init__( + self, + llm_provider: LLMProvider, + prompt_provider: PromptProvider, + type: PipeType = PipeType.GENERATOR, + config: Optional[GeneratorPipe] = None, + *args, + **kwargs, + ): + super().__init__( + llm_provider=llm_provider, + prompt_provider=prompt_provider, + type=type, + config=config + or GeneratorPipe.Config( + name="default_streaming_rag_pipe", task_prompt="default_rag" + ), + *args, + **kwargs, + ) + + async def _run_logic( + self, + input: SearchRAGPipe.Input, + state: AsyncState, + run_id: uuid.UUID, + rag_generation_config: GenerationConfig, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[str, None]: + iteration = 0 + context = "" + # dump the search results and construct the context + async for query, search_results in input.message: + yield f"<{self.SEARCH_STREAM_MARKER}>" + if search_results.vector_search_results: + context += "Vector Search Results:\n" + for result in search_results.vector_search_results: + if iteration >= 1: + yield "," + yield json.dumps(result.json()) + context += ( + f"{iteration + 1}:\n{result.metadata['text']}\n\n" + ) + iteration += 1 + + # if search_results.kg_search_results: + # for result in search_results.kg_search_results: + # if iteration >= 1: + # yield "," + # yield json.dumps(result.json()) + # context += f"Result {iteration+1}:\n{result.metadata['text']}\n\n" + # iteration += 1 + + yield f"</{self.SEARCH_STREAM_MARKER}>" + + messages = self._get_message_payload(query, context) + yield f"<{self.COMPLETION_STREAM_MARKER}>" + response = "" + for chunk in self.llm_provider.get_completion_stream( + messages=messages, generation_config=rag_generation_config + ): + chunk = StreamingSearchRAGPipe._process_chunk(chunk) + response += chunk + yield chunk + + yield f"</{self.COMPLETION_STREAM_MARKER}>" + + await self.enqueue_log( + run_id=run_id, + key="llm_response", + value=response, + ) + + async def _yield_chunks( + self, + start_marker: str, + chunks: Generator[str, None, None], + end_marker: str, + ) -> str: + yield start_marker + for chunk in chunks: + yield chunk + yield end_marker + + def _get_message_payload( + self, query: str, context: str + ) -> list[dict[str, str]]: + return [ + { + "role": "system", + "content": self.prompt_provider.get_prompt( + self.config.system_prompt + ), + }, + { + "role": "user", + "content": self.prompt_provider.get_prompt( + self.config.task_prompt, + inputs={"query": query, "context": context}, + ), + }, + ] + + @staticmethod + def _process_chunk(chunk: LLMChatCompletionChunk) -> str: + return chunk.choices[0].delta.content or "" |