import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, Optional
from fastapi import Security
from fastapi.security import (
APIKeyHeader,
HTTPAuthorizationCredentials,
HTTPBearer,
)
from ..abstractions import R2RException, Token, TokenData
from ..api.models import User
from .base import Provider, ProviderConfig
from .crypto import CryptoProvider
from .email import EmailProvider
logger = logging.getLogger()
if TYPE_CHECKING:
from core.providers.database import PostgresDatabaseProvider
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
class AuthConfig(ProviderConfig):
secret_key: Optional[str] = None
require_authentication: bool = False
require_email_verification: bool = False
default_admin_email: str = "admin@example.com"
default_admin_password: str = "change_me_immediately"
access_token_lifetime_in_minutes: Optional[int] = None
refresh_token_lifetime_in_days: Optional[int] = None
@property
def supported_providers(self) -> list[str]:
return ["r2r"]
def validate_config(self) -> None:
pass
class AuthProvider(Provider, ABC):
security = HTTPBearer(auto_error=False)
crypto_provider: CryptoProvider
email_provider: EmailProvider
database_provider: "PostgresDatabaseProvider"
def __init__(
self,
config: AuthConfig,
crypto_provider: CryptoProvider,
database_provider: "PostgresDatabaseProvider",
email_provider: EmailProvider,
):
if not isinstance(config, AuthConfig):
raise ValueError(
"AuthProvider must be initialized with an AuthConfig"
)
self.config = config
self.admin_email = config.default_admin_email
self.admin_password = config.default_admin_password
self.crypto_provider = crypto_provider
self.database_provider = database_provider
self.email_provider = email_provider
super().__init__(config)
self.config: AuthConfig = config
self.database_provider: "PostgresDatabaseProvider" = database_provider
async def _get_default_admin_user(self) -> User:
return await self.database_provider.users_handler.get_user_by_email(
self.admin_email
)
@abstractmethod
def create_access_token(self, data: dict) -> str:
pass
@abstractmethod
def create_refresh_token(self, data: dict) -> str:
pass
@abstractmethod
async def decode_token(self, token: str) -> TokenData:
pass
@abstractmethod
async def user(self, token: str) -> User:
pass
@abstractmethod
def get_current_active_user(self, current_user: User) -> User:
pass
@abstractmethod
async def register(self, email: str, password: str) -> User:
pass
@abstractmethod
async def send_verification_email(
self, email: str, user: Optional[User] = None
) -> tuple[str, datetime]:
pass
@abstractmethod
async def verify_email(
self, email: str, verification_code: str
) -> dict[str, str]:
pass
@abstractmethod
async def login(self, email: str, password: str) -> dict[str, Token]:
pass
@abstractmethod
async def refresh_access_token(
self, refresh_token: str
) -> dict[str, Token]:
pass
def auth_wrapper(
self,
public: bool = False,
):
async def _auth_wrapper(
auth: Optional[HTTPAuthorizationCredentials] = Security(
self.security
),
api_key: Optional[str] = Security(api_key_header),
) -> User:
# If authentication is not required and no credentials are provided, return the default admin user
if (
((not self.config.require_authentication) or public)
and auth is None
and api_key is None
):
return await self._get_default_admin_user()
if not auth and not api_key:
raise R2RException(
message="No credentials provided. Create an account at https://app.sciphi.ai and set your API key using `r2r configure key` OR change your base URL to a custom deployment.",
status_code=401,
)
if auth and api_key:
raise R2RException(
message="Cannot have both Bearer token and API key",
status_code=400,
)
# 1. Try JWT if `auth` is present (Bearer token)
if auth is not None:
credentials = auth.credentials
try:
token_data = await self.decode_token(credentials)
user = await self.database_provider.users_handler.get_user_by_email(
token_data.email
)
if user is not None:
return user
except R2RException:
# JWT decoding failed for logical reasons (invalid token)
pass
except Exception as e:
# JWT decoding failed unexpectedly, log and continue
logger.debug(f"JWT verification failed: {e}")
# 2. If JWT failed, try API key from Bearer token
# Expected format: key_id.raw_api_key
if "." in credentials:
key_id, raw_api_key = credentials.split(".", 1)
api_key_record = await self.database_provider.users_handler.get_api_key_record(
key_id
)
if api_key_record is not None:
hashed_key = api_key_record["hashed_key"]
if self.crypto_provider.verify_api_key(
raw_api_key, hashed_key
):
user = await self.database_provider.users_handler.get_user_by_id(
api_key_record["user_id"]
)
if user is not None and user.is_active:
return user
# 3. If no Bearer token worked, try the X-API-Key header
if api_key is not None and "." in api_key:
key_id, raw_api_key = api_key.split(".", 1)
api_key_record = await self.database_provider.users_handler.get_api_key_record(
key_id
)
if api_key_record is not None:
hashed_key = api_key_record["hashed_key"]
if self.crypto_provider.verify_api_key(
raw_api_key, hashed_key
):
user = await self.database_provider.users_handler.get_user_by_id(
api_key_record["user_id"]
)
if user is not None and user.is_active:
return user
# If we reach here, both JWT and API key auth failed
raise R2RException(
message="Invalid token or API key",
status_code=401,
)
return _auth_wrapper
@abstractmethod
async def change_password(
self, user: User, current_password: str, new_password: str
) -> dict[str, str]:
pass
@abstractmethod
async def request_password_reset(self, email: str) -> dict[str, str]:
pass
@abstractmethod
async def confirm_password_reset(
self, reset_token: str, new_password: str
) -> dict[str, str]:
pass
@abstractmethod
async def logout(self, token: str) -> dict[str, str]:
pass
@abstractmethod
async def send_reset_email(self, email: str) -> dict[str, str]:
pass