about summary refs log tree commit diff
path: root/R2R/r2r/main/api
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/main/api')
-rwxr-xr-xR2R/r2r/main/api/__init__.py0
-rwxr-xr-xR2R/r2r/main/api/client.py377
-rwxr-xr-xR2R/r2r/main/api/requests.py79
-rwxr-xr-xR2R/r2r/main/api/routes/__init__.py0
-rwxr-xr-xR2R/r2r/main/api/routes/base_router.py75
-rwxr-xr-xR2R/r2r/main/api/routes/ingestion.py42
-rwxr-xr-xR2R/r2r/main/api/routes/management.py101
-rwxr-xr-xR2R/r2r/main/api/routes/retrieval.py91
8 files changed, 765 insertions, 0 deletions
diff --git a/R2R/r2r/main/api/__init__.py b/R2R/r2r/main/api/__init__.py
new file mode 100755
index 00000000..e69de29b
--- /dev/null
+++ b/R2R/r2r/main/api/__init__.py
diff --git a/R2R/r2r/main/api/client.py b/R2R/r2r/main/api/client.py
new file mode 100755
index 00000000..b0f5b966
--- /dev/null
+++ b/R2R/r2r/main/api/client.py
@@ -0,0 +1,377 @@
+import asyncio
+import functools
+import json
+import os
+import threading
+import time
+import uuid
+from contextlib import ExitStack
+from typing import Any, AsyncGenerator, Generator, Optional, Union
+
+import fire
+import httpx
+import nest_asyncio
+import requests
+
+from .requests import (
+    R2RAnalyticsRequest,
+    R2RDeleteRequest,
+    R2RDocumentChunksRequest,
+    R2RDocumentsOverviewRequest,
+    R2RIngestFilesRequest,
+    R2RLogsRequest,
+    R2RPrintRelationshipsRequest,
+    R2RRAGRequest,
+    R2RSearchRequest,
+    R2RUpdateFilesRequest,
+    R2RUpdatePromptRequest,
+    R2RUsersOverviewRequest,
+)
+
+nest_asyncio.apply()
+
+
+class R2RHTTPError(Exception):
+    def __init__(self, status_code, error_type, message):
+        self.status_code = status_code
+        self.error_type = error_type
+        self.message = message
+        super().__init__(f"[{status_code}] {error_type}: {message}")
+
+
+def handle_request_error(response):
+    if response.status_code >= 400:
+        try:
+            error_content = response.json()
+            if isinstance(error_content, dict) and "detail" in error_content:
+                detail = error_content["detail"]
+                if isinstance(detail, dict):
+                    message = detail.get("message", str(response.text))
+                    error_type = detail.get("error_type", "UnknownError")
+                else:
+                    message = str(detail)
+                    error_type = "HTTPException"
+            else:
+                message = str(error_content)
+                error_type = "UnknownError"
+        except json.JSONDecodeError:
+            message = response.text
+            error_type = "UnknownError"
+
+        raise R2RHTTPError(
+            status_code=response.status_code,
+            error_type=error_type,
+            message=message,
+        )
+
+
+def monitor_request(func):
+    @functools.wraps(func)
+    def wrapper(*args, monitor=False, **kwargs):
+        if not monitor:
+            return func(*args, **kwargs)
+
+        result = None
+        exception = None
+
+        def run_func():
+            nonlocal result, exception
+            try:
+                result = func(*args, **kwargs)
+            except Exception as e:
+                exception = e
+
+        thread = threading.Thread(target=run_func)
+        thread.start()
+
+        dots = [".", "..", "..."]
+        i = 0
+        while thread.is_alive():
+            print(f"\rRequesting{dots[i % 3]}", end="", flush=True)
+            i += 1
+            time.sleep(0.5)
+
+        thread.join()
+
+        print("\r", end="", flush=True)
+
+        if exception:
+            raise exception
+        return result
+
+    return wrapper
+
+
+class R2RClient:
+    def __init__(self, base_url: str, prefix: str = "/v1"):
+        self.base_url = base_url
+        self.prefix = prefix
+
+    def _make_request(self, method, endpoint, **kwargs):
+        url = f"{self.base_url}{self.prefix}/{endpoint}"
+        response = requests.request(method, url, **kwargs)
+        handle_request_error(response)
+        return response.json()
+
+    def health(self) -> dict:
+        return self._make_request("GET", "health")
+
+    def update_prompt(
+        self,
+        name: str = "default_system",
+        template: Optional[str] = None,
+        input_types: Optional[dict] = None,
+    ) -> dict:
+        request = R2RUpdatePromptRequest(
+            name=name, template=template, input_types=input_types
+        )
+        return self._make_request(
+            "POST", "update_prompt", json=json.loads(request.json())
+        )
+
+    @monitor_request
+    def ingest_files(
+        self,
+        file_paths: list[str],
+        metadatas: Optional[list[dict]] = None,
+        document_ids: Optional[list[Union[uuid.UUID, str]]] = None,
+        versions: Optional[list[str]] = None,
+    ) -> dict:
+        all_file_paths = []
+
+        for path in file_paths:
+            if os.path.isdir(path):
+                for root, _, files in os.walk(path):
+                    all_file_paths.extend(
+                        os.path.join(root, file) for file in files
+                    )
+            else:
+                all_file_paths.append(path)
+
+        files_to_upload = [
+            (
+                "files",
+                (
+                    os.path.basename(file),
+                    open(file, "rb"),
+                    "application/octet-stream",
+                ),
+            )
+            for file in all_file_paths
+        ]
+        request = R2RIngestFilesRequest(
+            metadatas=metadatas,
+            document_ids=(
+                [str(ele) for ele in document_ids] if document_ids else None
+            ),
+            versions=versions,
+        )
+        try:
+            return self._make_request(
+                "POST",
+                "ingest_files",
+                data={
+                    k: json.dumps(v)
+                    for k, v in json.loads(request.json()).items()
+                },
+                files=files_to_upload,
+            )
+        finally:
+            for _, file_tuple in files_to_upload:
+                file_tuple[1].close()
+
+    @monitor_request
+    def update_files(
+        self,
+        file_paths: list[str],
+        document_ids: list[str],
+        metadatas: Optional[list[dict]] = None,
+    ) -> dict:
+        request = R2RUpdateFilesRequest(
+            metadatas=metadatas,
+            document_ids=document_ids,
+        )
+        with ExitStack() as stack:
+            return self._make_request(
+                "POST",
+                "update_files",
+                data={
+                    k: json.dumps(v)
+                    for k, v in json.loads(request.json()).items()
+                },
+                files=[
+                    (
+                        "files",
+                        (
+                            path.split("/")[-1],
+                            stack.enter_context(open(path, "rb")),
+                            "application/octet-stream",
+                        ),
+                    )
+                    for path in file_paths
+                ],
+            )
+
+    def search(
+        self,
+        query: str,
+        use_vector_search: bool = True,
+        search_filters: Optional[dict[str, Any]] = {},
+        search_limit: int = 10,
+        do_hybrid_search: bool = False,
+        use_kg_search: bool = False,
+        kg_agent_generation_config: Optional[dict] = None,
+    ) -> dict:
+        request = R2RSearchRequest(
+            query=query,
+            vector_search_settings={
+                "use_vector_search": use_vector_search,
+                "search_filters": search_filters or {},
+                "search_limit": search_limit,
+                "do_hybrid_search": do_hybrid_search,
+            },
+            kg_search_settings={
+                "use_kg_search": use_kg_search,
+                "agent_generation_config": kg_agent_generation_config,
+            },
+        )
+        return self._make_request(
+            "POST", "search", json=json.loads(request.json())
+        )
+
+    def rag(
+        self,
+        query: str,
+        use_vector_search: bool = True,
+        search_filters: Optional[dict[str, Any]] = {},
+        search_limit: int = 10,
+        do_hybrid_search: bool = False,
+        use_kg_search: bool = False,
+        kg_agent_generation_config: Optional[dict] = None,
+        rag_generation_config: Optional[dict] = None,
+    ) -> dict:
+        request = R2RRAGRequest(
+            query=query,
+            vector_search_settings={
+                "use_vector_search": use_vector_search,
+                "search_filters": search_filters or {},
+                "search_limit": search_limit,
+                "do_hybrid_search": do_hybrid_search,
+            },
+            kg_search_settings={
+                "use_kg_search": use_kg_search,
+                "agent_generation_config": kg_agent_generation_config,
+            },
+            rag_generation_config=rag_generation_config,
+        )
+
+        if rag_generation_config and rag_generation_config.get(
+            "stream", False
+        ):
+            return self._stream_rag_sync(request)
+        else:
+            return self._make_request(
+                "POST", "rag", json=json.loads(request.json())
+            )
+
+    async def _stream_rag(
+        self, rag_request: R2RRAGRequest
+    ) -> AsyncGenerator[str, None]:
+        url = f"{self.base_url}{self.prefix}/rag"
+        async with httpx.AsyncClient() as client:
+            async with client.stream(
+                "POST", url, json=json.loads(rag_request.json())
+            ) as response:
+                handle_request_error(response)
+                async for chunk in response.aiter_text():
+                    yield chunk
+
+    def _stream_rag_sync(
+        self, rag_request: R2RRAGRequest
+    ) -> Generator[str, None, None]:
+        async def run_async_generator():
+            async for chunk in self._stream_rag(rag_request):
+                yield chunk
+
+        loop = asyncio.new_event_loop()
+        asyncio.set_event_loop(loop)
+
+        async_gen = run_async_generator()
+
+        try:
+            while True:
+                chunk = loop.run_until_complete(async_gen.__anext__())
+                yield chunk
+        except StopAsyncIteration:
+            pass
+        finally:
+            loop.close()
+
+    def delete(
+        self, keys: list[str], values: list[Union[bool, int, str]]
+    ) -> dict:
+        request = R2RDeleteRequest(keys=keys, values=values)
+        return self._make_request(
+            "DELETE", "delete", json=json.loads(request.json())
+        )
+
+    def logs(self, log_type_filter: Optional[str] = None) -> dict:
+        request = R2RLogsRequest(log_type_filter=log_type_filter)
+        return self._make_request(
+            "GET", "logs", json=json.loads(request.json())
+        )
+
+    def app_settings(self) -> dict:
+        return self._make_request("GET", "app_settings")
+
+    def analytics(self, filter_criteria: dict, analysis_types: dict) -> dict:
+        request = R2RAnalyticsRequest(
+            filter_criteria=filter_criteria, analysis_types=analysis_types
+        )
+        return self._make_request(
+            "GET", "analytics", json=json.loads(request.json())
+        )
+
+    def users_overview(
+        self, user_ids: Optional[list[uuid.UUID]] = None
+    ) -> dict:
+        request = R2RUsersOverviewRequest(user_ids=user_ids)
+        return self._make_request(
+            "GET", "users_overview", json=json.loads(request.json())
+        )
+
+    def documents_overview(
+        self,
+        document_ids: Optional[list[str]] = None,
+        user_ids: Optional[list[str]] = None,
+    ) -> dict:
+        request = R2RDocumentsOverviewRequest(
+            document_ids=(
+                [uuid.UUID(did) for did in document_ids]
+                if document_ids
+                else None
+            ),
+            user_ids=(
+                [uuid.UUID(uid) for uid in user_ids] if user_ids else None
+            ),
+        )
+        return self._make_request(
+            "GET", "documents_overview", json=json.loads(request.json())
+        )
+
+    def document_chunks(self, document_id: str) -> dict:
+        request = R2RDocumentChunksRequest(document_id=document_id)
+        return self._make_request(
+            "GET", "document_chunks", json=json.loads(request.json())
+        )
+
+    def inspect_knowledge_graph(self, limit: int = 100) -> str:
+        request = R2RPrintRelationshipsRequest(limit=limit)
+        return self._make_request(
+            "POST", "inspect_knowledge_graph", json=json.loads(request.json())
+        )
+
+
+if __name__ == "__main__":
+    client = R2RClient(base_url="http://localhost:8000")
+    fire.Fire(client)
diff --git a/R2R/r2r/main/api/requests.py b/R2R/r2r/main/api/requests.py
new file mode 100755
index 00000000..5c63ab82
--- /dev/null
+++ b/R2R/r2r/main/api/requests.py
@@ -0,0 +1,79 @@
+import uuid
+from typing import Optional, Union
+
+from pydantic import BaseModel
+
+from r2r.base import AnalysisTypes, FilterCriteria
+
+
+class R2RUpdatePromptRequest(BaseModel):
+    name: str
+    template: Optional[str] = None
+    input_types: Optional[dict[str, str]] = {}
+
+
+class R2RIngestFilesRequest(BaseModel):
+    document_ids: Optional[list[uuid.UUID]] = None
+    metadatas: Optional[list[dict]] = None
+    versions: Optional[list[str]] = None
+
+
+class R2RUpdateFilesRequest(BaseModel):
+    metadatas: Optional[list[dict]] = None
+    document_ids: Optional[list[uuid.UUID]] = None
+
+
+class R2RSearchRequest(BaseModel):
+    query: str
+    vector_search_settings: Optional[dict] = None
+    kg_search_settings: Optional[dict] = None
+
+
+class R2RRAGRequest(BaseModel):
+    query: str
+    vector_search_settings: Optional[dict] = None
+    kg_search_settings: Optional[dict] = None
+    rag_generation_config: Optional[dict] = None
+
+
+class R2REvalRequest(BaseModel):
+    query: str
+    context: str
+    completion: str
+
+
+class R2RDeleteRequest(BaseModel):
+    keys: list[str]
+    values: list[Union[bool, int, str]]
+
+
+class R2RAnalyticsRequest(BaseModel):
+    filter_criteria: FilterCriteria
+    analysis_types: AnalysisTypes
+
+
+class R2RUsersOverviewRequest(BaseModel):
+    user_ids: Optional[list[uuid.UUID]]
+
+
+class R2RDocumentsOverviewRequest(BaseModel):
+    document_ids: Optional[list[uuid.UUID]]
+    user_ids: Optional[list[uuid.UUID]]
+
+
+class R2RDocumentChunksRequest(BaseModel):
+    document_id: uuid.UUID
+
+
+class R2RLogsRequest(BaseModel):
+    log_type_filter: Optional[str] = (None,)
+    max_runs_requested: int = 100
+
+
+class R2RPrintRelationshipsRequest(BaseModel):
+    limit: int = 100
+
+
+class R2RExtractionRequest(BaseModel):
+    entity_types: list[str]
+    relations: list[str]
diff --git a/R2R/r2r/main/api/routes/__init__.py b/R2R/r2r/main/api/routes/__init__.py
new file mode 100755
index 00000000..e69de29b
--- /dev/null
+++ b/R2R/r2r/main/api/routes/__init__.py
diff --git a/R2R/r2r/main/api/routes/base_router.py b/R2R/r2r/main/api/routes/base_router.py
new file mode 100755
index 00000000..d06a9935
--- /dev/null
+++ b/R2R/r2r/main/api/routes/base_router.py
@@ -0,0 +1,75 @@
+import functools
+import logging
+
+from fastapi import APIRouter, HTTPException
+from fastapi.responses import StreamingResponse
+
+from r2r.base import R2RException, manage_run
+
+logger = logging.getLogger(__name__)
+
+
+class BaseRouter:
+    def __init__(self, engine):
+        self.engine = engine
+        self.router = APIRouter()
+
+    def base_endpoint(self, func):
+        @functools.wraps(func)
+        async def wrapper(*args, **kwargs):
+            async with manage_run(
+                self.engine.run_manager, func.__name__
+            ) as run_id:
+                try:
+                    results = await func(*args, **kwargs)
+                    if isinstance(results, StreamingResponse):
+                        return results
+
+                    return {"results": results}
+                except R2RException as re:
+                    raise HTTPException(
+                        status_code=re.status_code,
+                        detail={
+                            "message": re.message,
+                            "error_type": type(re).__name__,
+                        },
+                    )
+                except Exception as e:
+                    # Get the pipeline name based on the function name
+                    pipeline_name = f"{func.__name__.split('_')[0]}_pipeline"
+
+                    # Safely get the pipeline object and its type
+                    pipeline = getattr(
+                        self.engine.pipelines, pipeline_name, None
+                    )
+                    pipeline_type = getattr(
+                        pipeline, "pipeline_type", "unknown"
+                    )
+
+                    await self.engine.logging_connection.log(
+                        log_id=run_id,
+                        key="pipeline_type",
+                        value=pipeline_type,
+                        is_info_log=True,
+                    )
+                    await self.engine.logging_connection.log(
+                        log_id=run_id,
+                        key="error",
+                        value=str(e),
+                        is_info_log=False,
+                    )
+                    logger.error(f"{func.__name__}() - \n\n{str(e)})")
+                    raise HTTPException(
+                        status_code=500,
+                        detail={
+                            "message": f"An error occurred during {func.__name__}",
+                            "error": str(e),
+                            "error_type": type(e).__name__,
+                        },
+                    ) from e
+
+        return wrapper
+
+    @classmethod
+    def build_router(cls, engine):
+        return cls(engine).router
diff --git a/R2R/r2r/main/api/routes/ingestion.py b/R2R/r2r/main/api/routes/ingestion.py
new file mode 100755
index 00000000..be583602
--- /dev/null
+++ b/R2R/r2r/main/api/routes/ingestion.py
@@ -0,0 +1,42 @@
+from fastapi import Depends, File, UploadFile
+
+from ...engine import R2REngine
+from ...services.ingestion_service import IngestionService
+from ..requests import R2RIngestFilesRequest, R2RUpdateFilesRequest
+from .base_router import BaseRouter
+
+
+class IngestionRouter(BaseRouter):
+    def __init__(self, engine: R2REngine):
+        super().__init__(engine)
+        self.setup_routes()
+
+    def setup_routes(self):
+        @self.router.post("/ingest_files")
+        @self.base_endpoint
+        async def ingest_files_app(
+            files: list[UploadFile] = File(...),
+            request: R2RIngestFilesRequest = Depends(
+                IngestionService.parse_ingest_files_form_data
+            ),
+        ):
+            return await self.engine.aingest_files(
+                files=files,
+                metadatas=request.metadatas,
+                document_ids=request.document_ids,
+                versions=request.versions,
+            )
+
+        @self.router.post("/update_files")
+        @self.base_endpoint
+        async def update_files_app(
+            files: list[UploadFile] = File(...),
+            request: R2RUpdateFilesRequest = Depends(
+                IngestionService.parse_update_files_form_data
+            ),
+        ):
+            return await self.engine.aupdate_files(
+                files=files,
+                metadatas=request.metadatas,
+                document_ids=request.document_ids,
+            )
diff --git a/R2R/r2r/main/api/routes/management.py b/R2R/r2r/main/api/routes/management.py
new file mode 100755
index 00000000..921fb534
--- /dev/null
+++ b/R2R/r2r/main/api/routes/management.py
@@ -0,0 +1,101 @@
+from ...engine import R2REngine
+from ..requests import (
+    R2RAnalyticsRequest,
+    R2RDeleteRequest,
+    R2RDocumentChunksRequest,
+    R2RDocumentsOverviewRequest,
+    R2RLogsRequest,
+    R2RPrintRelationshipsRequest,
+    R2RUpdatePromptRequest,
+    R2RUsersOverviewRequest,
+)
+from .base_router import BaseRouter
+
+
+class ManagementRouter(BaseRouter):
+    def __init__(self, engine: R2REngine):
+        super().__init__(engine)
+        self.setup_routes()
+
+    def setup_routes(self):
+        @self.router.get("/health")
+        async def health_check():
+            return {"response": "ok"}
+
+        @self.router.post("/update_prompt")
+        @self.base_endpoint
+        async def update_prompt_app(request: R2RUpdatePromptRequest):
+            return await self.engine.aupdate_prompt(
+                request.name, request.template, request.input_types
+            )
+
+        @self.router.post("/logs")
+        @self.router.get("/logs")
+        @self.base_endpoint
+        async def get_logs_app(request: R2RLogsRequest):
+            return await self.engine.alogs(
+                log_type_filter=request.log_type_filter,
+                max_runs_requested=request.max_runs_requested,
+            )
+
+        @self.router.post("/analytics")
+        @self.router.get("/analytics")
+        @self.base_endpoint
+        async def get_analytics_app(request: R2RAnalyticsRequest):
+            return await self.engine.aanalytics(
+                filter_criteria=request.filter_criteria,
+                analysis_types=request.analysis_types,
+            )
+
+        @self.router.post("/users_overview")
+        @self.router.get("/users_overview")
+        @self.base_endpoint
+        async def get_users_overview_app(request: R2RUsersOverviewRequest):
+            return await self.engine.ausers_overview(user_ids=request.user_ids)
+
+        @self.router.delete("/delete")
+        @self.base_endpoint
+        async def delete_app(request: R2RDeleteRequest):
+            return await self.engine.adelete(
+                keys=request.keys, values=request.values
+            )
+
+        @self.router.post("/documents_overview")
+        @self.router.get("/documents_overview")
+        @self.base_endpoint
+        async def get_documents_overview_app(
+            request: R2RDocumentsOverviewRequest,
+        ):
+            return await self.engine.adocuments_overview(
+                document_ids=request.document_ids, user_ids=request.user_ids
+            )
+
+        @self.router.post("/document_chunks")
+        @self.router.get("/document_chunks")
+        @self.base_endpoint
+        async def get_document_chunks_app(request: R2RDocumentChunksRequest):
+            return await self.engine.adocument_chunks(request.document_id)
+
+        @self.router.post("/inspect_knowledge_graph")
+        @self.router.get("/inspect_knowledge_graph")
+        @self.base_endpoint
+        async def inspect_knowledge_graph(
+            request: R2RPrintRelationshipsRequest,
+        ):
+            return await self.engine.inspect_knowledge_graph(
+                limit=request.limit
+            )
+
+        @self.router.get("/app_settings")
+        @self.base_endpoint
+        async def get_app_settings_app():
+            return await self.engine.aapp_settings()
+
+        @self.router.get("/openapi_spec")
+        @self.base_endpoint
+        def get_openapi_spec_app():
+            return self.engine.openapi_spec()
+
+
+def create_management_router(engine: R2REngine):
+    return ManagementRouter(engine).router
diff --git a/R2R/r2r/main/api/routes/retrieval.py b/R2R/r2r/main/api/routes/retrieval.py
new file mode 100755
index 00000000..b2d352aa
--- /dev/null
+++ b/R2R/r2r/main/api/routes/retrieval.py
@@ -0,0 +1,91 @@
+from fastapi.responses import StreamingResponse
+
+from r2r.base import GenerationConfig, KGSearchSettings, VectorSearchSettings
+
+from ...engine import R2REngine
+from ..requests import R2REvalRequest, R2RRAGRequest, R2RSearchRequest
+from .base_router import BaseRouter
+
+
+class RetrievalRouter(BaseRouter):
+    def __init__(self, engine: R2REngine):
+        super().__init__(engine)
+        self.setup_routes()
+
+    def setup_routes(self):
+        @self.router.post("/search")
+        @self.base_endpoint
+        async def search_app(request: R2RSearchRequest):
+            if "agent_generation_config" in request.kg_search_settings:
+                request.kg_search_settings["agent_generation_config"] = (
+                    GenerationConfig(
+                        **request.kg_search_settings["agent_generation_config"]
+                        or {}
+                    )
+                )
+
+            results = await self.engine.asearch(
+                query=request.query,
+                vector_search_settings=VectorSearchSettings(
+                    **(request.vector_search_settings or {})
+                ),
+                kg_search_settings=KGSearchSettings(
+                    **(request.kg_search_settings or {})
+                ),
+            )
+            return results
+
+        @self.router.post("/rag")
+        @self.base_endpoint
+        async def rag_app(request: R2RRAGRequest):
+            if "agent_generation_config" in request.kg_search_settings:
+                request.kg_search_settings["agent_generation_config"] = (
+                    GenerationConfig(
+                        **(
+                            request.kg_search_settings[
+                                "agent_generation_config"
+                            ]
+                            or {}
+                        )
+                    )
+                )
+            response = await self.engine.arag(
+                query=request.query,
+                vector_search_settings=VectorSearchSettings(
+                    **(request.vector_search_settings or {})
+                ),
+                kg_search_settings=KGSearchSettings(
+                    **(request.kg_search_settings or {})
+                ),
+                rag_generation_config=GenerationConfig(
+                    **(request.rag_generation_config or {})
+                ),
+            )
+            if (
+                request.rag_generation_config
+                and request.rag_generation_config.get("stream", False)
+            ):
+
+                async def stream_generator():
+                    async for chunk in response:
+                        yield chunk
+
+                return StreamingResponse(
+                    stream_generator(), media_type="application/json"
+                )
+            else:
+                return response
+
+        @self.router.post("/evaluate")
+        @self.base_endpoint
+        async def evaluate_app(request: R2REvalRequest):
+            results = await self.engine.aevaluate(
+                query=request.query,
+                context=request.context,
+                completion=request.completion,
+            )
+            return results
+
+
+def create_retrieval_router(engine: R2REngine):
+    return RetrievalRouter(engine).router