aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/main/app.py
blob: ceb13cce3c3483443241c04c4980e722e07c448d (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.openapi.utils import get_openapi
from fastapi.responses import JSONResponse

from core.base import R2RException
from core.providers import (
    HatchetOrchestrationProvider,
    SimpleOrchestrationProvider,
)
from core.utils.sentry import init_sentry

from .abstractions import R2RServices
from .api.v3.chunks_router import ChunksRouter
from .api.v3.collections_router import CollectionsRouter
from .api.v3.conversations_router import ConversationsRouter
from .api.v3.documents_router import DocumentsRouter
from .api.v3.graph_router import GraphRouter
from .api.v3.indices_router import IndicesRouter
from .api.v3.prompts_router import PromptsRouter
from .api.v3.retrieval_router import RetrievalRouter
from .api.v3.system_router import SystemRouter
from .api.v3.users_router import UsersRouter
from .config import R2RConfig


class R2RApp:
    def __init__(
        self,
        config: R2RConfig,
        orchestration_provider: (
            HatchetOrchestrationProvider | SimpleOrchestrationProvider
        ),
        services: R2RServices,
        chunks_router: ChunksRouter,
        collections_router: CollectionsRouter,
        conversations_router: ConversationsRouter,
        documents_router: DocumentsRouter,
        graph_router: GraphRouter,
        indices_router: IndicesRouter,
        prompts_router: PromptsRouter,
        retrieval_router: RetrievalRouter,
        system_router: SystemRouter,
        users_router: UsersRouter,
    ):
        init_sentry()

        self.config = config
        self.services = services
        self.chunks_router = chunks_router
        self.collections_router = collections_router
        self.conversations_router = conversations_router
        self.documents_router = documents_router
        self.graph_router = graph_router
        self.indices_router = indices_router
        self.orchestration_provider = orchestration_provider
        self.prompts_router = prompts_router
        self.retrieval_router = retrieval_router
        self.system_router = system_router
        self.users_router = users_router

        self.app = FastAPI()

        @self.app.exception_handler(R2RException)
        async def r2r_exception_handler(request: Request, exc: R2RException):
            return JSONResponse(
                status_code=exc.status_code,
                content={
                    "message": exc.message,
                    "error_type": type(exc).__name__,
                },
            )

        self._setup_routes()
        self._apply_cors()

    def _setup_routes(self):
        self.app.include_router(self.chunks_router, prefix="/v3")
        self.app.include_router(self.collections_router, prefix="/v3")
        self.app.include_router(self.conversations_router, prefix="/v3")
        self.app.include_router(self.documents_router, prefix="/v3")
        self.app.include_router(self.graph_router, prefix="/v3")
        self.app.include_router(self.indices_router, prefix="/v3")
        self.app.include_router(self.prompts_router, prefix="/v3")
        self.app.include_router(self.retrieval_router, prefix="/v3")
        self.app.include_router(self.system_router, prefix="/v3")
        self.app.include_router(self.users_router, prefix="/v3")

        @self.app.get("/openapi_spec", include_in_schema=False)
        async def openapi_spec():
            return get_openapi(
                title="R2R Application API",
                version="1.0.0",
                routes=self.app.routes,
            )

    def _apply_cors(self):
        origins = ["*", "http://localhost:3000", "http://localhost:7272"]
        self.app.add_middleware(
            CORSMiddleware,
            allow_origins=origins,
            allow_credentials=True,
            allow_methods=["*"],
            allow_headers=["*"],
        )

    async def serve(self, host: str = "0.0.0.0", port: int = 7272):
        import uvicorn

        from core.utils.logging_config import configure_logging

        configure_logging()

        config = uvicorn.Config(
            self.app,
            host=host,
            port=port,
            log_config=None,
        )
        server = uvicorn.Server(config)
        await server.serve()