diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/auth | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/auth')
5 files changed, 1260 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/auth/__init__.py new file mode 100644 index 00000000..9f116ffa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/providers/auth/__init__.py @@ -0,0 +1,11 @@ +from .clerk import ClerkAuthProvider +from .jwt import JwtAuthProvider +from .r2r_auth import R2RAuthProvider +from .supabase import SupabaseAuthProvider + +__all__ = [ + "R2RAuthProvider", + "SupabaseAuthProvider", + "JwtAuthProvider", + "ClerkAuthProvider", +] diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/clerk.py b/.venv/lib/python3.12/site-packages/core/providers/auth/clerk.py new file mode 100644 index 00000000..0db665e0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/providers/auth/clerk.py @@ -0,0 +1,133 @@ +import logging +import os +from datetime import datetime + +from core.base import ( + AuthConfig, + CryptoProvider, + EmailProvider, + R2RException, + TokenData, +) + +from ..database import PostgresDatabaseProvider +from .jwt import JwtAuthProvider + +logger = logging.getLogger(__name__) + + +class ClerkAuthProvider(JwtAuthProvider): + """ + ClerkAuthProvider extends JwtAuthProvider to support token verification with Clerk. + It uses Clerk's SDK to verify the JWT token and extract user information. + """ + + def __init__( + self, + config: AuthConfig, + crypto_provider: CryptoProvider, + database_provider: PostgresDatabaseProvider, + email_provider: EmailProvider, + ): + super().__init__( + config=config, + crypto_provider=crypto_provider, + database_provider=database_provider, + email_provider=email_provider, + ) + try: + from clerk_backend_api.jwks_helpers.verifytoken import ( + VerifyTokenOptions, + verify_token, + ) + + self.verify_token = verify_token + self.VerifyTokenOptions = VerifyTokenOptions + except ImportError as e: + raise R2RException( + status_code=500, + message="Clerk SDK is not installed. Run `pip install clerk-backend-api`", + ) from e + + async def decode_token(self, token: str) -> TokenData: + """ + Decode and verify the JWT token using Clerk's verify_token function. + + Args: + token: The JWT token to decode + + Returns: + TokenData: The decoded token data with user information + + Raises: + R2RException: If the token is invalid or verification fails + """ + clerk_secret_key = os.getenv("CLERK_SECRET_KEY") + if not clerk_secret_key: + raise R2RException( + status_code=500, + message="CLERK_SECRET_KEY environment variable is not set", + ) + + try: + # Configure verification options + options = self.VerifyTokenOptions( + secret_key=clerk_secret_key, + # Optional: specify audience if needed + # audience="your-audience", + # Optional: specify authorized parties if needed + # authorized_parties=["https://your-domain.com"] + ) + + # Verify the token using Clerk's SDK + payload = self.verify_token(token, options) + + # Check for the expected claims in the token payload + if not payload.get("sub") or not payload.get("email"): + raise R2RException( + status_code=401, + message="Invalid token: missing required claims", + ) + + # Create user in database if not exists + try: + await self.database_provider.users_handler.get_user_by_email( + payload.get("email") + ) + # TODO do we want to update user info here based on what's in the token? + except Exception: + # user doesn't exist, create in db + logger.debug(f"Creating new user: {payload.get('email')}") + try: + # Construct name from first_name and last_name if available + first_name = payload.get("first_name", "") + last_name = payload.get("last_name", "") + name = payload.get("name") + + # If name not directly provided, try to build it from first and last names + if not name and (first_name or last_name): + name = f"{first_name} {last_name}".strip() + + await self.database_provider.users_handler.create_user( + email=payload.get("email"), + account_type="external", + name=name, + ) + except Exception as e: + logger.error(f"Error creating user: {e}") + raise R2RException( + status_code=500, message="Failed to create user" + ) from e + + # Return the token data + return TokenData( + email=payload.get("email"), + token_type="bearer", + exp=datetime.fromtimestamp(payload.get("exp")), + ) + + except Exception as e: + logger.info(f"Clerk token verification failed: {e}") + raise R2RException( + status_code=401, message="Invalid token", detail=str(e) + ) from e diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/jwt.py b/.venv/lib/python3.12/site-packages/core/providers/auth/jwt.py new file mode 100644 index 00000000..08f85e6d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/providers/auth/jwt.py @@ -0,0 +1,166 @@ +import logging +import os +from datetime import datetime +from typing import Optional +from uuid import UUID + +import jwt +from fastapi import Depends + +from core.base import ( + AuthConfig, + AuthProvider, + CryptoProvider, + EmailProvider, + R2RException, + Token, + TokenData, +) +from core.base.api.models import User + +from ..database import PostgresDatabaseProvider + +logger = logging.getLogger() + + +class JwtAuthProvider(AuthProvider): + def __init__( + self, + config: AuthConfig, + crypto_provider: CryptoProvider, + database_provider: PostgresDatabaseProvider, + email_provider: EmailProvider, + ): + super().__init__( + config, crypto_provider, database_provider, email_provider + ) + + async def login(self, email: str, password: str) -> dict[str, Token]: + raise NotImplementedError("Not implemented") + + async def oauth_callback(self, code: str) -> dict[str, Token]: + raise NotImplementedError("Not implemented") + + async def user(self, token: str) -> User: + raise NotImplementedError("Not implemented") + + async def change_password( + self, user: User, current_password: str, new_password: str + ) -> dict[str, str]: + raise NotImplementedError("Not implemented") + + async def confirm_password_reset( + self, reset_token: str, new_password: str + ) -> dict[str, str]: + raise NotImplementedError("Not implemented") + + def create_access_token(self, data: dict) -> str: + raise NotImplementedError("Not implemented") + + def create_refresh_token(self, data: dict) -> str: + raise NotImplementedError("Not implemented") + + async def decode_token(self, token: str) -> TokenData: + # use JWT library to validate and decode JWT token + jwtSecret = os.getenv("JWT_SECRET") + if jwtSecret is None: + raise R2RException( + status_code=500, + message="JWT_SECRET environment variable is not set", + ) + try: + user = jwt.decode(token, jwtSecret, algorithms=["HS256"]) + except Exception as e: + logger.info(f"JWT verification failed: {e}") + raise R2RException( + status_code=401, message="Invalid JWT token", detail=e + ) from e + if user: + # Create user in database if not exists + try: + await self.database_provider.users_handler.get_user_by_email( + user.get("email") + ) + # TODO do we want to update user info here based on what's in the token? + except Exception: + # user doesn't exist, create in db + logger.debug(f"Creating new user: {user.get('email')}") + try: + await self.database_provider.users_handler.create_user( + email=user.get("email"), + account_type="external", + name=user.get("name"), + ) + except Exception as e: + logger.error(f"Error creating user: {e}") + raise R2RException( + status_code=500, message="Failed to create user" + ) from e + return TokenData( + email=user.get("email"), + token_type="bearer", + exp=user.get("exp"), + ) + else: + raise R2RException(status_code=401, message="Invalid JWT token") + + async def refresh_access_token( + self, refresh_token: str + ) -> dict[str, Token]: + raise NotImplementedError("Not implemented") + + def get_current_active_user( + self, current_user: User = Depends(user) + ) -> User: + # Check if user is active + if not current_user.is_active: + raise R2RException(status_code=400, message="Inactive user") + return current_user + + async def logout(self, token: str) -> dict[str, str]: + raise NotImplementedError("Not implemented") + + async def register( + self, + email: str, + password: str, + name: Optional[str] = None, + bio: Optional[str] = None, + profile_picture: Optional[str] = None, + ) -> User: # type: ignore + raise NotImplementedError("Not implemented") + + async def request_password_reset(self, email: str) -> dict[str, str]: + raise NotImplementedError("Not implemented") + + async def send_reset_email(self, email: str) -> dict[str, str]: + raise NotImplementedError("Not implemented") + + async def create_user_api_key( + self, + user_id: UUID, + name: Optional[str] = None, + description: Optional[str] = None, + ) -> dict[str, str]: + raise NotImplementedError("Not implemented") + + async def verify_email( + self, email: str, verification_code: str + ) -> dict[str, str]: + raise NotImplementedError("Not implemented") + + async def send_verification_email( + self, email: str, user: Optional[User] = None + ) -> tuple[str, datetime]: + raise NotImplementedError("Not implemented") + + async def list_user_api_keys(self, user_id: UUID) -> list[dict]: + raise NotImplementedError("Not implemented") + + async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool: + raise NotImplementedError("Not implemented") + + async def oauth_callback_handler( + self, provider: str, oauth_id: str, email: str + ) -> dict[str, Token]: + raise NotImplementedError("Not implemented") diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/r2r_auth.py b/.venv/lib/python3.12/site-packages/core/providers/auth/r2r_auth.py new file mode 100644 index 00000000..762884ce --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/providers/auth/r2r_auth.py @@ -0,0 +1,701 @@ +import logging +import os +from datetime import datetime, timedelta, timezone +from typing import Optional +from uuid import UUID + +from fastapi import Depends, HTTPException +from fastapi.security import OAuth2PasswordBearer + +from core.base import ( + AuthConfig, + AuthProvider, + CollectionResponse, + CryptoProvider, + EmailProvider, + R2RException, + Token, + TokenData, +) +from core.base.api.models import User + +from ..database import PostgresDatabaseProvider + +DEFAULT_ACCESS_LIFETIME_IN_MINUTES = 3600 +DEFAULT_REFRESH_LIFETIME_IN_DAYS = 7 + +logger = logging.getLogger() +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + +def normalize_email(email: str) -> str: + """Normalizes an email address by converting it to lowercase. This ensures + consistent email handling throughout the application. + + Args: + email: The email address to normalize + + Returns: + The normalized (lowercase) email address + """ + return email.lower() if email else "" + + +class R2RAuthProvider(AuthProvider): + def __init__( + self, + config: AuthConfig, + crypto_provider: CryptoProvider, + database_provider: PostgresDatabaseProvider, + email_provider: EmailProvider, + ): + super().__init__( + config, crypto_provider, database_provider, email_provider + ) + self.database_provider: PostgresDatabaseProvider = database_provider + logger.debug(f"Initializing R2RAuthProvider with config: {config}") + + # We no longer use a local secret_key or defaults here. + # All key handling is done in the crypto_provider. + self.access_token_lifetime_in_minutes = ( + config.access_token_lifetime_in_minutes + or os.getenv("R2R_ACCESS_LIFE_IN_MINUTES") + or DEFAULT_ACCESS_LIFETIME_IN_MINUTES + ) + self.refresh_token_lifetime_in_days = ( + config.refresh_token_lifetime_in_days + or os.getenv("R2R_REFRESH_LIFE_IN_DAYS") + or DEFAULT_REFRESH_LIFETIME_IN_DAYS + ) + self.config: AuthConfig = config + + async def initialize(self): + try: + user = await self.register( + email=normalize_email(self.admin_email), + password=self.admin_password, + is_superuser=True, + ) + await self.database_provider.users_handler.mark_user_as_superuser( + id=user.id + ) + except R2RException: + logger.info("Default admin user already exists.") + + def create_access_token(self, data: dict) -> str: + expire = datetime.now(timezone.utc) + timedelta( + minutes=float(self.access_token_lifetime_in_minutes) + ) + # Add token_type and pass data/expiry to crypto_provider + data_with_type = {**data, "token_type": "access"} + return self.crypto_provider.generate_secure_token( + data=data_with_type, + expiry=expire, + ) + + def create_refresh_token(self, data: dict) -> str: + expire = datetime.now(timezone.utc) + timedelta( + days=float(self.refresh_token_lifetime_in_days) + ) + data_with_type = {**data, "token_type": "refresh"} + return self.crypto_provider.generate_secure_token( + data=data_with_type, + expiry=expire, + ) + + async def decode_token(self, token: str) -> TokenData: + if "token=" in token: + token = token.split("token=")[1] + if "&tokenType=refresh" in token: + token = token.split("&tokenType=refresh")[0] + # First, check if the token is blacklisted + if await self.database_provider.token_handler.is_token_blacklisted( + token=token + ): + raise R2RException( + status_code=401, message="Token has been invalidated" + ) + + # Verify token using crypto_provider + payload = self.crypto_provider.verify_secure_token(token=token) + if payload is None: + raise R2RException( + status_code=401, message="Invalid or expired token" + ) + + email = payload.get("sub") + token_type = payload.get("token_type") + exp = payload.get("exp") + + if email is None or token_type is None or exp is None: + raise R2RException(status_code=401, message="Invalid token claims") + + email_str: str = email + token_type_str: str = token_type + exp_float: float = exp + + exp_datetime = datetime.fromtimestamp(exp_float, tz=timezone.utc) + if exp_datetime < datetime.now(timezone.utc): + raise R2RException(status_code=401, message="Token has expired") + + return TokenData( + email=normalize_email(email_str), + token_type=token_type_str, + exp=exp_datetime, + ) + + async def authenticate_api_key(self, api_key: str) -> User: + """Authenticate using an API key of the form "public_key.raw_key". + + Returns a User if successful, or raises R2RException if not. + """ + try: + key_id, raw_key = api_key.split(".", 1) + except ValueError as e: + raise R2RException( + status_code=401, message="Invalid API key format" + ) from e + + key_record = ( + await self.database_provider.users_handler.get_api_key_record( + key_id=key_id + ) + ) + if not key_record: + raise R2RException(status_code=401, message="Invalid API key") + + if not self.crypto_provider.verify_api_key( + raw_api_key=raw_key, hashed_key=key_record["hashed_key"] + ): + raise R2RException(status_code=401, message="Invalid API key") + + user = await self.database_provider.users_handler.get_user_by_id( + id=key_record["user_id"] + ) + if not user.is_active: + raise R2RException( + status_code=401, message="User account is inactive" + ) + + return user + + async def user(self, token: str = Depends(oauth2_scheme)) -> User: + """Attempt to authenticate via JWT first, then fallback to API key.""" + # Try JWT auth + try: + token_data = await self.decode_token(token=token) + if not token_data.email: + raise R2RException( + status_code=401, message="Could not validate credentials" + ) + user = ( + await self.database_provider.users_handler.get_user_by_email( + email=normalize_email(token_data.email) + ) + ) + if user is None: + raise R2RException( + status_code=401, + message="Invalid authentication credentials", + ) + return user + except R2RException: + # If JWT fails, try API key auth + # OAuth2PasswordBearer provides token as "Bearer xxx", strip it if needed + token = token.removeprefix("Bearer ") + return await self.authenticate_api_key(api_key=token) + + def get_current_active_user( + self, current_user: User = Depends(user) + ) -> User: + if not current_user.is_active: + raise R2RException(status_code=400, message="Inactive user") + return current_user + + async def register( + self, + email: str, + password: Optional[str] = None, + is_superuser: bool = False, + account_type: str = "password", + github_id: Optional[str] = None, + google_id: Optional[str] = None, + name: Optional[str] = None, + bio: Optional[str] = None, + profile_picture: Optional[str] = None, + ) -> User: + if account_type == "password": + if not password: + raise R2RException( + status_code=400, + message="Password is required for password accounts", + ) + else: + if github_id and google_id: + raise R2RException( + status_code=400, + message="Cannot register OAuth with both GitHub and Google IDs", + ) + if not github_id and not google_id: + raise R2RException( + status_code=400, + message="Invalid OAuth specification without GitHub or Google ID", + ) + new_user = await self.database_provider.users_handler.create_user( + email=normalize_email(email), + password=password, + is_superuser=is_superuser, + account_type=account_type, + github_id=github_id, + google_id=google_id, + name=name, + bio=bio, + profile_picture=profile_picture, + ) + default_collection: CollectionResponse = ( + await self.database_provider.collections_handler.create_collection( + owner_id=new_user.id, + ) + ) + await self.database_provider.graphs_handler.create( + collection_id=default_collection.id, + name=default_collection.name, + description=default_collection.description, + ) + + await self.database_provider.users_handler.add_user_to_collection( + new_user.id, default_collection.id + ) + + new_user = await self.database_provider.users_handler.get_user_by_id( + new_user.id + ) + + if self.config.require_email_verification: + verification_code, _ = await self.send_verification_email( + email=normalize_email(email), user=new_user + ) + else: + expiry = datetime.now(timezone.utc) + timedelta(hours=366 * 10) + await self.database_provider.users_handler.store_verification_code( + id=new_user.id, + verification_code=str(-1), + expiry=expiry, + ) + await self.database_provider.users_handler.mark_user_as_verified( + id=new_user.id + ) + + return new_user + + async def send_verification_email( + self, email: str, user: Optional[User] = None + ) -> tuple[str, datetime]: + if user is None: + user = ( + await self.database_provider.users_handler.get_user_by_email( + email=normalize_email(email) + ) + ) + if not user: + raise R2RException(status_code=404, message="User not found") + + verification_code = self.crypto_provider.generate_verification_code() + expiry = datetime.now(timezone.utc) + timedelta(hours=24) + + await self.database_provider.users_handler.store_verification_code( + id=user.id, + verification_code=verification_code, + expiry=expiry, + ) + + if hasattr(user, "verification_code_expiry"): + user.verification_code_expiry = expiry + + first_name = ( + user.name.split(" ")[0] if user.name else email.split("@")[0] + ) + + await self.email_provider.send_verification_email( + to_email=user.email, + verification_code=verification_code, + dynamic_template_data={"first_name": first_name}, + ) + + return verification_code, expiry + + async def verify_email( + self, email: str, verification_code: str + ) -> dict[str, str]: + user_id = await self.database_provider.users_handler.get_user_id_by_verification_code( + verification_code=verification_code + ) + await self.database_provider.users_handler.mark_user_as_verified( + id=user_id + ) + await self.database_provider.users_handler.remove_verification_code( + verification_code=verification_code + ) + return {"message": "Email verified successfully"} + + async def login(self, email: str, password: str) -> dict[str, Token]: + logger.debug(f"Attempting login for email: {email}") + user = await self.database_provider.users_handler.get_user_by_email( + email=normalize_email(email) + ) + + if user.account_type != "password": + logger.warning( + f"Password login not allowed for {user.account_type} accounts: {email}" + ) + raise R2RException( + status_code=401, + message=f"This account is configured for {user.account_type} login, not password.", + ) + + logger.debug(f"User found: {user}") + + if not isinstance(user.hashed_password, str): + logger.error( + f"Invalid hashed_password type: {type(user.hashed_password)}" + ) + raise HTTPException( + status_code=500, + detail="Invalid password hash in database", + ) + + try: + password_verified = self.crypto_provider.verify_password( + plain_password=password, + hashed_password=user.hashed_password, + ) + except Exception as e: + logger.error(f"Error during password verification: {str(e)}") + raise HTTPException( + status_code=500, + detail="Error during password verification", + ) from e + + if not password_verified: + logger.warning(f"Invalid password for user: {email}") + raise R2RException( + status_code=401, message="Incorrect email or password" + ) + + if not user.is_verified and self.config.require_email_verification: + logger.warning(f"Unverified user attempted login: {email}") + raise R2RException(status_code=401, message="Email not verified") + + access_token = self.create_access_token( + data={"sub": normalize_email(user.email)} + ) + refresh_token = self.create_refresh_token( + data={"sub": normalize_email(user.email)} + ) + return { + "access_token": Token(token=access_token, token_type="access"), + "refresh_token": Token(token=refresh_token, token_type="refresh"), + } + + async def refresh_access_token( + self, refresh_token: str + ) -> dict[str, Token]: + token_data = await self.decode_token(refresh_token) + if token_data.token_type != "refresh": + raise R2RException( + status_code=401, message="Invalid refresh token" + ) + + # Invalidate the old refresh token and create a new one + await self.database_provider.token_handler.blacklist_token( + token=refresh_token + ) + + new_access_token = self.create_access_token( + data={"sub": normalize_email(token_data.email)} + ) + new_refresh_token = self.create_refresh_token( + data={"sub": normalize_email(token_data.email)} + ) + return { + "access_token": Token(token=new_access_token, token_type="access"), + "refresh_token": Token( + token=new_refresh_token, token_type="refresh" + ), + } + + async def change_password( + self, user: User, current_password: str, new_password: str + ) -> dict[str, str]: + if not isinstance(user.hashed_password, str): + logger.error( + f"Invalid hashed_password type: {type(user.hashed_password)}" + ) + raise HTTPException( + status_code=500, + detail="Invalid password hash in database", + ) + + if not self.crypto_provider.verify_password( + plain_password=current_password, + hashed_password=user.hashed_password, + ): + raise R2RException( + status_code=400, message="Incorrect current password" + ) + + hashed_new_password = self.crypto_provider.get_password_hash( + password=new_password + ) + await self.database_provider.users_handler.update_user_password( + id=user.id, + new_hashed_password=hashed_new_password, + ) + try: + await self.email_provider.send_password_changed_email( + to_email=normalize_email(user.email), + dynamic_template_data={ + "first_name": ( + user.name.split(" ")[0] or "User" + if user.name + else "User" + ) + }, + ) + except Exception as e: + logger.error( + f"Failed to send password change notification: {str(e)}" + ) + + return {"message": "Password changed successfully"} + + async def request_password_reset(self, email: str) -> dict[str, str]: + try: + user = ( + await self.database_provider.users_handler.get_user_by_email( + email=normalize_email(email) + ) + ) + + reset_token = self.crypto_provider.generate_verification_code() + expiry = datetime.now(timezone.utc) + timedelta(hours=1) + await self.database_provider.users_handler.store_reset_token( + id=user.id, + reset_token=reset_token, + expiry=expiry, + ) + + first_name = ( + user.name.split(" ")[0] if user.name else email.split("@")[0] + ) + await self.email_provider.send_password_reset_email( + to_email=normalize_email(email), + reset_token=reset_token, + dynamic_template_data={"first_name": first_name}, + ) + + return { + "message": "If the email exists, a reset link has been sent" + } + except R2RException as e: + if e.status_code == 404: + # User doesn't exist; return a success message anyway + return { + "message": "If the email exists, a reset link has been sent" + } + else: + raise + + async def confirm_password_reset( + self, reset_token: str, new_password: str + ) -> dict[str, str]: + user_id = await self.database_provider.users_handler.get_user_id_by_reset_token( + reset_token=reset_token + ) + if not user_id: + raise R2RException( + status_code=400, message="Invalid or expired reset token" + ) + + hashed_new_password = self.crypto_provider.get_password_hash( + password=new_password + ) + await self.database_provider.users_handler.update_user_password( + id=user_id, + new_hashed_password=hashed_new_password, + ) + await self.database_provider.users_handler.remove_reset_token( + id=user_id + ) + # Get the user information + user = await self.database_provider.users_handler.get_user_by_id( + id=user_id + ) + + try: + await self.email_provider.send_password_changed_email( + to_email=normalize_email(user.email), + dynamic_template_data={ + "first_name": ( + user.name.split(" ")[0] or "User" + if user.name + else "User" + ) + }, + ) + except Exception as e: + logger.error( + f"Failed to send password change notification: {str(e)}" + ) + + return {"message": "Password reset successfully"} + + async def logout(self, token: str) -> dict[str, str]: + await self.database_provider.token_handler.blacklist_token(token=token) + return {"message": "Logged out successfully"} + + async def clean_expired_blacklisted_tokens(self): + await self.database_provider.token_handler.clean_expired_blacklisted_tokens() + + async def send_reset_email(self, email: str) -> dict: + verification_code, expiry = await self.send_verification_email( + email=normalize_email(email) + ) + + return { + "verification_code": verification_code, + "expiry": expiry, + "message": f"Verification email sent successfully to {email}", + } + + async def create_user_api_key( + self, + user_id: UUID, + name: Optional[str] = None, + description: Optional[str] = None, + ) -> dict[str, str]: + key_id, raw_api_key = self.crypto_provider.generate_api_key() + hashed_key = self.crypto_provider.hash_api_key(raw_api_key) + + api_key_uuid = ( + await self.database_provider.users_handler.store_user_api_key( + user_id=user_id, + key_id=key_id, + hashed_key=hashed_key, + name=name, + description=description, + ) + ) + + return { + "api_key": f"{key_id}.{raw_api_key}", + "key_id": str(api_key_uuid), + "public_key": key_id, + "name": name or "", + } + + async def list_user_api_keys(self, user_id: UUID) -> list[dict]: + return await self.database_provider.users_handler.get_user_api_keys( + user_id=user_id + ) + + async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool: + return await self.database_provider.users_handler.delete_api_key( + user_id=user_id, + key_id=key_id, + ) + + async def rename_api_key( + self, user_id: UUID, key_id: UUID, new_name: str + ) -> bool: + return await self.database_provider.users_handler.update_api_key_name( + user_id=user_id, + key_id=key_id, + name=new_name, + ) + + async def oauth_callback_handler( + self, provider: str, oauth_id: str, email: str + ) -> dict[str, Token]: + """Handles a login/registration flow for OAuth providers (e.g., Google + or GitHub). + + :param provider: "google" or "github" + :param oauth_id: The unique ID from the OAuth provider (e.g. Google's + 'sub') + :param email: The user's email from the provider, if available. + :return: dict with access_token and refresh_token + """ + # 1) Attempt to find user by google_id or github_id, or by email + # The logic depends on your preference. We'll assume "google" => google_id, etc. + try: + if provider == "google": + try: + user = await self.database_provider.users_handler.get_user_by_email( + normalize_email(email) + ) + # If user found, check if user.google_id matches or is null. If null, update it + if user and not user.google_id: + raise R2RException( + status_code=401, + message="User already exists and is not linked to Google account", + ) + except Exception: + # Create new user + user = await self.register( + email=normalize_email(email) + or f"{oauth_id}@google_oauth.fake", # fallback + password=None, # no password + account_type="oauth", + google_id=oauth_id, + ) + elif provider == "github": + try: + user = await self.database_provider.users_handler.get_user_by_email( + normalize_email(email) + ) + # If user found, check if user.google_id matches or is null. If null, update it + if user and not user.github_id: + raise R2RException( + status_code=401, + message="User already exists and is not linked to Github account", + ) + except Exception: + # Create new user + user = await self.register( + email=normalize_email(email) + or f"{oauth_id}@github_oauth.fake", # fallback + password=None, # no password + account_type="oauth", + github_id=oauth_id, + ) + # else handle other providers + + except R2RException: + # If no user found or creation fails + raise R2RException( + status_code=401, message="Could not create or fetch user" + ) from None + + # If user is inactive, etc. + if not user.is_active: + raise R2RException( + status_code=401, message="User account is inactive" + ) + + # Possibly mark user as verified if you trust the OAuth provider's email + user.is_verified = True + await self.database_provider.users_handler.update_user(user) + + # 2) Generate tokens + access_token = self.create_access_token( + data={"sub": normalize_email(user.email)} + ) + refresh_token = self.create_refresh_token( + data={"sub": normalize_email(user.email)} + ) + + return { + "access_token": Token(token=access_token, token_type="access"), + "refresh_token": Token(token=refresh_token, token_type="refresh"), + } diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/supabase.py b/.venv/lib/python3.12/site-packages/core/providers/auth/supabase.py new file mode 100644 index 00000000..5fc0e0bf --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/providers/auth/supabase.py @@ -0,0 +1,249 @@ +import logging +import os +from datetime import datetime +from typing import Optional +from uuid import UUID + +from fastapi import Depends, HTTPException +from fastapi.security import OAuth2PasswordBearer +from supabase import Client, create_client + +from core.base import ( + AuthConfig, + AuthProvider, + CryptoProvider, + EmailProvider, + R2RException, + Token, + TokenData, +) +from core.base.api.models import User + +from ..database import PostgresDatabaseProvider + +logger = logging.getLogger() + +logger = logging.getLogger() +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + +class SupabaseAuthProvider(AuthProvider): + def __init__( + self, + config: AuthConfig, + crypto_provider: CryptoProvider, + database_provider: PostgresDatabaseProvider, + email_provider: EmailProvider, + ): + super().__init__( + config, crypto_provider, database_provider, email_provider + ) + self.supabase_url = config.extra_fields.get( + "supabase_url", None + ) or os.getenv("SUPABASE_URL") + self.supabase_key = config.extra_fields.get( + "supabase_key", None + ) or os.getenv("SUPABASE_KEY") + if not self.supabase_url or not self.supabase_key: + raise HTTPException( + status_code=500, + detail="Supabase URL and key must be provided", + ) + self.supabase: Client = create_client( + self.supabase_url, self.supabase_key + ) + + async def initialize(self): + # No initialization needed for Supabase + pass + + def create_access_token(self, data: dict) -> str: + raise NotImplementedError( + "create_access_token is not used with Supabase authentication" + ) + + def create_refresh_token(self, data: dict) -> str: + raise NotImplementedError( + "create_refresh_token is not used with Supabase authentication" + ) + + async def decode_token(self, token: str) -> TokenData: + raise NotImplementedError( + "decode_token is not used with Supabase authentication" + ) + + async def register( + self, + email: str, + password: str, + name: Optional[str] = None, + bio: Optional[str] = None, + profile_picture: Optional[str] = None, + ) -> User: # type: ignore + # Use Supabase client to create a new user + + if self.supabase.auth.sign_up(email=email, password=password): + raise R2RException( + status_code=400, + message="Supabase provider implementation is still under construction", + ) + else: + raise R2RException( + status_code=400, message="User registration failed" + ) + + async def send_verification_email( + self, email: str, user: Optional[User] = None + ) -> tuple[str, datetime]: + raise NotImplementedError( + "send_verification_email is not used with Supabase" + ) + + async def verify_email( + self, email: str, verification_code: str + ) -> dict[str, str]: + # Use Supabase client to verify email + if self.supabase.auth.verify_email(email, verification_code): + return {"message": "Email verified successfully"} + else: + raise R2RException( + status_code=400, message="Invalid or expired verification code" + ) + + async def login(self, email: str, password: str) -> dict[str, Token]: + # Use Supabase client to authenticate user and get tokens + if response := self.supabase.auth.sign_in( + email=email, password=password + ): + access_token = response.access_token + refresh_token = response.refresh_token + return { + "access_token": Token(token=access_token, token_type="access"), + "refresh_token": Token( + token=refresh_token, token_type="refresh" + ), + } + else: + raise R2RException( + status_code=401, message="Invalid email or password" + ) + + async def refresh_access_token( + self, refresh_token: str + ) -> dict[str, Token]: + # Use Supabase client to refresh access token + if response := self.supabase.auth.refresh_access_token(refresh_token): + new_access_token = response.access_token + new_refresh_token = response.refresh_token + return { + "access_token": Token( + token=new_access_token, token_type="access" + ), + "refresh_token": Token( + token=new_refresh_token, token_type="refresh" + ), + } + else: + raise R2RException( + status_code=401, message="Invalid refresh token" + ) + + async def user(self, token: str = Depends(oauth2_scheme)) -> User: + # Use Supabase client to get user details from token + if user := self.supabase.auth.get_user(token).user: + return User( + id=user.id, + email=user.email, + is_active=True, # Assuming active if exists in Supabase + is_superuser=False, # Default to False unless explicitly set + created_at=user.created_at, + updated_at=user.updated_at, + is_verified=user.email_confirmed_at is not None, + name=user.user_metadata.get("full_name"), + # Set other optional fields if available in user metadata + ) + + else: + raise R2RException(status_code=401, message="Invalid token") + + def get_current_active_user( + self, current_user: User = Depends(user) + ) -> User: + # Check if user is active + if not current_user.is_active: + raise R2RException(status_code=400, message="Inactive user") + return current_user + + async def change_password( + self, user: User, current_password: str, new_password: str + ) -> dict[str, str]: + # Use Supabase client to update user password + if self.supabase.auth.update(user.id, {"password": new_password}): + return {"message": "Password changed successfully"} + else: + raise R2RException( + status_code=400, message="Failed to change password" + ) + + async def request_password_reset(self, email: str) -> dict[str, str]: + # Use Supabase client to send password reset email + if self.supabase.auth.send_password_reset_email(email): + return { + "message": "If the email exists, a reset link has been sent" + } + else: + raise R2RException( + status_code=400, message="Failed to send password reset email" + ) + + async def confirm_password_reset( + self, reset_token: str, new_password: str + ) -> dict[str, str]: + # Use Supabase client to reset password with token + if self.supabase.auth.reset_password_for_email( + reset_token, new_password + ): + return {"message": "Password reset successfully"} + else: + raise R2RException( + status_code=400, message="Invalid or expired reset token" + ) + + async def logout(self, token: str) -> dict[str, str]: + # Use Supabase client to logout user and revoke token + self.supabase.auth.sign_out(token) + return {"message": "Logged out successfully"} + + async def clean_expired_blacklisted_tokens(self): + # Not applicable for Supabase, tokens are managed by Supabase + pass + + async def send_reset_email(self, email: str) -> dict[str, str]: + raise NotImplementedError("send_reset_email is not used with Supabase") + + async def create_user_api_key( + self, + user_id: UUID, + name: Optional[str] = None, + description: Optional[str] = None, + ) -> dict[str, str]: + raise NotImplementedError( + "API key management is not supported with Supabase authentication" + ) + + async def list_user_api_keys(self, user_id: UUID) -> list[dict]: + raise NotImplementedError( + "API key management is not supported with Supabase authentication" + ) + + async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool: + raise NotImplementedError( + "API key management is not supported with Supabase authentication" + ) + + async def oauth_callback_handler( + self, provider: str, oauth_id: str, email: str + ) -> dict[str, Token]: + raise NotImplementedError( + "API key management is not supported with Supabase authentication" + ) |