about summary refs log tree commit diff
path: root/R2R/r2r/main/api/routes/retrieval.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/main/api/routes/retrieval.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/r2r/main/api/routes/retrieval.py')
-rwxr-xr-xR2R/r2r/main/api/routes/retrieval.py91
1 files changed, 91 insertions, 0 deletions
diff --git a/R2R/r2r/main/api/routes/retrieval.py b/R2R/r2r/main/api/routes/retrieval.py
new file mode 100755
index 00000000..b2d352aa
--- /dev/null
+++ b/R2R/r2r/main/api/routes/retrieval.py
@@ -0,0 +1,91 @@
+from fastapi.responses import StreamingResponse
+
+from r2r.base import GenerationConfig, KGSearchSettings, VectorSearchSettings
+
+from ...engine import R2REngine
+from ..requests import R2REvalRequest, R2RRAGRequest, R2RSearchRequest
+from .base_router import BaseRouter
+
+
+class RetrievalRouter(BaseRouter):
+    def __init__(self, engine: R2REngine):
+        super().__init__(engine)
+        self.setup_routes()
+
+    def setup_routes(self):
+        @self.router.post("/search")
+        @self.base_endpoint
+        async def search_app(request: R2RSearchRequest):
+            if "agent_generation_config" in request.kg_search_settings:
+                request.kg_search_settings["agent_generation_config"] = (
+                    GenerationConfig(
+                        **request.kg_search_settings["agent_generation_config"]
+                        or {}
+                    )
+                )
+
+            results = await self.engine.asearch(
+                query=request.query,
+                vector_search_settings=VectorSearchSettings(
+                    **(request.vector_search_settings or {})
+                ),
+                kg_search_settings=KGSearchSettings(
+                    **(request.kg_search_settings or {})
+                ),
+            )
+            return results
+
+        @self.router.post("/rag")
+        @self.base_endpoint
+        async def rag_app(request: R2RRAGRequest):
+            if "agent_generation_config" in request.kg_search_settings:
+                request.kg_search_settings["agent_generation_config"] = (
+                    GenerationConfig(
+                        **(
+                            request.kg_search_settings[
+                                "agent_generation_config"
+                            ]
+                            or {}
+                        )
+                    )
+                )
+            response = await self.engine.arag(
+                query=request.query,
+                vector_search_settings=VectorSearchSettings(
+                    **(request.vector_search_settings or {})
+                ),
+                kg_search_settings=KGSearchSettings(
+                    **(request.kg_search_settings or {})
+                ),
+                rag_generation_config=GenerationConfig(
+                    **(request.rag_generation_config or {})
+                ),
+            )
+            if (
+                request.rag_generation_config
+                and request.rag_generation_config.get("stream", False)
+            ):
+
+                async def stream_generator():
+                    async for chunk in response:
+                        yield chunk
+
+                return StreamingResponse(
+                    stream_generator(), media_type="application/json"
+                )
+            else:
+                return response
+
+        @self.router.post("/evaluate")
+        @self.base_endpoint
+        async def evaluate_app(request: R2REvalRequest):
+            results = await self.engine.aevaluate(
+                query=request.query,
+                context=request.context,
+                completion=request.completion,
+            )
+            return results
+
+
+def create_retrieval_router(engine: R2REngine):
+    return RetrievalRouter(engine).router