1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
|
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, Optional
from fastapi import Security
from fastapi.security import (
APIKeyHeader,
HTTPAuthorizationCredentials,
HTTPBearer,
)
from ..abstractions import R2RException, Token, TokenData
from ..api.models import User
from .base import Provider, ProviderConfig
from .crypto import CryptoProvider
from .email import EmailProvider
logger = logging.getLogger()
if TYPE_CHECKING:
from core.providers.database import PostgresDatabaseProvider
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
class AuthConfig(ProviderConfig):
secret_key: Optional[str] = None
require_authentication: bool = False
require_email_verification: bool = False
default_admin_email: str = "admin@example.com"
default_admin_password: str = "change_me_immediately"
access_token_lifetime_in_minutes: Optional[int] = None
refresh_token_lifetime_in_days: Optional[int] = None
@property
def supported_providers(self) -> list[str]:
return ["r2r"]
def validate_config(self) -> None:
pass
class AuthProvider(Provider, ABC):
security = HTTPBearer(auto_error=False)
crypto_provider: CryptoProvider
email_provider: EmailProvider
database_provider: "PostgresDatabaseProvider"
def __init__(
self,
config: AuthConfig,
crypto_provider: CryptoProvider,
database_provider: "PostgresDatabaseProvider",
email_provider: EmailProvider,
):
if not isinstance(config, AuthConfig):
raise ValueError(
"AuthProvider must be initialized with an AuthConfig"
)
self.config = config
self.admin_email = config.default_admin_email
self.admin_password = config.default_admin_password
self.crypto_provider = crypto_provider
self.database_provider = database_provider
self.email_provider = email_provider
super().__init__(config)
self.config: AuthConfig = config
self.database_provider: "PostgresDatabaseProvider" = database_provider
async def _get_default_admin_user(self) -> User:
return await self.database_provider.users_handler.get_user_by_email(
self.admin_email
)
@abstractmethod
def create_access_token(self, data: dict) -> str:
pass
@abstractmethod
def create_refresh_token(self, data: dict) -> str:
pass
@abstractmethod
async def decode_token(self, token: str) -> TokenData:
pass
@abstractmethod
async def user(self, token: str) -> User:
pass
@abstractmethod
def get_current_active_user(self, current_user: User) -> User:
pass
@abstractmethod
async def register(self, email: str, password: str) -> User:
pass
@abstractmethod
async def send_verification_email(
self, email: str, user: Optional[User] = None
) -> tuple[str, datetime]:
pass
@abstractmethod
async def verify_email(
self, email: str, verification_code: str
) -> dict[str, str]:
pass
@abstractmethod
async def login(self, email: str, password: str) -> dict[str, Token]:
pass
@abstractmethod
async def refresh_access_token(
self, refresh_token: str
) -> dict[str, Token]:
pass
def auth_wrapper(
self,
public: bool = False,
):
async def _auth_wrapper(
auth: Optional[HTTPAuthorizationCredentials] = Security(
self.security
),
api_key: Optional[str] = Security(api_key_header),
) -> User:
# If authentication is not required and no credentials are provided, return the default admin user
if (
((not self.config.require_authentication) or public)
and auth is None
and api_key is None
):
return await self._get_default_admin_user()
if not auth and not api_key:
raise R2RException(
message="No credentials provided. Create an account at https://app.sciphi.ai and set your API key using `r2r configure key` OR change your base URL to a custom deployment.",
status_code=401,
)
if auth and api_key:
raise R2RException(
message="Cannot have both Bearer token and API key",
status_code=400,
)
# 1. Try JWT if `auth` is present (Bearer token)
if auth is not None:
credentials = auth.credentials
try:
token_data = await self.decode_token(credentials)
user = await self.database_provider.users_handler.get_user_by_email(
token_data.email
)
if user is not None:
return user
except R2RException:
# JWT decoding failed for logical reasons (invalid token)
pass
except Exception as e:
# JWT decoding failed unexpectedly, log and continue
logger.debug(f"JWT verification failed: {e}")
# 2. If JWT failed, try API key from Bearer token
# Expected format: key_id.raw_api_key
if "." in credentials:
key_id, raw_api_key = credentials.split(".", 1)
api_key_record = await self.database_provider.users_handler.get_api_key_record(
key_id
)
if api_key_record is not None:
hashed_key = api_key_record["hashed_key"]
if self.crypto_provider.verify_api_key(
raw_api_key, hashed_key
):
user = await self.database_provider.users_handler.get_user_by_id(
api_key_record["user_id"]
)
if user is not None and user.is_active:
return user
# 3. If no Bearer token worked, try the X-API-Key header
if api_key is not None and "." in api_key:
key_id, raw_api_key = api_key.split(".", 1)
api_key_record = await self.database_provider.users_handler.get_api_key_record(
key_id
)
if api_key_record is not None:
hashed_key = api_key_record["hashed_key"]
if self.crypto_provider.verify_api_key(
raw_api_key, hashed_key
):
user = await self.database_provider.users_handler.get_user_by_id(
api_key_record["user_id"]
)
if user is not None and user.is_active:
return user
# If we reach here, both JWT and API key auth failed
raise R2RException(
message="Invalid token or API key",
status_code=401,
)
return _auth_wrapper
@abstractmethod
async def change_password(
self, user: User, current_password: str, new_password: str
) -> dict[str, str]:
pass
@abstractmethod
async def request_password_reset(self, email: str) -> dict[str, str]:
pass
@abstractmethod
async def confirm_password_reset(
self, reset_token: str, new_password: str
) -> dict[str, str]:
pass
@abstractmethod
async def logout(self, token: str) -> dict[str, str]:
pass
@abstractmethod
async def send_reset_email(self, email: str) -> dict[str, str]:
pass
|