import logging import os from datetime import datetime, timedelta, timezone from typing import Optional from uuid import UUID from fastapi import Depends, HTTPException from fastapi.security import OAuth2PasswordBearer from core.base import ( AuthConfig, AuthProvider, CollectionResponse, CryptoProvider, EmailProvider, R2RException, Token, TokenData, ) from core.base.api.models import User from ..database import PostgresDatabaseProvider DEFAULT_ACCESS_LIFETIME_IN_MINUTES = 3600 DEFAULT_REFRESH_LIFETIME_IN_DAYS = 7 logger = logging.getLogger() oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") def normalize_email(email: str) -> str: """Normalizes an email address by converting it to lowercase. This ensures consistent email handling throughout the application. Args: email: The email address to normalize Returns: The normalized (lowercase) email address """ return email.lower() if email else "" class R2RAuthProvider(AuthProvider): def __init__( self, config: AuthConfig, crypto_provider: CryptoProvider, database_provider: PostgresDatabaseProvider, email_provider: EmailProvider, ): super().__init__( config, crypto_provider, database_provider, email_provider ) self.database_provider: PostgresDatabaseProvider = database_provider logger.debug(f"Initializing R2RAuthProvider with config: {config}") # We no longer use a local secret_key or defaults here. # All key handling is done in the crypto_provider. self.access_token_lifetime_in_minutes = ( config.access_token_lifetime_in_minutes or os.getenv("R2R_ACCESS_LIFE_IN_MINUTES") or DEFAULT_ACCESS_LIFETIME_IN_MINUTES ) self.refresh_token_lifetime_in_days = ( config.refresh_token_lifetime_in_days or os.getenv("R2R_REFRESH_LIFE_IN_DAYS") or DEFAULT_REFRESH_LIFETIME_IN_DAYS ) self.config: AuthConfig = config async def initialize(self): try: user = await self.register( email=normalize_email(self.admin_email), password=self.admin_password, is_superuser=True, ) await self.database_provider.users_handler.mark_user_as_superuser( id=user.id ) except R2RException: logger.info("Default admin user already exists.") def create_access_token(self, data: dict) -> str: expire = datetime.now(timezone.utc) + timedelta( minutes=float(self.access_token_lifetime_in_minutes) ) # Add token_type and pass data/expiry to crypto_provider data_with_type = {**data, "token_type": "access"} return self.crypto_provider.generate_secure_token( data=data_with_type, expiry=expire, ) def create_refresh_token(self, data: dict) -> str: expire = datetime.now(timezone.utc) + timedelta( days=float(self.refresh_token_lifetime_in_days) ) data_with_type = {**data, "token_type": "refresh"} return self.crypto_provider.generate_secure_token( data=data_with_type, expiry=expire, ) async def decode_token(self, token: str) -> TokenData: if "token=" in token: token = token.split("token=")[1] if "&tokenType=refresh" in token: token = token.split("&tokenType=refresh")[0] # First, check if the token is blacklisted if await self.database_provider.token_handler.is_token_blacklisted( token=token ): raise R2RException( status_code=401, message="Token has been invalidated" ) # Verify token using crypto_provider payload = self.crypto_provider.verify_secure_token(token=token) if payload is None: raise R2RException( status_code=401, message="Invalid or expired token" ) email = payload.get("sub") token_type = payload.get("token_type") exp = payload.get("exp") if email is None or token_type is None or exp is None: raise R2RException(status_code=401, message="Invalid token claims") email_str: str = email token_type_str: str = token_type exp_float: float = exp exp_datetime = datetime.fromtimestamp(exp_float, tz=timezone.utc) if exp_datetime < datetime.now(timezone.utc): raise R2RException(status_code=401, message="Token has expired") return TokenData( email=normalize_email(email_str), token_type=token_type_str, exp=exp_datetime, ) async def authenticate_api_key(self, api_key: str) -> User: """Authenticate using an API key of the form "public_key.raw_key". Returns a User if successful, or raises R2RException if not. """ try: key_id, raw_key = api_key.split(".", 1) except ValueError as e: raise R2RException( status_code=401, message="Invalid API key format" ) from e key_record = ( await self.database_provider.users_handler.get_api_key_record( key_id=key_id ) ) if not key_record: raise R2RException(status_code=401, message="Invalid API key") if not self.crypto_provider.verify_api_key( raw_api_key=raw_key, hashed_key=key_record["hashed_key"] ): raise R2RException(status_code=401, message="Invalid API key") user = await self.database_provider.users_handler.get_user_by_id( id=key_record["user_id"] ) if not user.is_active: raise R2RException( status_code=401, message="User account is inactive" ) return user async def user(self, token: str = Depends(oauth2_scheme)) -> User: """Attempt to authenticate via JWT first, then fallback to API key.""" # Try JWT auth try: token_data = await self.decode_token(token=token) if not token_data.email: raise R2RException( status_code=401, message="Could not validate credentials" ) user = ( await self.database_provider.users_handler.get_user_by_email( email=normalize_email(token_data.email) ) ) if user is None: raise R2RException( status_code=401, message="Invalid authentication credentials", ) return user except R2RException: # If JWT fails, try API key auth # OAuth2PasswordBearer provides token as "Bearer xxx", strip it if needed token = token.removeprefix("Bearer ") return await self.authenticate_api_key(api_key=token) def get_current_active_user( self, current_user: User = Depends(user) ) -> User: if not current_user.is_active: raise R2RException(status_code=400, message="Inactive user") return current_user async def register( self, email: str, password: Optional[str] = None, is_superuser: bool = False, account_type: str = "password", github_id: Optional[str] = None, google_id: Optional[str] = None, name: Optional[str] = None, bio: Optional[str] = None, profile_picture: Optional[str] = None, ) -> User: if account_type == "password": if not password: raise R2RException( status_code=400, message="Password is required for password accounts", ) else: if github_id and google_id: raise R2RException( status_code=400, message="Cannot register OAuth with both GitHub and Google IDs", ) if not github_id and not google_id: raise R2RException( status_code=400, message="Invalid OAuth specification without GitHub or Google ID", ) new_user = await self.database_provider.users_handler.create_user( email=normalize_email(email), password=password, is_superuser=is_superuser, account_type=account_type, github_id=github_id, google_id=google_id, name=name, bio=bio, profile_picture=profile_picture, ) default_collection: CollectionResponse = ( await self.database_provider.collections_handler.create_collection( owner_id=new_user.id, ) ) await self.database_provider.graphs_handler.create( collection_id=default_collection.id, name=default_collection.name, description=default_collection.description, ) await self.database_provider.users_handler.add_user_to_collection( new_user.id, default_collection.id ) new_user = await self.database_provider.users_handler.get_user_by_id( new_user.id ) if self.config.require_email_verification: verification_code, _ = await self.send_verification_email( email=normalize_email(email), user=new_user ) else: expiry = datetime.now(timezone.utc) + timedelta(hours=366 * 10) await self.database_provider.users_handler.store_verification_code( id=new_user.id, verification_code=str(-1), expiry=expiry, ) await self.database_provider.users_handler.mark_user_as_verified( id=new_user.id ) return new_user async def send_verification_email( self, email: str, user: Optional[User] = None ) -> tuple[str, datetime]: if user is None: user = ( await self.database_provider.users_handler.get_user_by_email( email=normalize_email(email) ) ) if not user: raise R2RException(status_code=404, message="User not found") verification_code = self.crypto_provider.generate_verification_code() expiry = datetime.now(timezone.utc) + timedelta(hours=24) await self.database_provider.users_handler.store_verification_code( id=user.id, verification_code=verification_code, expiry=expiry, ) if hasattr(user, "verification_code_expiry"): user.verification_code_expiry = expiry first_name = ( user.name.split(" ")[0] if user.name else email.split("@")[0] ) await self.email_provider.send_verification_email( to_email=user.email, verification_code=verification_code, dynamic_template_data={"first_name": first_name}, ) return verification_code, expiry async def verify_email( self, email: str, verification_code: str ) -> dict[str, str]: user_id = await self.database_provider.users_handler.get_user_id_by_verification_code( verification_code=verification_code ) await self.database_provider.users_handler.mark_user_as_verified( id=user_id ) await self.database_provider.users_handler.remove_verification_code( verification_code=verification_code ) return {"message": "Email verified successfully"} async def login(self, email: str, password: str) -> dict[str, Token]: logger.debug(f"Attempting login for email: {email}") user = await self.database_provider.users_handler.get_user_by_email( email=normalize_email(email) ) if user.account_type != "password": logger.warning( f"Password login not allowed for {user.account_type} accounts: {email}" ) raise R2RException( status_code=401, message=f"This account is configured for {user.account_type} login, not password.", ) logger.debug(f"User found: {user}") if not isinstance(user.hashed_password, str): logger.error( f"Invalid hashed_password type: {type(user.hashed_password)}" ) raise HTTPException( status_code=500, detail="Invalid password hash in database", ) try: password_verified = self.crypto_provider.verify_password( plain_password=password, hashed_password=user.hashed_password, ) except Exception as e: logger.error(f"Error during password verification: {str(e)}") raise HTTPException( status_code=500, detail="Error during password verification", ) from e if not password_verified: logger.warning(f"Invalid password for user: {email}") raise R2RException( status_code=401, message="Incorrect email or password" ) if not user.is_verified and self.config.require_email_verification: logger.warning(f"Unverified user attempted login: {email}") raise R2RException(status_code=401, message="Email not verified") access_token = self.create_access_token( data={"sub": normalize_email(user.email)} ) refresh_token = self.create_refresh_token( data={"sub": normalize_email(user.email)} ) return { "access_token": Token(token=access_token, token_type="access"), "refresh_token": Token(token=refresh_token, token_type="refresh"), } async def refresh_access_token( self, refresh_token: str ) -> dict[str, Token]: token_data = await self.decode_token(refresh_token) if token_data.token_type != "refresh": raise R2RException( status_code=401, message="Invalid refresh token" ) # Invalidate the old refresh token and create a new one await self.database_provider.token_handler.blacklist_token( token=refresh_token ) new_access_token = self.create_access_token( data={"sub": normalize_email(token_data.email)} ) new_refresh_token = self.create_refresh_token( data={"sub": normalize_email(token_data.email)} ) return { "access_token": Token(token=new_access_token, token_type="access"), "refresh_token": Token( token=new_refresh_token, token_type="refresh" ), } async def change_password( self, user: User, current_password: str, new_password: str ) -> dict[str, str]: if not isinstance(user.hashed_password, str): logger.error( f"Invalid hashed_password type: {type(user.hashed_password)}" ) raise HTTPException( status_code=500, detail="Invalid password hash in database", ) if not self.crypto_provider.verify_password( plain_password=current_password, hashed_password=user.hashed_password, ): raise R2RException( status_code=400, message="Incorrect current password" ) hashed_new_password = self.crypto_provider.get_password_hash( password=new_password ) await self.database_provider.users_handler.update_user_password( id=user.id, new_hashed_password=hashed_new_password, ) try: await self.email_provider.send_password_changed_email( to_email=normalize_email(user.email), dynamic_template_data={ "first_name": ( user.name.split(" ")[0] or "User" if user.name else "User" ) }, ) except Exception as e: logger.error( f"Failed to send password change notification: {str(e)}" ) return {"message": "Password changed successfully"} async def request_password_reset(self, email: str) -> dict[str, str]: try: user = ( await self.database_provider.users_handler.get_user_by_email( email=normalize_email(email) ) ) reset_token = self.crypto_provider.generate_verification_code() expiry = datetime.now(timezone.utc) + timedelta(hours=1) await self.database_provider.users_handler.store_reset_token( id=user.id, reset_token=reset_token, expiry=expiry, ) first_name = ( user.name.split(" ")[0] if user.name else email.split("@")[0] ) await self.email_provider.send_password_reset_email( to_email=normalize_email(email), reset_token=reset_token, dynamic_template_data={"first_name": first_name}, ) return { "message": "If the email exists, a reset link has been sent" } except R2RException as e: if e.status_code == 404: # User doesn't exist; return a success message anyway return { "message": "If the email exists, a reset link has been sent" } else: raise async def confirm_password_reset( self, reset_token: str, new_password: str ) -> dict[str, str]: user_id = await self.database_provider.users_handler.get_user_id_by_reset_token( reset_token=reset_token ) if not user_id: raise R2RException( status_code=400, message="Invalid or expired reset token" ) hashed_new_password = self.crypto_provider.get_password_hash( password=new_password ) await self.database_provider.users_handler.update_user_password( id=user_id, new_hashed_password=hashed_new_password, ) await self.database_provider.users_handler.remove_reset_token( id=user_id ) # Get the user information user = await self.database_provider.users_handler.get_user_by_id( id=user_id ) try: await self.email_provider.send_password_changed_email( to_email=normalize_email(user.email), dynamic_template_data={ "first_name": ( user.name.split(" ")[0] or "User" if user.name else "User" ) }, ) except Exception as e: logger.error( f"Failed to send password change notification: {str(e)}" ) return {"message": "Password reset successfully"} async def logout(self, token: str) -> dict[str, str]: await self.database_provider.token_handler.blacklist_token(token=token) return {"message": "Logged out successfully"} async def clean_expired_blacklisted_tokens(self): await self.database_provider.token_handler.clean_expired_blacklisted_tokens() async def send_reset_email(self, email: str) -> dict: verification_code, expiry = await self.send_verification_email( email=normalize_email(email) ) return { "verification_code": verification_code, "expiry": expiry, "message": f"Verification email sent successfully to {email}", } async def create_user_api_key( self, user_id: UUID, name: Optional[str] = None, description: Optional[str] = None, ) -> dict[str, str]: key_id, raw_api_key = self.crypto_provider.generate_api_key() hashed_key = self.crypto_provider.hash_api_key(raw_api_key) api_key_uuid = ( await self.database_provider.users_handler.store_user_api_key( user_id=user_id, key_id=key_id, hashed_key=hashed_key, name=name, description=description, ) ) return { "api_key": f"{key_id}.{raw_api_key}", "key_id": str(api_key_uuid), "public_key": key_id, "name": name or "", } async def list_user_api_keys(self, user_id: UUID) -> list[dict]: return await self.database_provider.users_handler.get_user_api_keys( user_id=user_id ) async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool: return await self.database_provider.users_handler.delete_api_key( user_id=user_id, key_id=key_id, ) async def rename_api_key( self, user_id: UUID, key_id: UUID, new_name: str ) -> bool: return await self.database_provider.users_handler.update_api_key_name( user_id=user_id, key_id=key_id, name=new_name, ) async def oauth_callback_handler( self, provider: str, oauth_id: str, email: str ) -> dict[str, Token]: """Handles a login/registration flow for OAuth providers (e.g., Google or GitHub). :param provider: "google" or "github" :param oauth_id: The unique ID from the OAuth provider (e.g. Google's 'sub') :param email: The user's email from the provider, if available. :return: dict with access_token and refresh_token """ # 1) Attempt to find user by google_id or github_id, or by email # The logic depends on your preference. We'll assume "google" => google_id, etc. try: if provider == "google": try: user = await self.database_provider.users_handler.get_user_by_email( normalize_email(email) ) # If user found, check if user.google_id matches or is null. If null, update it if user and not user.google_id: raise R2RException( status_code=401, message="User already exists and is not linked to Google account", ) except Exception: # Create new user user = await self.register( email=normalize_email(email) or f"{oauth_id}@google_oauth.fake", # fallback password=None, # no password account_type="oauth", google_id=oauth_id, ) elif provider == "github": try: user = await self.database_provider.users_handler.get_user_by_email( normalize_email(email) ) # If user found, check if user.google_id matches or is null. If null, update it if user and not user.github_id: raise R2RException( status_code=401, message="User already exists and is not linked to Github account", ) except Exception: # Create new user user = await self.register( email=normalize_email(email) or f"{oauth_id}@github_oauth.fake", # fallback password=None, # no password account_type="oauth", github_id=oauth_id, ) # else handle other providers except R2RException: # If no user found or creation fails raise R2RException( status_code=401, message="Could not create or fetch user" ) from None # If user is inactive, etc. if not user.is_active: raise R2RException( status_code=401, message="User account is inactive" ) # Possibly mark user as verified if you trust the OAuth provider's email user.is_verified = True await self.database_provider.users_handler.update_user(user) # 2) Generate tokens access_token = self.create_access_token( data={"sub": normalize_email(user.email)} ) refresh_token = self.create_refresh_token( data={"sub": normalize_email(user.email)} ) return { "access_token": Token(token=access_token, token_type="access"), "refresh_token": Token(token=refresh_token, token_type="refresh"), }