import logging import uuid from typing import Any, AsyncGenerator, Optional, Tuple from r2r.base import ( AggregateSearchResult, AsyncPipe, AsyncState, LLMProvider, PipeType, PromptProvider, ) from r2r.base.abstractions.llm import GenerationConfig, RAGCompletion from ..abstractions.generator_pipe import GeneratorPipe logger = logging.getLogger(__name__) class SearchRAGPipe(GeneratorPipe): class Input(AsyncPipe.Input): message: AsyncGenerator[Tuple[str, AggregateSearchResult], None] def __init__( self, llm_provider: LLMProvider, prompt_provider: PromptProvider, type: PipeType = PipeType.GENERATOR, config: Optional[GeneratorPipe] = None, *args, **kwargs, ): super().__init__( llm_provider=llm_provider, prompt_provider=prompt_provider, type=type, config=config or GeneratorPipe.Config( name="default_rag_pipe", task_prompt="default_rag" ), *args, **kwargs, ) async def _run_logic( self, input: Input, state: AsyncState, run_id: uuid.UUID, rag_generation_config: GenerationConfig, *args: Any, **kwargs: Any, ) -> AsyncGenerator[RAGCompletion, None]: context = "" search_iteration = 1 total_results = 0 # must select a query if there are multiple sel_query = None async for query, search_results in input.message: if search_iteration == 1: sel_query = query context_piece, total_results = await self._collect_context( query, search_results, search_iteration, total_results ) context += context_piece search_iteration += 1 messages = self._get_message_payload(sel_query, context) response = await self.llm_provider.aget_completion( messages=messages, generation_config=rag_generation_config ) yield RAGCompletion(completion=response, search_results=search_results) await self.enqueue_log( run_id=run_id, key="llm_response", value=response.choices[0].message.content, ) def _get_message_payload(self, query: str, context: str) -> dict: return [ { "role": "system", "content": self.prompt_provider.get_prompt( self.config.system_prompt, ), }, { "role": "user", "content": self.prompt_provider.get_prompt( self.config.task_prompt, inputs={ "query": query, "context": context, }, ), }, ] async def _collect_context( self, query: str, results: AggregateSearchResult, iteration: int, total_results: int, ) -> Tuple[str, int]: context = f"Query:\n{query}\n\n" if results.vector_search_results: context += f"Vector Search Results({iteration}):\n" it = total_results + 1 for result in results.vector_search_results: context += f"[{it}]: {result.metadata['text']}\n\n" it += 1 total_results = ( it - 1 ) # Update total_results based on the last index used if results.kg_search_results: context += f"Knowledge Graph ({iteration}):\n" it = total_results + 1 for query, search_results in results.kg_search_results: # [1]: context += f"Query: {query}\n\n" context += f"Results:\n" for search_result in search_results: context += f"[{it}]: {search_result}\n\n" it += 1 total_results = ( it - 1 ) # Update total_results based on the last index used return context, total_results