aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/main/api/routes/base_router.py
blob: d06a99358f87cfcfabba0241fcc00144b2ff35ed (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import functools
import logging

from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse

from r2r.base import R2RException, manage_run

logger = logging.getLogger(__name__)


class BaseRouter:
    def __init__(self, engine):
        self.engine = engine
        self.router = APIRouter()

    def base_endpoint(self, func):
        @functools.wraps(func)
        async def wrapper(*args, **kwargs):
            async with manage_run(
                self.engine.run_manager, func.__name__
            ) as run_id:
                try:
                    results = await func(*args, **kwargs)
                    if isinstance(results, StreamingResponse):
                        return results

                    return {"results": results}
                except R2RException as re:
                    raise HTTPException(
                        status_code=re.status_code,
                        detail={
                            "message": re.message,
                            "error_type": type(re).__name__,
                        },
                    )
                except Exception as e:
                    # Get the pipeline name based on the function name
                    pipeline_name = f"{func.__name__.split('_')[0]}_pipeline"

                    # Safely get the pipeline object and its type
                    pipeline = getattr(
                        self.engine.pipelines, pipeline_name, None
                    )
                    pipeline_type = getattr(
                        pipeline, "pipeline_type", "unknown"
                    )

                    await self.engine.logging_connection.log(
                        log_id=run_id,
                        key="pipeline_type",
                        value=pipeline_type,
                        is_info_log=True,
                    )
                    await self.engine.logging_connection.log(
                        log_id=run_id,
                        key="error",
                        value=str(e),
                        is_info_log=False,
                    )
                    logger.error(f"{func.__name__}() - \n\n{str(e)})")
                    raise HTTPException(
                        status_code=500,
                        detail={
                            "message": f"An error occurred during {func.__name__}",
                            "error": str(e),
                            "error_type": type(e).__name__,
                        },
                    ) from e

        return wrapper

    @classmethod
    def build_router(cls, engine):
        return cls(engine).router