diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/pipes/retrieval/search_rag_pipe.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to 'R2R/r2r/pipes/retrieval/search_rag_pipe.py')
-rwxr-xr-x | R2R/r2r/pipes/retrieval/search_rag_pipe.py | 130 |
1 files changed, 130 insertions, 0 deletions
diff --git a/R2R/r2r/pipes/retrieval/search_rag_pipe.py b/R2R/r2r/pipes/retrieval/search_rag_pipe.py new file mode 100755 index 00000000..4d01d2df --- /dev/null +++ b/R2R/r2r/pipes/retrieval/search_rag_pipe.py @@ -0,0 +1,130 @@ +import logging +import uuid +from typing import Any, AsyncGenerator, Optional, Tuple + +from r2r.base import ( + AggregateSearchResult, + AsyncPipe, + AsyncState, + LLMProvider, + PipeType, + PromptProvider, +) +from r2r.base.abstractions.llm import GenerationConfig, RAGCompletion + +from ..abstractions.generator_pipe import GeneratorPipe + +logger = logging.getLogger(__name__) + + +class SearchRAGPipe(GeneratorPipe): + class Input(AsyncPipe.Input): + message: AsyncGenerator[Tuple[str, AggregateSearchResult], None] + + def __init__( + self, + llm_provider: LLMProvider, + prompt_provider: PromptProvider, + type: PipeType = PipeType.GENERATOR, + config: Optional[GeneratorPipe] = None, + *args, + **kwargs, + ): + super().__init__( + llm_provider=llm_provider, + prompt_provider=prompt_provider, + type=type, + config=config + or GeneratorPipe.Config( + name="default_rag_pipe", task_prompt="default_rag" + ), + *args, + **kwargs, + ) + + async def _run_logic( + self, + input: Input, + state: AsyncState, + run_id: uuid.UUID, + rag_generation_config: GenerationConfig, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator[RAGCompletion, None]: + context = "" + search_iteration = 1 + total_results = 0 + # must select a query if there are multiple + sel_query = None + async for query, search_results in input.message: + if search_iteration == 1: + sel_query = query + context_piece, total_results = await self._collect_context( + query, search_results, search_iteration, total_results + ) + context += context_piece + search_iteration += 1 + + messages = self._get_message_payload(sel_query, context) + + response = await self.llm_provider.aget_completion( + messages=messages, generation_config=rag_generation_config + ) + yield RAGCompletion(completion=response, search_results=search_results) + + await self.enqueue_log( + run_id=run_id, + key="llm_response", + value=response.choices[0].message.content, + ) + + def _get_message_payload(self, query: str, context: str) -> dict: + return [ + { + "role": "system", + "content": self.prompt_provider.get_prompt( + self.config.system_prompt, + ), + }, + { + "role": "user", + "content": self.prompt_provider.get_prompt( + self.config.task_prompt, + inputs={ + "query": query, + "context": context, + }, + ), + }, + ] + + async def _collect_context( + self, + query: str, + results: AggregateSearchResult, + iteration: int, + total_results: int, + ) -> Tuple[str, int]: + context = f"Query:\n{query}\n\n" + if results.vector_search_results: + context += f"Vector Search Results({iteration}):\n" + it = total_results + 1 + for result in results.vector_search_results: + context += f"[{it}]: {result.metadata['text']}\n\n" + it += 1 + total_results = ( + it - 1 + ) # Update total_results based on the last index used + if results.kg_search_results: + context += f"Knowledge Graph ({iteration}):\n" + it = total_results + 1 + for query, search_results in results.kg_search_results: # [1]: + context += f"Query: {query}\n\n" + context += f"Results:\n" + for search_result in search_results: + context += f"[{it}]: {search_result}\n\n" + it += 1 + total_results = ( + it - 1 + ) # Update total_results based on the last index used + return context, total_results |