aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/main/services/retrieval_service.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/main/services/retrieval_service.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to 'R2R/r2r/main/services/retrieval_service.py')
-rwxr-xr-xR2R/r2r/main/services/retrieval_service.py207
1 files changed, 207 insertions, 0 deletions
diff --git a/R2R/r2r/main/services/retrieval_service.py b/R2R/r2r/main/services/retrieval_service.py
new file mode 100755
index 00000000..c4f6aff5
--- /dev/null
+++ b/R2R/r2r/main/services/retrieval_service.py
@@ -0,0 +1,207 @@
+import logging
+import time
+import uuid
+from typing import Optional
+
+from r2r.base import (
+ GenerationConfig,
+ KGSearchSettings,
+ KVLoggingSingleton,
+ R2RException,
+ RunManager,
+ VectorSearchSettings,
+ manage_run,
+ to_async_generator,
+)
+from r2r.pipes import EvalPipe
+from r2r.telemetry.telemetry_decorator import telemetry_event
+
+from ..abstractions import R2RPipelines, R2RProviders
+from ..assembly.config import R2RConfig
+from .base import Service
+
+logger = logging.getLogger(__name__)
+
+
+class RetrievalService(Service):
+ def __init__(
+ self,
+ config: R2RConfig,
+ providers: R2RProviders,
+ pipelines: R2RPipelines,
+ run_manager: RunManager,
+ logging_connection: KVLoggingSingleton,
+ ):
+ super().__init__(
+ config, providers, pipelines, run_manager, logging_connection
+ )
+
+ @telemetry_event("Search")
+ async def search(
+ self,
+ query: str,
+ vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
+ kg_search_settings: KGSearchSettings = KGSearchSettings(),
+ *args,
+ **kwargs,
+ ):
+ async with manage_run(self.run_manager, "search_app") as run_id:
+ t0 = time.time()
+
+ if (
+ kg_search_settings.use_kg_search
+ and self.config.kg.provider is None
+ ):
+ raise R2RException(
+ status_code=400,
+ message="Knowledge Graph search is not enabled in the configuration.",
+ )
+
+ if (
+ vector_search_settings.use_vector_search
+ and self.config.vector_database.provider is None
+ ):
+ raise R2RException(
+ status_code=400,
+ message="Vector search is not enabled in the configuration.",
+ )
+
+ # TODO - Remove these transforms once we have a better way to handle this
+ for filter, value in vector_search_settings.search_filters.items():
+ if isinstance(value, uuid.UUID):
+ vector_search_settings.search_filters[filter] = str(value)
+
+ results = await self.pipelines.search_pipeline.run(
+ input=to_async_generator([query]),
+ vector_search_settings=vector_search_settings,
+ kg_search_settings=kg_search_settings,
+ run_manager=self.run_manager,
+ *args,
+ **kwargs,
+ )
+
+ t1 = time.time()
+ latency = f"{t1 - t0:.2f}"
+
+ await self.logging_connection.log(
+ log_id=run_id,
+ key="search_latency",
+ value=latency,
+ is_info_log=False,
+ )
+
+ return results.dict()
+
+ @telemetry_event("RAG")
+ async def rag(
+ self,
+ query: str,
+ rag_generation_config: GenerationConfig,
+ vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
+ kg_search_settings: KGSearchSettings = KGSearchSettings(),
+ *args,
+ **kwargs,
+ ):
+ async with manage_run(self.run_manager, "rag_app") as run_id:
+ try:
+ t0 = time.time()
+
+ # TODO - Remove these transforms once we have a better way to handle this
+ for (
+ filter,
+ value,
+ ) in vector_search_settings.search_filters.items():
+ if isinstance(value, uuid.UUID):
+ vector_search_settings.search_filters[filter] = str(
+ value
+ )
+
+ if rag_generation_config.stream:
+ t1 = time.time()
+ latency = f"{t1 - t0:.2f}"
+
+ await self.logging_connection.log(
+ log_id=run_id,
+ key="rag_generation_latency",
+ value=latency,
+ is_info_log=False,
+ )
+
+ async def stream_response():
+ async with manage_run(self.run_manager, "arag"):
+ async for (
+ chunk
+ ) in await self.pipelines.streaming_rag_pipeline.run(
+ input=to_async_generator([query]),
+ run_manager=self.run_manager,
+ vector_search_settings=vector_search_settings,
+ kg_search_settings=kg_search_settings,
+ rag_generation_config=rag_generation_config,
+ ):
+ yield chunk
+
+ return stream_response()
+
+ results = await self.pipelines.rag_pipeline.run(
+ input=to_async_generator([query]),
+ run_manager=self.run_manager,
+ vector_search_settings=vector_search_settings,
+ kg_search_settings=kg_search_settings,
+ rag_generation_config=rag_generation_config,
+ *args,
+ **kwargs,
+ )
+
+ t1 = time.time()
+ latency = f"{t1 - t0:.2f}"
+
+ await self.logging_connection.log(
+ log_id=run_id,
+ key="rag_generation_latency",
+ value=latency,
+ is_info_log=False,
+ )
+
+ if len(results) == 0:
+ raise R2RException(
+ status_code=404, message="No results found"
+ )
+ if len(results) > 1:
+ logger.warning(
+ f"Multiple results found for query: {query}"
+ )
+ # unpack the first result
+ return results[0]
+
+ except Exception as e:
+ logger.error(f"Pipeline error: {str(e)}")
+ if "NoneType" in str(e):
+ raise R2RException(
+ status_code=502,
+ message="Ollama server not reachable or returned an invalid response",
+ )
+ raise R2RException(
+ status_code=500, message="Internal Server Error"
+ )
+
+ @telemetry_event("Evaluate")
+ async def evaluate(
+ self,
+ query: str,
+ context: str,
+ completion: str,
+ eval_generation_config: Optional[GenerationConfig],
+ *args,
+ **kwargs,
+ ):
+ eval_payload = EvalPipe.EvalPayload(
+ query=query,
+ context=context,
+ completion=completion,
+ )
+ result = await self.eval_pipeline.run(
+ input=to_async_generator([eval_payload]),
+ run_manager=self.run_manager,
+ eval_generation_config=eval_generation_config,
+ )
+ return result