aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/main/api/routes/retrieval.py
blob: b2d352aa472ad45511612b008984f94f04d189ed (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
from fastapi.responses import StreamingResponse

from r2r.base import GenerationConfig, KGSearchSettings, VectorSearchSettings

from ...engine import R2REngine
from ..requests import R2REvalRequest, R2RRAGRequest, R2RSearchRequest
from .base_router import BaseRouter


class RetrievalRouter(BaseRouter):
    def __init__(self, engine: R2REngine):
        super().__init__(engine)
        self.setup_routes()

    def setup_routes(self):
        @self.router.post("/search")
        @self.base_endpoint
        async def search_app(request: R2RSearchRequest):
            if "agent_generation_config" in request.kg_search_settings:
                request.kg_search_settings["agent_generation_config"] = (
                    GenerationConfig(
                        **request.kg_search_settings["agent_generation_config"]
                        or {}
                    )
                )

            results = await self.engine.asearch(
                query=request.query,
                vector_search_settings=VectorSearchSettings(
                    **(request.vector_search_settings or {})
                ),
                kg_search_settings=KGSearchSettings(
                    **(request.kg_search_settings or {})
                ),
            )
            return results

        @self.router.post("/rag")
        @self.base_endpoint
        async def rag_app(request: R2RRAGRequest):
            if "agent_generation_config" in request.kg_search_settings:
                request.kg_search_settings["agent_generation_config"] = (
                    GenerationConfig(
                        **(
                            request.kg_search_settings[
                                "agent_generation_config"
                            ]
                            or {}
                        )
                    )
                )
            response = await self.engine.arag(
                query=request.query,
                vector_search_settings=VectorSearchSettings(
                    **(request.vector_search_settings or {})
                ),
                kg_search_settings=KGSearchSettings(
                    **(request.kg_search_settings or {})
                ),
                rag_generation_config=GenerationConfig(
                    **(request.rag_generation_config or {})
                ),
            )
            if (
                request.rag_generation_config
                and request.rag_generation_config.get("stream", False)
            ):

                async def stream_generator():
                    async for chunk in response:
                        yield chunk

                return StreamingResponse(
                    stream_generator(), media_type="application/json"
                )
            else:
                return response

        @self.router.post("/evaluate")
        @self.base_endpoint
        async def evaluate_app(request: R2REvalRequest):
            results = await self.engine.aevaluate(
                query=request.query,
                context=request.context,
                completion=request.completion,
            )
            return results


def create_retrieval_router(engine: R2REngine):
    return RetrievalRouter(engine).router