about summary refs log tree commit diff
import logging
import os
from datetime import datetime
from typing import Optional
from uuid import UUID

import jwt
from fastapi import Depends

from core.base import (
    AuthConfig,
    AuthProvider,
    CryptoProvider,
    EmailProvider,
    R2RException,
    Token,
    TokenData,
)
from core.base.api.models import User

from ..database import PostgresDatabaseProvider

logger = logging.getLogger()


class JwtAuthProvider(AuthProvider):
    def __init__(
        self,
        config: AuthConfig,
        crypto_provider: CryptoProvider,
        database_provider: PostgresDatabaseProvider,
        email_provider: EmailProvider,
    ):
        super().__init__(
            config, crypto_provider, database_provider, email_provider
        )

    async def login(self, email: str, password: str) -> dict[str, Token]:
        raise NotImplementedError("Not implemented")

    async def oauth_callback(self, code: str) -> dict[str, Token]:
        raise NotImplementedError("Not implemented")

    async def user(self, token: str) -> User:
        raise NotImplementedError("Not implemented")

    async def change_password(
        self, user: User, current_password: str, new_password: str
    ) -> dict[str, str]:
        raise NotImplementedError("Not implemented")

    async def confirm_password_reset(
        self, reset_token: str, new_password: str
    ) -> dict[str, str]:
        raise NotImplementedError("Not implemented")

    def create_access_token(self, data: dict) -> str:
        raise NotImplementedError("Not implemented")

    def create_refresh_token(self, data: dict) -> str:
        raise NotImplementedError("Not implemented")

    async def decode_token(self, token: str) -> TokenData:
        # use JWT library to validate and decode JWT token
        jwtSecret = os.getenv("JWT_SECRET")
        if jwtSecret is None:
            raise R2RException(
                status_code=500,
                message="JWT_SECRET environment variable is not set",
            )
        try:
            user = jwt.decode(token, jwtSecret, algorithms=["HS256"])
        except Exception as e:
            logger.info(f"JWT verification failed: {e}")
            raise R2RException(
                status_code=401, message="Invalid JWT token", detail=e
            ) from e
        if user:
            # Create user in database if not exists
            try:
                await self.database_provider.users_handler.get_user_by_email(
                    user.get("email")
                )
                # TODO do we want to update user info here based on what's in the token?
            except Exception:
                # user doesn't exist, create in db
                logger.debug(f"Creating new user: {user.get('email')}")
                try:
                    await self.database_provider.users_handler.create_user(
                        email=user.get("email"),
                        account_type="external",
                        name=user.get("name"),
                    )
                except Exception as e:
                    logger.error(f"Error creating user: {e}")
                    raise R2RException(
                        status_code=500, message="Failed to create user"
                    ) from e
            return TokenData(
                email=user.get("email"),
                token_type="bearer",
                exp=user.get("exp"),
            )
        else:
            raise R2RException(status_code=401, message="Invalid JWT token")

    async def refresh_access_token(
        self, refresh_token: str
    ) -> dict[str, Token]:
        raise NotImplementedError("Not implemented")

    def get_current_active_user(
        self, current_user: User = Depends(user)
    ) -> User:
        # Check if user is active
        if not current_user.is_active:
            raise R2RException(status_code=400, message="Inactive user")
        return current_user

    async def logout(self, token: str) -> dict[str, str]:
        raise NotImplementedError("Not implemented")

    async def register(
        self,
        email: str,
        password: str,
        name: Optional[str] = None,
        bio: Optional[str] = None,
        profile_picture: Optional[str] = None,
    ) -> User:  # type: ignore
        raise NotImplementedError("Not implemented")

    async def request_password_reset(self, email: str) -> dict[str, str]:
        raise NotImplementedError("Not implemented")

    async def send_reset_email(self, email: str) -> dict[str, str]:
        raise NotImplementedError("Not implemented")

    async def create_user_api_key(
        self,
        user_id: UUID,
        name: Optional[str] = None,
        description: Optional[str] = None,
    ) -> dict[str, str]:
        raise NotImplementedError("Not implemented")

    async def verify_email(
        self, email: str, verification_code: str
    ) -> dict[str, str]:
        raise NotImplementedError("Not implemented")

    async def send_verification_email(
        self, email: str, user: Optional[User] = None
    ) -> tuple[str, datetime]:
        raise NotImplementedError("Not implemented")

    async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
        raise NotImplementedError("Not implemented")

    async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
        raise NotImplementedError("Not implemented")

    async def oauth_callback_handler(
        self, provider: str, oauth_id: str, email: str
    ) -> dict[str, Token]:
        raise NotImplementedError("Not implemented")