aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/main/app.py
blob: 981445e44a07d486844df7d1abe07ca23887bd54 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from fastapi import FastAPI

from .engine import R2REngine


class R2RApp:
    def __init__(self, engine: R2REngine):
        self.engine = engine
        self._setup_routes()
        self._apply_cors()

    async def openapi_spec(self, *args, **kwargs):
        from fastapi.openapi.utils import get_openapi

        return get_openapi(
            title="R2R Application API",
            version="1.0.0",
            routes=self.app.routes,
        )

    def _setup_routes(self):
        from .api.routes import ingestion, management, retrieval

        self.app = FastAPI()

        # Create routers with the engine
        ingestion_router = ingestion.IngestionRouter.build_router(self.engine)
        management_router = management.ManagementRouter.build_router(
            self.engine
        )
        retrieval_router = retrieval.RetrievalRouter.build_router(self.engine)

        # Include routers in the app
        self.app.include_router(ingestion_router, prefix="/v1")
        self.app.include_router(management_router, prefix="/v1")
        self.app.include_router(retrieval_router, prefix="/v1")

    def _apply_cors(self):
        from fastapi.middleware.cors import CORSMiddleware

        origins = ["*", "http://localhost:3000", "http://localhost:8000"]
        self.app.add_middleware(
            CORSMiddleware,
            allow_origins=origins,
            allow_credentials=True,
            allow_methods=["*"],
            allow_headers=["*"],
        )

    def serve(self, host: str = "0.0.0.0", port: int = 8000):
        import uvicorn

        uvicorn.run(self.app, host=host, port=port)