import logging import os from datetime import datetime from typing import Optional from uuid import UUID import jwt from fastapi import Depends from core.base import ( AuthConfig, AuthProvider, CryptoProvider, EmailProvider, R2RException, Token, TokenData, ) from core.base.api.models import User from ..database import PostgresDatabaseProvider logger = logging.getLogger() class JwtAuthProvider(AuthProvider): def __init__( self, config: AuthConfig, crypto_provider: CryptoProvider, database_provider: PostgresDatabaseProvider, email_provider: EmailProvider, ): super().__init__( config, crypto_provider, database_provider, email_provider ) async def login(self, email: str, password: str) -> dict[str, Token]: raise NotImplementedError("Not implemented") async def oauth_callback(self, code: str) -> dict[str, Token]: raise NotImplementedError("Not implemented") async def user(self, token: str) -> User: raise NotImplementedError("Not implemented") async def change_password( self, user: User, current_password: str, new_password: str ) -> dict[str, str]: raise NotImplementedError("Not implemented") async def confirm_password_reset( self, reset_token: str, new_password: str ) -> dict[str, str]: raise NotImplementedError("Not implemented") def create_access_token(self, data: dict) -> str: raise NotImplementedError("Not implemented") def create_refresh_token(self, data: dict) -> str: raise NotImplementedError("Not implemented") async def decode_token(self, token: str) -> TokenData: # use JWT library to validate and decode JWT token jwtSecret = os.getenv("JWT_SECRET") if jwtSecret is None: raise R2RException( status_code=500, message="JWT_SECRET environment variable is not set", ) try: user = jwt.decode(token, jwtSecret, algorithms=["HS256"]) except Exception as e: logger.info(f"JWT verification failed: {e}") raise R2RException( status_code=401, message="Invalid JWT token", detail=e ) from e if user: # Create user in database if not exists try: await self.database_provider.users_handler.get_user_by_email( user.get("email") ) # TODO do we want to update user info here based on what's in the token? except Exception: # user doesn't exist, create in db logger.debug(f"Creating new user: {user.get('email')}") try: await self.database_provider.users_handler.create_user( email=user.get("email"), account_type="external", name=user.get("name"), ) except Exception as e: logger.error(f"Error creating user: {e}") raise R2RException( status_code=500, message="Failed to create user" ) from e return TokenData( email=user.get("email"), token_type="bearer", exp=user.get("exp"), ) else: raise R2RException(status_code=401, message="Invalid JWT token") async def refresh_access_token( self, refresh_token: str ) -> dict[str, Token]: raise NotImplementedError("Not implemented") def get_current_active_user( self, current_user: User = Depends(user) ) -> User: # Check if user is active if not current_user.is_active: raise R2RException(status_code=400, message="Inactive user") return current_user async def logout(self, token: str) -> dict[str, str]: raise NotImplementedError("Not implemented") async def register( self, email: str, password: str, name: Optional[str] = None, bio: Optional[str] = None, profile_picture: Optional[str] = None, ) -> User: # type: ignore raise NotImplementedError("Not implemented") async def request_password_reset(self, email: str) -> dict[str, str]: raise NotImplementedError("Not implemented") async def send_reset_email(self, email: str) -> dict[str, str]: raise NotImplementedError("Not implemented") async def create_user_api_key( self, user_id: UUID, name: Optional[str] = None, description: Optional[str] = None, ) -> dict[str, str]: raise NotImplementedError("Not implemented") async def verify_email( self, email: str, verification_code: str ) -> dict[str, str]: raise NotImplementedError("Not implemented") async def send_verification_email( self, email: str, user: Optional[User] = None ) -> tuple[str, datetime]: raise NotImplementedError("Not implemented") async def list_user_api_keys(self, user_id: UUID) -> list[dict]: raise NotImplementedError("Not implemented") async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool: raise NotImplementedError("Not implemented") async def oauth_callback_handler( self, provider: str, oauth_id: str, email: str ) -> dict[str, Token]: raise NotImplementedError("Not implemented")