about summary refs log tree commit diff
path: root/R2R/r2r/main/api/routes/base_router.py
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/main/api/routes/base_router.py')
-rwxr-xr-xR2R/r2r/main/api/routes/base_router.py75
1 files changed, 75 insertions, 0 deletions
diff --git a/R2R/r2r/main/api/routes/base_router.py b/R2R/r2r/main/api/routes/base_router.py
new file mode 100755
index 00000000..d06a9935
--- /dev/null
+++ b/R2R/r2r/main/api/routes/base_router.py
@@ -0,0 +1,75 @@
+import functools
+import logging
+
+from fastapi import APIRouter, HTTPException
+from fastapi.responses import StreamingResponse
+
+from r2r.base import R2RException, manage_run
+
+logger = logging.getLogger(__name__)
+
+
+class BaseRouter:
+    def __init__(self, engine):
+        self.engine = engine
+        self.router = APIRouter()
+
+    def base_endpoint(self, func):
+        @functools.wraps(func)
+        async def wrapper(*args, **kwargs):
+            async with manage_run(
+                self.engine.run_manager, func.__name__
+            ) as run_id:
+                try:
+                    results = await func(*args, **kwargs)
+                    if isinstance(results, StreamingResponse):
+                        return results
+
+                    return {"results": results}
+                except R2RException as re:
+                    raise HTTPException(
+                        status_code=re.status_code,
+                        detail={
+                            "message": re.message,
+                            "error_type": type(re).__name__,
+                        },
+                    )
+                except Exception as e:
+                    # Get the pipeline name based on the function name
+                    pipeline_name = f"{func.__name__.split('_')[0]}_pipeline"
+
+                    # Safely get the pipeline object and its type
+                    pipeline = getattr(
+                        self.engine.pipelines, pipeline_name, None
+                    )
+                    pipeline_type = getattr(
+                        pipeline, "pipeline_type", "unknown"
+                    )
+
+                    await self.engine.logging_connection.log(
+                        log_id=run_id,
+                        key="pipeline_type",
+                        value=pipeline_type,
+                        is_info_log=True,
+                    )
+                    await self.engine.logging_connection.log(
+                        log_id=run_id,
+                        key="error",
+                        value=str(e),
+                        is_info_log=False,
+                    )
+                    logger.error(f"{func.__name__}() - \n\n{str(e)})")
+                    raise HTTPException(
+                        status_code=500,
+                        detail={
+                            "message": f"An error occurred during {func.__name__}",
+                            "error": str(e),
+                            "error_type": type(e).__name__,
+                        },
+                    ) from e
+
+        return wrapper
+
+    @classmethod
+    def build_router(cls, engine):
+        return cls(engine).router