aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/main/api/routes/base_router.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/main/api/routes/base_router.py')
-rwxr-xr-xR2R/r2r/main/api/routes/base_router.py75
1 files changed, 75 insertions, 0 deletions
diff --git a/R2R/r2r/main/api/routes/base_router.py b/R2R/r2r/main/api/routes/base_router.py
new file mode 100755
index 00000000..d06a9935
--- /dev/null
+++ b/R2R/r2r/main/api/routes/base_router.py
@@ -0,0 +1,75 @@
+import functools
+import logging
+
+from fastapi import APIRouter, HTTPException
+from fastapi.responses import StreamingResponse
+
+from r2r.base import R2RException, manage_run
+
+logger = logging.getLogger(__name__)
+
+
+class BaseRouter:
+ def __init__(self, engine):
+ self.engine = engine
+ self.router = APIRouter()
+
+ def base_endpoint(self, func):
+ @functools.wraps(func)
+ async def wrapper(*args, **kwargs):
+ async with manage_run(
+ self.engine.run_manager, func.__name__
+ ) as run_id:
+ try:
+ results = await func(*args, **kwargs)
+ if isinstance(results, StreamingResponse):
+ return results
+
+ return {"results": results}
+ except R2RException as re:
+ raise HTTPException(
+ status_code=re.status_code,
+ detail={
+ "message": re.message,
+ "error_type": type(re).__name__,
+ },
+ )
+ except Exception as e:
+ # Get the pipeline name based on the function name
+ pipeline_name = f"{func.__name__.split('_')[0]}_pipeline"
+
+ # Safely get the pipeline object and its type
+ pipeline = getattr(
+ self.engine.pipelines, pipeline_name, None
+ )
+ pipeline_type = getattr(
+ pipeline, "pipeline_type", "unknown"
+ )
+
+ await self.engine.logging_connection.log(
+ log_id=run_id,
+ key="pipeline_type",
+ value=pipeline_type,
+ is_info_log=True,
+ )
+ await self.engine.logging_connection.log(
+ log_id=run_id,
+ key="error",
+ value=str(e),
+ is_info_log=False,
+ )
+ logger.error(f"{func.__name__}() - \n\n{str(e)})")
+ raise HTTPException(
+ status_code=500,
+ detail={
+ "message": f"An error occurred during {func.__name__}",
+ "error": str(e),
+ "error_type": type(e).__name__,
+ },
+ ) from e
+
+ return wrapper
+
+ @classmethod
+ def build_router(cls, engine):
+ return cls(engine).router