import logging import time import uuid from typing import Optional from r2r.base import ( GenerationConfig, KGSearchSettings, KVLoggingSingleton, R2RException, RunManager, VectorSearchSettings, manage_run, to_async_generator, ) from r2r.pipes import EvalPipe from r2r.telemetry.telemetry_decorator import telemetry_event from ..abstractions import R2RPipelines, R2RProviders from ..assembly.config import R2RConfig from .base import Service logger = logging.getLogger(__name__) class RetrievalService(Service): def __init__( self, config: R2RConfig, providers: R2RProviders, pipelines: R2RPipelines, run_manager: RunManager, logging_connection: KVLoggingSingleton, ): super().__init__( config, providers, pipelines, run_manager, logging_connection ) @telemetry_event("Search") async def search( self, query: str, vector_search_settings: VectorSearchSettings = VectorSearchSettings(), kg_search_settings: KGSearchSettings = KGSearchSettings(), *args, **kwargs, ): async with manage_run(self.run_manager, "search_app") as run_id: t0 = time.time() if ( kg_search_settings.use_kg_search and self.config.kg.provider is None ): raise R2RException( status_code=400, message="Knowledge Graph search is not enabled in the configuration.", ) if ( vector_search_settings.use_vector_search and self.config.vector_database.provider is None ): raise R2RException( status_code=400, message="Vector search is not enabled in the configuration.", ) # TODO - Remove these transforms once we have a better way to handle this for filter, value in vector_search_settings.search_filters.items(): if isinstance(value, uuid.UUID): vector_search_settings.search_filters[filter] = str(value) results = await self.pipelines.search_pipeline.run( input=to_async_generator([query]), vector_search_settings=vector_search_settings, kg_search_settings=kg_search_settings, run_manager=self.run_manager, *args, **kwargs, ) t1 = time.time() latency = f"{t1 - t0:.2f}" await self.logging_connection.log( log_id=run_id, key="search_latency", value=latency, is_info_log=False, ) return results.dict() @telemetry_event("RAG") async def rag( self, query: str, rag_generation_config: GenerationConfig, vector_search_settings: VectorSearchSettings = VectorSearchSettings(), kg_search_settings: KGSearchSettings = KGSearchSettings(), *args, **kwargs, ): async with manage_run(self.run_manager, "rag_app") as run_id: try: t0 = time.time() # TODO - Remove these transforms once we have a better way to handle this for ( filter, value, ) in vector_search_settings.search_filters.items(): if isinstance(value, uuid.UUID): vector_search_settings.search_filters[filter] = str( value ) if rag_generation_config.stream: t1 = time.time() latency = f"{t1 - t0:.2f}" await self.logging_connection.log( log_id=run_id, key="rag_generation_latency", value=latency, is_info_log=False, ) async def stream_response(): async with manage_run(self.run_manager, "arag"): async for ( chunk ) in await self.pipelines.streaming_rag_pipeline.run( input=to_async_generator([query]), run_manager=self.run_manager, vector_search_settings=vector_search_settings, kg_search_settings=kg_search_settings, rag_generation_config=rag_generation_config, ): yield chunk return stream_response() results = await self.pipelines.rag_pipeline.run( input=to_async_generator([query]), run_manager=self.run_manager, vector_search_settings=vector_search_settings, kg_search_settings=kg_search_settings, rag_generation_config=rag_generation_config, *args, **kwargs, ) t1 = time.time() latency = f"{t1 - t0:.2f}" await self.logging_connection.log( log_id=run_id, key="rag_generation_latency", value=latency, is_info_log=False, ) if len(results) == 0: raise R2RException( status_code=404, message="No results found" ) if len(results) > 1: logger.warning( f"Multiple results found for query: {query}" ) # unpack the first result return results[0] except Exception as e: logger.error(f"Pipeline error: {str(e)}") if "NoneType" in str(e): raise R2RException( status_code=502, message="Ollama server not reachable or returned an invalid response", ) raise R2RException( status_code=500, message="Internal Server Error" ) @telemetry_event("Evaluate") async def evaluate( self, query: str, context: str, completion: str, eval_generation_config: Optional[GenerationConfig], *args, **kwargs, ): eval_payload = EvalPipe.EvalPayload( query=query, context=context, completion=completion, ) result = await self.eval_pipeline.run( input=to_async_generator([eval_payload]), run_manager=self.run_manager, eval_generation_config=eval_generation_config, ) return result