about summary refs log tree commit diff
path: root/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/pipes/retrieval/streaming_rag_pipe.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/r2r/pipes/retrieval/streaming_rag_pipe.py')
-rwxr-xr-xR2R/r2r/pipes/retrieval/streaming_rag_pipe.py131
1 files changed, 131 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py b/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py
new file mode 100755
index 00000000..b01f6445
--- /dev/null
+++ b/R2R/r2r/pipes/retrieval/streaming_rag_pipe.py
@@ -0,0 +1,131 @@
+import json
+import logging
+import uuid
+from typing import Any, AsyncGenerator, Generator, Optional
+
+from r2r.base import (
+    AsyncState,
+    LLMChatCompletionChunk,
+    LLMProvider,
+    PipeType,
+    PromptProvider,
+)
+from r2r.base.abstractions.llm import GenerationConfig
+
+from ..abstractions.generator_pipe import GeneratorPipe
+from .search_rag_pipe import SearchRAGPipe
+
+logger = logging.getLogger(__name__)
+
+
+class StreamingSearchRAGPipe(SearchRAGPipe):
+    SEARCH_STREAM_MARKER = "search"
+    COMPLETION_STREAM_MARKER = "completion"
+
+    def __init__(
+        self,
+        llm_provider: LLMProvider,
+        prompt_provider: PromptProvider,
+        type: PipeType = PipeType.GENERATOR,
+        config: Optional[GeneratorPipe] = None,
+        *args,
+        **kwargs,
+    ):
+        super().__init__(
+            llm_provider=llm_provider,
+            prompt_provider=prompt_provider,
+            type=type,
+            config=config
+            or GeneratorPipe.Config(
+                name="default_streaming_rag_pipe", task_prompt="default_rag"
+            ),
+            *args,
+            **kwargs,
+        )
+
+    async def _run_logic(
+        self,
+        input: SearchRAGPipe.Input,
+        state: AsyncState,
+        run_id: uuid.UUID,
+        rag_generation_config: GenerationConfig,
+        *args: Any,
+        **kwargs: Any,
+    ) -> AsyncGenerator[str, None]:
+        iteration = 0
+        context = ""
+        # dump the search results and construct the context
+        async for query, search_results in input.message:
+            yield f"<{self.SEARCH_STREAM_MARKER}>"
+            if search_results.vector_search_results:
+                context += "Vector Search Results:\n"
+                for result in search_results.vector_search_results:
+                    if iteration >= 1:
+                        yield ","
+                    yield json.dumps(result.json())
+                    context += (
+                        f"{iteration + 1}:\n{result.metadata['text']}\n\n"
+                    )
+                    iteration += 1
+
+            # if search_results.kg_search_results:
+            #     for result in search_results.kg_search_results:
+            #         if iteration >= 1:
+            #             yield ","
+            #         yield json.dumps(result.json())
+            #         context += f"Result {iteration+1}:\n{result.metadata['text']}\n\n"
+            #         iteration += 1
+
+            yield f"</{self.SEARCH_STREAM_MARKER}>"
+
+            messages = self._get_message_payload(query, context)
+            yield f"<{self.COMPLETION_STREAM_MARKER}>"
+            response = ""
+            for chunk in self.llm_provider.get_completion_stream(
+                messages=messages, generation_config=rag_generation_config
+            ):
+                chunk = StreamingSearchRAGPipe._process_chunk(chunk)
+                response += chunk
+                yield chunk
+
+            yield f"</{self.COMPLETION_STREAM_MARKER}>"
+
+            await self.enqueue_log(
+                run_id=run_id,
+                key="llm_response",
+                value=response,
+            )
+
+    async def _yield_chunks(
+        self,
+        start_marker: str,
+        chunks: Generator[str, None, None],
+        end_marker: str,
+    ) -> str:
+        yield start_marker
+        for chunk in chunks:
+            yield chunk
+        yield end_marker
+
+    def _get_message_payload(
+        self, query: str, context: str
+    ) -> list[dict[str, str]]:
+        return [
+            {
+                "role": "system",
+                "content": self.prompt_provider.get_prompt(
+                    self.config.system_prompt
+                ),
+            },
+            {
+                "role": "user",
+                "content": self.prompt_provider.get_prompt(
+                    self.config.task_prompt,
+                    inputs={"query": query, "context": context},
+                ),
+            },
+        ]
+
+    @staticmethod
+    def _process_chunk(chunk: LLMChatCompletionChunk) -> str:
+        return chunk.choices[0].delta.content or ""