diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/main/api/routes/retrieval.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz |
Diffstat (limited to 'R2R/r2r/main/api/routes/retrieval.py')
-rwxr-xr-x | R2R/r2r/main/api/routes/retrieval.py | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/R2R/r2r/main/api/routes/retrieval.py b/R2R/r2r/main/api/routes/retrieval.py new file mode 100755 index 00000000..b2d352aa --- /dev/null +++ b/R2R/r2r/main/api/routes/retrieval.py @@ -0,0 +1,91 @@ +from fastapi.responses import StreamingResponse + +from r2r.base import GenerationConfig, KGSearchSettings, VectorSearchSettings + +from ...engine import R2REngine +from ..requests import R2REvalRequest, R2RRAGRequest, R2RSearchRequest +from .base_router import BaseRouter + + +class RetrievalRouter(BaseRouter): + def __init__(self, engine: R2REngine): + super().__init__(engine) + self.setup_routes() + + def setup_routes(self): + @self.router.post("/search") + @self.base_endpoint + async def search_app(request: R2RSearchRequest): + if "agent_generation_config" in request.kg_search_settings: + request.kg_search_settings["agent_generation_config"] = ( + GenerationConfig( + **request.kg_search_settings["agent_generation_config"] + or {} + ) + ) + + results = await self.engine.asearch( + query=request.query, + vector_search_settings=VectorSearchSettings( + **(request.vector_search_settings or {}) + ), + kg_search_settings=KGSearchSettings( + **(request.kg_search_settings or {}) + ), + ) + return results + + @self.router.post("/rag") + @self.base_endpoint + async def rag_app(request: R2RRAGRequest): + if "agent_generation_config" in request.kg_search_settings: + request.kg_search_settings["agent_generation_config"] = ( + GenerationConfig( + **( + request.kg_search_settings[ + "agent_generation_config" + ] + or {} + ) + ) + ) + response = await self.engine.arag( + query=request.query, + vector_search_settings=VectorSearchSettings( + **(request.vector_search_settings or {}) + ), + kg_search_settings=KGSearchSettings( + **(request.kg_search_settings or {}) + ), + rag_generation_config=GenerationConfig( + **(request.rag_generation_config or {}) + ), + ) + if ( + request.rag_generation_config + and request.rag_generation_config.get("stream", False) + ): + + async def stream_generator(): + async for chunk in response: + yield chunk + + return StreamingResponse( + stream_generator(), media_type="application/json" + ) + else: + return response + + @self.router.post("/evaluate") + @self.base_endpoint + async def evaluate_app(request: R2REvalRequest): + results = await self.engine.aevaluate( + query=request.query, + context=request.context, + completion=request.completion, + ) + return results + + +def create_retrieval_router(engine: R2REngine): + return RetrievalRouter(engine).router |