1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
|
from fastapi.responses import StreamingResponse
from r2r.base import GenerationConfig, KGSearchSettings, VectorSearchSettings
from ...engine import R2REngine
from ..requests import R2REvalRequest, R2RRAGRequest, R2RSearchRequest
from .base_router import BaseRouter
class RetrievalRouter(BaseRouter):
def __init__(self, engine: R2REngine):
super().__init__(engine)
self.setup_routes()
def setup_routes(self):
@self.router.post("/search")
@self.base_endpoint
async def search_app(request: R2RSearchRequest):
if "agent_generation_config" in request.kg_search_settings:
request.kg_search_settings["agent_generation_config"] = (
GenerationConfig(
**request.kg_search_settings["agent_generation_config"]
or {}
)
)
results = await self.engine.asearch(
query=request.query,
vector_search_settings=VectorSearchSettings(
**(request.vector_search_settings or {})
),
kg_search_settings=KGSearchSettings(
**(request.kg_search_settings or {})
),
)
return results
@self.router.post("/rag")
@self.base_endpoint
async def rag_app(request: R2RRAGRequest):
if "agent_generation_config" in request.kg_search_settings:
request.kg_search_settings["agent_generation_config"] = (
GenerationConfig(
**(
request.kg_search_settings[
"agent_generation_config"
]
or {}
)
)
)
response = await self.engine.arag(
query=request.query,
vector_search_settings=VectorSearchSettings(
**(request.vector_search_settings or {})
),
kg_search_settings=KGSearchSettings(
**(request.kg_search_settings or {})
),
rag_generation_config=GenerationConfig(
**(request.rag_generation_config or {})
),
)
if (
request.rag_generation_config
and request.rag_generation_config.get("stream", False)
):
async def stream_generator():
async for chunk in response:
yield chunk
return StreamingResponse(
stream_generator(), media_type="application/json"
)
else:
return response
@self.router.post("/evaluate")
@self.base_endpoint
async def evaluate_app(request: R2REvalRequest):
results = await self.engine.aevaluate(
query=request.query,
context=request.context,
completion=request.completion,
)
return results
def create_retrieval_router(engine: R2REngine):
return RetrievalRouter(engine).router
|