import logging
import uuid
from typing import Any, AsyncGenerator, Optional, Tuple
from r2r.base import (
AggregateSearchResult,
AsyncPipe,
AsyncState,
LLMProvider,
PipeType,
PromptProvider,
)
from r2r.base.abstractions.llm import GenerationConfig, RAGCompletion
from ..abstractions.generator_pipe import GeneratorPipe
logger = logging.getLogger(__name__)
class SearchRAGPipe(GeneratorPipe):
class Input(AsyncPipe.Input):
message: AsyncGenerator[Tuple[str, AggregateSearchResult], None]
def __init__(
self,
llm_provider: LLMProvider,
prompt_provider: PromptProvider,
type: PipeType = PipeType.GENERATOR,
config: Optional[GeneratorPipe] = None,
*args,
**kwargs,
):
super().__init__(
llm_provider=llm_provider,
prompt_provider=prompt_provider,
type=type,
config=config
or GeneratorPipe.Config(
name="default_rag_pipe", task_prompt="default_rag"
),
*args,
**kwargs,
)
async def _run_logic(
self,
input: Input,
state: AsyncState,
run_id: uuid.UUID,
rag_generation_config: GenerationConfig,
*args: Any,
**kwargs: Any,
) -> AsyncGenerator[RAGCompletion, None]:
context = ""
search_iteration = 1
total_results = 0
# must select a query if there are multiple
sel_query = None
async for query, search_results in input.message:
if search_iteration == 1:
sel_query = query
context_piece, total_results = await self._collect_context(
query, search_results, search_iteration, total_results
)
context += context_piece
search_iteration += 1
messages = self._get_message_payload(sel_query, context)
response = await self.llm_provider.aget_completion(
messages=messages, generation_config=rag_generation_config
)
yield RAGCompletion(completion=response, search_results=search_results)
await self.enqueue_log(
run_id=run_id,
key="llm_response",
value=response.choices[0].message.content,
)
def _get_message_payload(self, query: str, context: str) -> dict:
return [
{
"role": "system",
"content": self.prompt_provider.get_prompt(
self.config.system_prompt,
),
},
{
"role": "user",
"content": self.prompt_provider.get_prompt(
self.config.task_prompt,
inputs={
"query": query,
"context": context,
},
),
},
]
async def _collect_context(
self,
query: str,
results: AggregateSearchResult,
iteration: int,
total_results: int,
) -> Tuple[str, int]:
context = f"Query:\n{query}\n\n"
if results.vector_search_results:
context += f"Vector Search Results({iteration}):\n"
it = total_results + 1
for result in results.vector_search_results:
context += f"[{it}]: {result.metadata['text']}\n\n"
it += 1
total_results = (
it - 1
) # Update total_results based on the last index used
if results.kg_search_results:
context += f"Knowledge Graph ({iteration}):\n"
it = total_results + 1
for query, search_results in results.kg_search_results: # [1]:
context += f"Query: {query}\n\n"
context += f"Results:\n"
for search_result in search_results:
context += f"[{it}]: {search_result}\n\n"
it += 1
total_results = (
it - 1
) # Update total_results based on the last index used
return context, total_results