aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/base/providers/auth.py
blob: 352c3331def3dad887068d41a8fad224c8a95744 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, Optional

from fastapi import Security
from fastapi.security import (
    APIKeyHeader,
    HTTPAuthorizationCredentials,
    HTTPBearer,
)

from ..abstractions import R2RException, Token, TokenData
from ..api.models import User
from .base import Provider, ProviderConfig
from .crypto import CryptoProvider
from .email import EmailProvider

logger = logging.getLogger()

if TYPE_CHECKING:
    from core.providers.database import PostgresDatabaseProvider

api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)


class AuthConfig(ProviderConfig):
    secret_key: Optional[str] = None
    require_authentication: bool = False
    require_email_verification: bool = False
    default_admin_email: str = "admin@example.com"
    default_admin_password: str = "change_me_immediately"
    access_token_lifetime_in_minutes: Optional[int] = None
    refresh_token_lifetime_in_days: Optional[int] = None

    @property
    def supported_providers(self) -> list[str]:
        return ["r2r"]

    def validate_config(self) -> None:
        pass


class AuthProvider(Provider, ABC):
    security = HTTPBearer(auto_error=False)
    crypto_provider: CryptoProvider
    email_provider: EmailProvider
    database_provider: "PostgresDatabaseProvider"

    def __init__(
        self,
        config: AuthConfig,
        crypto_provider: CryptoProvider,
        database_provider: "PostgresDatabaseProvider",
        email_provider: EmailProvider,
    ):
        if not isinstance(config, AuthConfig):
            raise ValueError(
                "AuthProvider must be initialized with an AuthConfig"
            )
        self.config = config
        self.admin_email = config.default_admin_email
        self.admin_password = config.default_admin_password
        self.crypto_provider = crypto_provider
        self.database_provider = database_provider
        self.email_provider = email_provider
        super().__init__(config)
        self.config: AuthConfig = config
        self.database_provider: "PostgresDatabaseProvider" = database_provider

    async def _get_default_admin_user(self) -> User:
        return await self.database_provider.users_handler.get_user_by_email(
            self.admin_email
        )

    @abstractmethod
    def create_access_token(self, data: dict) -> str:
        pass

    @abstractmethod
    def create_refresh_token(self, data: dict) -> str:
        pass

    @abstractmethod
    async def decode_token(self, token: str) -> TokenData:
        pass

    @abstractmethod
    async def user(self, token: str) -> User:
        pass

    @abstractmethod
    def get_current_active_user(self, current_user: User) -> User:
        pass

    @abstractmethod
    async def register(self, email: str, password: str) -> User:
        pass

    @abstractmethod
    async def send_verification_email(
        self, email: str, user: Optional[User] = None
    ) -> tuple[str, datetime]:
        pass

    @abstractmethod
    async def verify_email(
        self, email: str, verification_code: str
    ) -> dict[str, str]:
        pass

    @abstractmethod
    async def login(self, email: str, password: str) -> dict[str, Token]:
        pass

    @abstractmethod
    async def refresh_access_token(
        self, refresh_token: str
    ) -> dict[str, Token]:
        pass

    def auth_wrapper(
        self,
        public: bool = False,
    ):
        async def _auth_wrapper(
            auth: Optional[HTTPAuthorizationCredentials] = Security(
                self.security
            ),
            api_key: Optional[str] = Security(api_key_header),
        ) -> User:
            # If authentication is not required and no credentials are provided, return the default admin user
            if (
                ((not self.config.require_authentication) or public)
                and auth is None
                and api_key is None
            ):
                return await self._get_default_admin_user()
            if not auth and not api_key:
                raise R2RException(
                    message="No credentials provided. Create an account at https://app.sciphi.ai and set your API key using `r2r configure key` OR change your base URL to a custom deployment.",
                    status_code=401,
                )
            if auth and api_key:
                raise R2RException(
                    message="Cannot have both Bearer token and API key",
                    status_code=400,
                )
            # 1. Try JWT if `auth` is present (Bearer token)
            if auth is not None:
                credentials = auth.credentials
                try:
                    token_data = await self.decode_token(credentials)
                    user = await self.database_provider.users_handler.get_user_by_email(
                        token_data.email
                    )
                    if user is not None:
                        return user
                except R2RException:
                    # JWT decoding failed for logical reasons (invalid token)
                    pass
                except Exception as e:
                    # JWT decoding failed unexpectedly, log and continue
                    logger.debug(f"JWT verification failed: {e}")

                # 2. If JWT failed, try API key from Bearer token
                # Expected format: key_id.raw_api_key
                if "." in credentials:
                    key_id, raw_api_key = credentials.split(".", 1)
                    api_key_record = await self.database_provider.users_handler.get_api_key_record(
                        key_id
                    )
                    if api_key_record is not None:
                        hashed_key = api_key_record["hashed_key"]
                        if self.crypto_provider.verify_api_key(
                            raw_api_key, hashed_key
                        ):
                            user = await self.database_provider.users_handler.get_user_by_id(
                                api_key_record["user_id"]
                            )
                            if user is not None and user.is_active:
                                return user

            # 3. If no Bearer token worked, try the X-API-Key header
            if api_key is not None and "." in api_key:
                key_id, raw_api_key = api_key.split(".", 1)
                api_key_record = await self.database_provider.users_handler.get_api_key_record(
                    key_id
                )
                if api_key_record is not None:
                    hashed_key = api_key_record["hashed_key"]
                    if self.crypto_provider.verify_api_key(
                        raw_api_key, hashed_key
                    ):
                        user = await self.database_provider.users_handler.get_user_by_id(
                            api_key_record["user_id"]
                        )
                        if user is not None and user.is_active:
                            return user

            # If we reach here, both JWT and API key auth failed
            raise R2RException(
                message="Invalid token or API key",
                status_code=401,
            )

        return _auth_wrapper

    @abstractmethod
    async def change_password(
        self, user: User, current_password: str, new_password: str
    ) -> dict[str, str]:
        pass

    @abstractmethod
    async def request_password_reset(self, email: str) -> dict[str, str]:
        pass

    @abstractmethod
    async def confirm_password_reset(
        self, reset_token: str, new_password: str
    ) -> dict[str, str]:
        pass

    @abstractmethod
    async def logout(self, token: str) -> dict[str, str]:
        pass

    @abstractmethod
    async def send_reset_email(self, email: str) -> dict[str, str]:
        pass