diff options
Diffstat (limited to 'R2R/r2r/main/api/routes')
-rwxr-xr-x | R2R/r2r/main/api/routes/__init__.py | 0 | ||||
-rwxr-xr-x | R2R/r2r/main/api/routes/base_router.py | 75 | ||||
-rwxr-xr-x | R2R/r2r/main/api/routes/ingestion.py | 42 | ||||
-rwxr-xr-x | R2R/r2r/main/api/routes/management.py | 101 | ||||
-rwxr-xr-x | R2R/r2r/main/api/routes/retrieval.py | 91 |
5 files changed, 309 insertions, 0 deletions
diff --git a/R2R/r2r/main/api/routes/__init__.py b/R2R/r2r/main/api/routes/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/main/api/routes/__init__.py diff --git a/R2R/r2r/main/api/routes/base_router.py b/R2R/r2r/main/api/routes/base_router.py new file mode 100755 index 00000000..d06a9935 --- /dev/null +++ b/R2R/r2r/main/api/routes/base_router.py @@ -0,0 +1,75 @@ +import functools +import logging + +from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse + +from r2r.base import R2RException, manage_run + +logger = logging.getLogger(__name__) + + +class BaseRouter: + def __init__(self, engine): + self.engine = engine + self.router = APIRouter() + + def base_endpoint(self, func): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + async with manage_run( + self.engine.run_manager, func.__name__ + ) as run_id: + try: + results = await func(*args, **kwargs) + if isinstance(results, StreamingResponse): + return results + + return {"results": results} + except R2RException as re: + raise HTTPException( + status_code=re.status_code, + detail={ + "message": re.message, + "error_type": type(re).__name__, + }, + ) + except Exception as e: + # Get the pipeline name based on the function name + pipeline_name = f"{func.__name__.split('_')[0]}_pipeline" + + # Safely get the pipeline object and its type + pipeline = getattr( + self.engine.pipelines, pipeline_name, None + ) + pipeline_type = getattr( + pipeline, "pipeline_type", "unknown" + ) + + await self.engine.logging_connection.log( + log_id=run_id, + key="pipeline_type", + value=pipeline_type, + is_info_log=True, + ) + await self.engine.logging_connection.log( + log_id=run_id, + key="error", + value=str(e), + is_info_log=False, + ) + logger.error(f"{func.__name__}() - \n\n{str(e)})") + raise HTTPException( + status_code=500, + detail={ + "message": f"An error occurred during {func.__name__}", + "error": str(e), + "error_type": type(e).__name__, + }, + ) from e + + return wrapper + + @classmethod + def build_router(cls, engine): + return cls(engine).router diff --git a/R2R/r2r/main/api/routes/ingestion.py b/R2R/r2r/main/api/routes/ingestion.py new file mode 100755 index 00000000..be583602 --- /dev/null +++ b/R2R/r2r/main/api/routes/ingestion.py @@ -0,0 +1,42 @@ +from fastapi import Depends, File, UploadFile + +from ...engine import R2REngine +from ...services.ingestion_service import IngestionService +from ..requests import R2RIngestFilesRequest, R2RUpdateFilesRequest +from .base_router import BaseRouter + + +class IngestionRouter(BaseRouter): + def __init__(self, engine: R2REngine): + super().__init__(engine) + self.setup_routes() + + def setup_routes(self): + @self.router.post("/ingest_files") + @self.base_endpoint + async def ingest_files_app( + files: list[UploadFile] = File(...), + request: R2RIngestFilesRequest = Depends( + IngestionService.parse_ingest_files_form_data + ), + ): + return await self.engine.aingest_files( + files=files, + metadatas=request.metadatas, + document_ids=request.document_ids, + versions=request.versions, + ) + + @self.router.post("/update_files") + @self.base_endpoint + async def update_files_app( + files: list[UploadFile] = File(...), + request: R2RUpdateFilesRequest = Depends( + IngestionService.parse_update_files_form_data + ), + ): + return await self.engine.aupdate_files( + files=files, + metadatas=request.metadatas, + document_ids=request.document_ids, + ) diff --git a/R2R/r2r/main/api/routes/management.py b/R2R/r2r/main/api/routes/management.py new file mode 100755 index 00000000..921fb534 --- /dev/null +++ b/R2R/r2r/main/api/routes/management.py @@ -0,0 +1,101 @@ +from ...engine import R2REngine +from ..requests import ( + R2RAnalyticsRequest, + R2RDeleteRequest, + R2RDocumentChunksRequest, + R2RDocumentsOverviewRequest, + R2RLogsRequest, + R2RPrintRelationshipsRequest, + R2RUpdatePromptRequest, + R2RUsersOverviewRequest, +) +from .base_router import BaseRouter + + +class ManagementRouter(BaseRouter): + def __init__(self, engine: R2REngine): + super().__init__(engine) + self.setup_routes() + + def setup_routes(self): + @self.router.get("/health") + async def health_check(): + return {"response": "ok"} + + @self.router.post("/update_prompt") + @self.base_endpoint + async def update_prompt_app(request: R2RUpdatePromptRequest): + return await self.engine.aupdate_prompt( + request.name, request.template, request.input_types + ) + + @self.router.post("/logs") + @self.router.get("/logs") + @self.base_endpoint + async def get_logs_app(request: R2RLogsRequest): + return await self.engine.alogs( + log_type_filter=request.log_type_filter, + max_runs_requested=request.max_runs_requested, + ) + + @self.router.post("/analytics") + @self.router.get("/analytics") + @self.base_endpoint + async def get_analytics_app(request: R2RAnalyticsRequest): + return await self.engine.aanalytics( + filter_criteria=request.filter_criteria, + analysis_types=request.analysis_types, + ) + + @self.router.post("/users_overview") + @self.router.get("/users_overview") + @self.base_endpoint + async def get_users_overview_app(request: R2RUsersOverviewRequest): + return await self.engine.ausers_overview(user_ids=request.user_ids) + + @self.router.delete("/delete") + @self.base_endpoint + async def delete_app(request: R2RDeleteRequest): + return await self.engine.adelete( + keys=request.keys, values=request.values + ) + + @self.router.post("/documents_overview") + @self.router.get("/documents_overview") + @self.base_endpoint + async def get_documents_overview_app( + request: R2RDocumentsOverviewRequest, + ): + return await self.engine.adocuments_overview( + document_ids=request.document_ids, user_ids=request.user_ids + ) + + @self.router.post("/document_chunks") + @self.router.get("/document_chunks") + @self.base_endpoint + async def get_document_chunks_app(request: R2RDocumentChunksRequest): + return await self.engine.adocument_chunks(request.document_id) + + @self.router.post("/inspect_knowledge_graph") + @self.router.get("/inspect_knowledge_graph") + @self.base_endpoint + async def inspect_knowledge_graph( + request: R2RPrintRelationshipsRequest, + ): + return await self.engine.inspect_knowledge_graph( + limit=request.limit + ) + + @self.router.get("/app_settings") + @self.base_endpoint + async def get_app_settings_app(): + return await self.engine.aapp_settings() + + @self.router.get("/openapi_spec") + @self.base_endpoint + def get_openapi_spec_app(): + return self.engine.openapi_spec() + + +def create_management_router(engine: R2REngine): + return ManagementRouter(engine).router diff --git a/R2R/r2r/main/api/routes/retrieval.py b/R2R/r2r/main/api/routes/retrieval.py new file mode 100755 index 00000000..b2d352aa --- /dev/null +++ b/R2R/r2r/main/api/routes/retrieval.py @@ -0,0 +1,91 @@ +from fastapi.responses import StreamingResponse + +from r2r.base import GenerationConfig, KGSearchSettings, VectorSearchSettings + +from ...engine import R2REngine +from ..requests import R2REvalRequest, R2RRAGRequest, R2RSearchRequest +from .base_router import BaseRouter + + +class RetrievalRouter(BaseRouter): + def __init__(self, engine: R2REngine): + super().__init__(engine) + self.setup_routes() + + def setup_routes(self): + @self.router.post("/search") + @self.base_endpoint + async def search_app(request: R2RSearchRequest): + if "agent_generation_config" in request.kg_search_settings: + request.kg_search_settings["agent_generation_config"] = ( + GenerationConfig( + **request.kg_search_settings["agent_generation_config"] + or {} + ) + ) + + results = await self.engine.asearch( + query=request.query, + vector_search_settings=VectorSearchSettings( + **(request.vector_search_settings or {}) + ), + kg_search_settings=KGSearchSettings( + **(request.kg_search_settings or {}) + ), + ) + return results + + @self.router.post("/rag") + @self.base_endpoint + async def rag_app(request: R2RRAGRequest): + if "agent_generation_config" in request.kg_search_settings: + request.kg_search_settings["agent_generation_config"] = ( + GenerationConfig( + **( + request.kg_search_settings[ + "agent_generation_config" + ] + or {} + ) + ) + ) + response = await self.engine.arag( + query=request.query, + vector_search_settings=VectorSearchSettings( + **(request.vector_search_settings or {}) + ), + kg_search_settings=KGSearchSettings( + **(request.kg_search_settings or {}) + ), + rag_generation_config=GenerationConfig( + **(request.rag_generation_config or {}) + ), + ) + if ( + request.rag_generation_config + and request.rag_generation_config.get("stream", False) + ): + + async def stream_generator(): + async for chunk in response: + yield chunk + + return StreamingResponse( + stream_generator(), media_type="application/json" + ) + else: + return response + + @self.router.post("/evaluate") + @self.base_endpoint + async def evaluate_app(request: R2REvalRequest): + results = await self.engine.aevaluate( + query=request.query, + context=request.context, + completion=request.completion, + ) + return results + + +def create_retrieval_router(engine: R2REngine): + return RetrievalRouter(engine).router |