aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/providers/auth
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/auth')
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/auth/__init__.py11
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/auth/clerk.py133
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/auth/jwt.py166
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/auth/r2r_auth.py701
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/auth/supabase.py249
5 files changed, 1260 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/__init__.py b/.venv/lib/python3.12/site-packages/core/providers/auth/__init__.py
new file mode 100644
index 00000000..9f116ffa
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/auth/__init__.py
@@ -0,0 +1,11 @@
+from .clerk import ClerkAuthProvider
+from .jwt import JwtAuthProvider
+from .r2r_auth import R2RAuthProvider
+from .supabase import SupabaseAuthProvider
+
+__all__ = [
+ "R2RAuthProvider",
+ "SupabaseAuthProvider",
+ "JwtAuthProvider",
+ "ClerkAuthProvider",
+]
diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/clerk.py b/.venv/lib/python3.12/site-packages/core/providers/auth/clerk.py
new file mode 100644
index 00000000..0db665e0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/auth/clerk.py
@@ -0,0 +1,133 @@
+import logging
+import os
+from datetime import datetime
+
+from core.base import (
+ AuthConfig,
+ CryptoProvider,
+ EmailProvider,
+ R2RException,
+ TokenData,
+)
+
+from ..database import PostgresDatabaseProvider
+from .jwt import JwtAuthProvider
+
+logger = logging.getLogger(__name__)
+
+
+class ClerkAuthProvider(JwtAuthProvider):
+ """
+ ClerkAuthProvider extends JwtAuthProvider to support token verification with Clerk.
+ It uses Clerk's SDK to verify the JWT token and extract user information.
+ """
+
+ def __init__(
+ self,
+ config: AuthConfig,
+ crypto_provider: CryptoProvider,
+ database_provider: PostgresDatabaseProvider,
+ email_provider: EmailProvider,
+ ):
+ super().__init__(
+ config=config,
+ crypto_provider=crypto_provider,
+ database_provider=database_provider,
+ email_provider=email_provider,
+ )
+ try:
+ from clerk_backend_api.jwks_helpers.verifytoken import (
+ VerifyTokenOptions,
+ verify_token,
+ )
+
+ self.verify_token = verify_token
+ self.VerifyTokenOptions = VerifyTokenOptions
+ except ImportError as e:
+ raise R2RException(
+ status_code=500,
+ message="Clerk SDK is not installed. Run `pip install clerk-backend-api`",
+ ) from e
+
+ async def decode_token(self, token: str) -> TokenData:
+ """
+ Decode and verify the JWT token using Clerk's verify_token function.
+
+ Args:
+ token: The JWT token to decode
+
+ Returns:
+ TokenData: The decoded token data with user information
+
+ Raises:
+ R2RException: If the token is invalid or verification fails
+ """
+ clerk_secret_key = os.getenv("CLERK_SECRET_KEY")
+ if not clerk_secret_key:
+ raise R2RException(
+ status_code=500,
+ message="CLERK_SECRET_KEY environment variable is not set",
+ )
+
+ try:
+ # Configure verification options
+ options = self.VerifyTokenOptions(
+ secret_key=clerk_secret_key,
+ # Optional: specify audience if needed
+ # audience="your-audience",
+ # Optional: specify authorized parties if needed
+ # authorized_parties=["https://your-domain.com"]
+ )
+
+ # Verify the token using Clerk's SDK
+ payload = self.verify_token(token, options)
+
+ # Check for the expected claims in the token payload
+ if not payload.get("sub") or not payload.get("email"):
+ raise R2RException(
+ status_code=401,
+ message="Invalid token: missing required claims",
+ )
+
+ # Create user in database if not exists
+ try:
+ await self.database_provider.users_handler.get_user_by_email(
+ payload.get("email")
+ )
+ # TODO do we want to update user info here based on what's in the token?
+ except Exception:
+ # user doesn't exist, create in db
+ logger.debug(f"Creating new user: {payload.get('email')}")
+ try:
+ # Construct name from first_name and last_name if available
+ first_name = payload.get("first_name", "")
+ last_name = payload.get("last_name", "")
+ name = payload.get("name")
+
+ # If name not directly provided, try to build it from first and last names
+ if not name and (first_name or last_name):
+ name = f"{first_name} {last_name}".strip()
+
+ await self.database_provider.users_handler.create_user(
+ email=payload.get("email"),
+ account_type="external",
+ name=name,
+ )
+ except Exception as e:
+ logger.error(f"Error creating user: {e}")
+ raise R2RException(
+ status_code=500, message="Failed to create user"
+ ) from e
+
+ # Return the token data
+ return TokenData(
+ email=payload.get("email"),
+ token_type="bearer",
+ exp=datetime.fromtimestamp(payload.get("exp")),
+ )
+
+ except Exception as e:
+ logger.info(f"Clerk token verification failed: {e}")
+ raise R2RException(
+ status_code=401, message="Invalid token", detail=str(e)
+ ) from e
diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/jwt.py b/.venv/lib/python3.12/site-packages/core/providers/auth/jwt.py
new file mode 100644
index 00000000..08f85e6d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/auth/jwt.py
@@ -0,0 +1,166 @@
+import logging
+import os
+from datetime import datetime
+from typing import Optional
+from uuid import UUID
+
+import jwt
+from fastapi import Depends
+
+from core.base import (
+ AuthConfig,
+ AuthProvider,
+ CryptoProvider,
+ EmailProvider,
+ R2RException,
+ Token,
+ TokenData,
+)
+from core.base.api.models import User
+
+from ..database import PostgresDatabaseProvider
+
+logger = logging.getLogger()
+
+
+class JwtAuthProvider(AuthProvider):
+ def __init__(
+ self,
+ config: AuthConfig,
+ crypto_provider: CryptoProvider,
+ database_provider: PostgresDatabaseProvider,
+ email_provider: EmailProvider,
+ ):
+ super().__init__(
+ config, crypto_provider, database_provider, email_provider
+ )
+
+ async def login(self, email: str, password: str) -> dict[str, Token]:
+ raise NotImplementedError("Not implemented")
+
+ async def oauth_callback(self, code: str) -> dict[str, Token]:
+ raise NotImplementedError("Not implemented")
+
+ async def user(self, token: str) -> User:
+ raise NotImplementedError("Not implemented")
+
+ async def change_password(
+ self, user: User, current_password: str, new_password: str
+ ) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ async def confirm_password_reset(
+ self, reset_token: str, new_password: str
+ ) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ def create_access_token(self, data: dict) -> str:
+ raise NotImplementedError("Not implemented")
+
+ def create_refresh_token(self, data: dict) -> str:
+ raise NotImplementedError("Not implemented")
+
+ async def decode_token(self, token: str) -> TokenData:
+ # use JWT library to validate and decode JWT token
+ jwtSecret = os.getenv("JWT_SECRET")
+ if jwtSecret is None:
+ raise R2RException(
+ status_code=500,
+ message="JWT_SECRET environment variable is not set",
+ )
+ try:
+ user = jwt.decode(token, jwtSecret, algorithms=["HS256"])
+ except Exception as e:
+ logger.info(f"JWT verification failed: {e}")
+ raise R2RException(
+ status_code=401, message="Invalid JWT token", detail=e
+ ) from e
+ if user:
+ # Create user in database if not exists
+ try:
+ await self.database_provider.users_handler.get_user_by_email(
+ user.get("email")
+ )
+ # TODO do we want to update user info here based on what's in the token?
+ except Exception:
+ # user doesn't exist, create in db
+ logger.debug(f"Creating new user: {user.get('email')}")
+ try:
+ await self.database_provider.users_handler.create_user(
+ email=user.get("email"),
+ account_type="external",
+ name=user.get("name"),
+ )
+ except Exception as e:
+ logger.error(f"Error creating user: {e}")
+ raise R2RException(
+ status_code=500, message="Failed to create user"
+ ) from e
+ return TokenData(
+ email=user.get("email"),
+ token_type="bearer",
+ exp=user.get("exp"),
+ )
+ else:
+ raise R2RException(status_code=401, message="Invalid JWT token")
+
+ async def refresh_access_token(
+ self, refresh_token: str
+ ) -> dict[str, Token]:
+ raise NotImplementedError("Not implemented")
+
+ def get_current_active_user(
+ self, current_user: User = Depends(user)
+ ) -> User:
+ # Check if user is active
+ if not current_user.is_active:
+ raise R2RException(status_code=400, message="Inactive user")
+ return current_user
+
+ async def logout(self, token: str) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ async def register(
+ self,
+ email: str,
+ password: str,
+ name: Optional[str] = None,
+ bio: Optional[str] = None,
+ profile_picture: Optional[str] = None,
+ ) -> User: # type: ignore
+ raise NotImplementedError("Not implemented")
+
+ async def request_password_reset(self, email: str) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ async def send_reset_email(self, email: str) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ async def create_user_api_key(
+ self,
+ user_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ async def verify_email(
+ self, email: str, verification_code: str
+ ) -> dict[str, str]:
+ raise NotImplementedError("Not implemented")
+
+ async def send_verification_email(
+ self, email: str, user: Optional[User] = None
+ ) -> tuple[str, datetime]:
+ raise NotImplementedError("Not implemented")
+
+ async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
+ raise NotImplementedError("Not implemented")
+
+ async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
+ raise NotImplementedError("Not implemented")
+
+ async def oauth_callback_handler(
+ self, provider: str, oauth_id: str, email: str
+ ) -> dict[str, Token]:
+ raise NotImplementedError("Not implemented")
diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/r2r_auth.py b/.venv/lib/python3.12/site-packages/core/providers/auth/r2r_auth.py
new file mode 100644
index 00000000..762884ce
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/auth/r2r_auth.py
@@ -0,0 +1,701 @@
+import logging
+import os
+from datetime import datetime, timedelta, timezone
+from typing import Optional
+from uuid import UUID
+
+from fastapi import Depends, HTTPException
+from fastapi.security import OAuth2PasswordBearer
+
+from core.base import (
+ AuthConfig,
+ AuthProvider,
+ CollectionResponse,
+ CryptoProvider,
+ EmailProvider,
+ R2RException,
+ Token,
+ TokenData,
+)
+from core.base.api.models import User
+
+from ..database import PostgresDatabaseProvider
+
+DEFAULT_ACCESS_LIFETIME_IN_MINUTES = 3600
+DEFAULT_REFRESH_LIFETIME_IN_DAYS = 7
+
+logger = logging.getLogger()
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+
+def normalize_email(email: str) -> str:
+ """Normalizes an email address by converting it to lowercase. This ensures
+ consistent email handling throughout the application.
+
+ Args:
+ email: The email address to normalize
+
+ Returns:
+ The normalized (lowercase) email address
+ """
+ return email.lower() if email else ""
+
+
+class R2RAuthProvider(AuthProvider):
+ def __init__(
+ self,
+ config: AuthConfig,
+ crypto_provider: CryptoProvider,
+ database_provider: PostgresDatabaseProvider,
+ email_provider: EmailProvider,
+ ):
+ super().__init__(
+ config, crypto_provider, database_provider, email_provider
+ )
+ self.database_provider: PostgresDatabaseProvider = database_provider
+ logger.debug(f"Initializing R2RAuthProvider with config: {config}")
+
+ # We no longer use a local secret_key or defaults here.
+ # All key handling is done in the crypto_provider.
+ self.access_token_lifetime_in_minutes = (
+ config.access_token_lifetime_in_minutes
+ or os.getenv("R2R_ACCESS_LIFE_IN_MINUTES")
+ or DEFAULT_ACCESS_LIFETIME_IN_MINUTES
+ )
+ self.refresh_token_lifetime_in_days = (
+ config.refresh_token_lifetime_in_days
+ or os.getenv("R2R_REFRESH_LIFE_IN_DAYS")
+ or DEFAULT_REFRESH_LIFETIME_IN_DAYS
+ )
+ self.config: AuthConfig = config
+
+ async def initialize(self):
+ try:
+ user = await self.register(
+ email=normalize_email(self.admin_email),
+ password=self.admin_password,
+ is_superuser=True,
+ )
+ await self.database_provider.users_handler.mark_user_as_superuser(
+ id=user.id
+ )
+ except R2RException:
+ logger.info("Default admin user already exists.")
+
+ def create_access_token(self, data: dict) -> str:
+ expire = datetime.now(timezone.utc) + timedelta(
+ minutes=float(self.access_token_lifetime_in_minutes)
+ )
+ # Add token_type and pass data/expiry to crypto_provider
+ data_with_type = {**data, "token_type": "access"}
+ return self.crypto_provider.generate_secure_token(
+ data=data_with_type,
+ expiry=expire,
+ )
+
+ def create_refresh_token(self, data: dict) -> str:
+ expire = datetime.now(timezone.utc) + timedelta(
+ days=float(self.refresh_token_lifetime_in_days)
+ )
+ data_with_type = {**data, "token_type": "refresh"}
+ return self.crypto_provider.generate_secure_token(
+ data=data_with_type,
+ expiry=expire,
+ )
+
+ async def decode_token(self, token: str) -> TokenData:
+ if "token=" in token:
+ token = token.split("token=")[1]
+ if "&tokenType=refresh" in token:
+ token = token.split("&tokenType=refresh")[0]
+ # First, check if the token is blacklisted
+ if await self.database_provider.token_handler.is_token_blacklisted(
+ token=token
+ ):
+ raise R2RException(
+ status_code=401, message="Token has been invalidated"
+ )
+
+ # Verify token using crypto_provider
+ payload = self.crypto_provider.verify_secure_token(token=token)
+ if payload is None:
+ raise R2RException(
+ status_code=401, message="Invalid or expired token"
+ )
+
+ email = payload.get("sub")
+ token_type = payload.get("token_type")
+ exp = payload.get("exp")
+
+ if email is None or token_type is None or exp is None:
+ raise R2RException(status_code=401, message="Invalid token claims")
+
+ email_str: str = email
+ token_type_str: str = token_type
+ exp_float: float = exp
+
+ exp_datetime = datetime.fromtimestamp(exp_float, tz=timezone.utc)
+ if exp_datetime < datetime.now(timezone.utc):
+ raise R2RException(status_code=401, message="Token has expired")
+
+ return TokenData(
+ email=normalize_email(email_str),
+ token_type=token_type_str,
+ exp=exp_datetime,
+ )
+
+ async def authenticate_api_key(self, api_key: str) -> User:
+ """Authenticate using an API key of the form "public_key.raw_key".
+
+ Returns a User if successful, or raises R2RException if not.
+ """
+ try:
+ key_id, raw_key = api_key.split(".", 1)
+ except ValueError as e:
+ raise R2RException(
+ status_code=401, message="Invalid API key format"
+ ) from e
+
+ key_record = (
+ await self.database_provider.users_handler.get_api_key_record(
+ key_id=key_id
+ )
+ )
+ if not key_record:
+ raise R2RException(status_code=401, message="Invalid API key")
+
+ if not self.crypto_provider.verify_api_key(
+ raw_api_key=raw_key, hashed_key=key_record["hashed_key"]
+ ):
+ raise R2RException(status_code=401, message="Invalid API key")
+
+ user = await self.database_provider.users_handler.get_user_by_id(
+ id=key_record["user_id"]
+ )
+ if not user.is_active:
+ raise R2RException(
+ status_code=401, message="User account is inactive"
+ )
+
+ return user
+
+ async def user(self, token: str = Depends(oauth2_scheme)) -> User:
+ """Attempt to authenticate via JWT first, then fallback to API key."""
+ # Try JWT auth
+ try:
+ token_data = await self.decode_token(token=token)
+ if not token_data.email:
+ raise R2RException(
+ status_code=401, message="Could not validate credentials"
+ )
+ user = (
+ await self.database_provider.users_handler.get_user_by_email(
+ email=normalize_email(token_data.email)
+ )
+ )
+ if user is None:
+ raise R2RException(
+ status_code=401,
+ message="Invalid authentication credentials",
+ )
+ return user
+ except R2RException:
+ # If JWT fails, try API key auth
+ # OAuth2PasswordBearer provides token as "Bearer xxx", strip it if needed
+ token = token.removeprefix("Bearer ")
+ return await self.authenticate_api_key(api_key=token)
+
+ def get_current_active_user(
+ self, current_user: User = Depends(user)
+ ) -> User:
+ if not current_user.is_active:
+ raise R2RException(status_code=400, message="Inactive user")
+ return current_user
+
+ async def register(
+ self,
+ email: str,
+ password: Optional[str] = None,
+ is_superuser: bool = False,
+ account_type: str = "password",
+ github_id: Optional[str] = None,
+ google_id: Optional[str] = None,
+ name: Optional[str] = None,
+ bio: Optional[str] = None,
+ profile_picture: Optional[str] = None,
+ ) -> User:
+ if account_type == "password":
+ if not password:
+ raise R2RException(
+ status_code=400,
+ message="Password is required for password accounts",
+ )
+ else:
+ if github_id and google_id:
+ raise R2RException(
+ status_code=400,
+ message="Cannot register OAuth with both GitHub and Google IDs",
+ )
+ if not github_id and not google_id:
+ raise R2RException(
+ status_code=400,
+ message="Invalid OAuth specification without GitHub or Google ID",
+ )
+ new_user = await self.database_provider.users_handler.create_user(
+ email=normalize_email(email),
+ password=password,
+ is_superuser=is_superuser,
+ account_type=account_type,
+ github_id=github_id,
+ google_id=google_id,
+ name=name,
+ bio=bio,
+ profile_picture=profile_picture,
+ )
+ default_collection: CollectionResponse = (
+ await self.database_provider.collections_handler.create_collection(
+ owner_id=new_user.id,
+ )
+ )
+ await self.database_provider.graphs_handler.create(
+ collection_id=default_collection.id,
+ name=default_collection.name,
+ description=default_collection.description,
+ )
+
+ await self.database_provider.users_handler.add_user_to_collection(
+ new_user.id, default_collection.id
+ )
+
+ new_user = await self.database_provider.users_handler.get_user_by_id(
+ new_user.id
+ )
+
+ if self.config.require_email_verification:
+ verification_code, _ = await self.send_verification_email(
+ email=normalize_email(email), user=new_user
+ )
+ else:
+ expiry = datetime.now(timezone.utc) + timedelta(hours=366 * 10)
+ await self.database_provider.users_handler.store_verification_code(
+ id=new_user.id,
+ verification_code=str(-1),
+ expiry=expiry,
+ )
+ await self.database_provider.users_handler.mark_user_as_verified(
+ id=new_user.id
+ )
+
+ return new_user
+
+ async def send_verification_email(
+ self, email: str, user: Optional[User] = None
+ ) -> tuple[str, datetime]:
+ if user is None:
+ user = (
+ await self.database_provider.users_handler.get_user_by_email(
+ email=normalize_email(email)
+ )
+ )
+ if not user:
+ raise R2RException(status_code=404, message="User not found")
+
+ verification_code = self.crypto_provider.generate_verification_code()
+ expiry = datetime.now(timezone.utc) + timedelta(hours=24)
+
+ await self.database_provider.users_handler.store_verification_code(
+ id=user.id,
+ verification_code=verification_code,
+ expiry=expiry,
+ )
+
+ if hasattr(user, "verification_code_expiry"):
+ user.verification_code_expiry = expiry
+
+ first_name = (
+ user.name.split(" ")[0] if user.name else email.split("@")[0]
+ )
+
+ await self.email_provider.send_verification_email(
+ to_email=user.email,
+ verification_code=verification_code,
+ dynamic_template_data={"first_name": first_name},
+ )
+
+ return verification_code, expiry
+
+ async def verify_email(
+ self, email: str, verification_code: str
+ ) -> dict[str, str]:
+ user_id = await self.database_provider.users_handler.get_user_id_by_verification_code(
+ verification_code=verification_code
+ )
+ await self.database_provider.users_handler.mark_user_as_verified(
+ id=user_id
+ )
+ await self.database_provider.users_handler.remove_verification_code(
+ verification_code=verification_code
+ )
+ return {"message": "Email verified successfully"}
+
+ async def login(self, email: str, password: str) -> dict[str, Token]:
+ logger.debug(f"Attempting login for email: {email}")
+ user = await self.database_provider.users_handler.get_user_by_email(
+ email=normalize_email(email)
+ )
+
+ if user.account_type != "password":
+ logger.warning(
+ f"Password login not allowed for {user.account_type} accounts: {email}"
+ )
+ raise R2RException(
+ status_code=401,
+ message=f"This account is configured for {user.account_type} login, not password.",
+ )
+
+ logger.debug(f"User found: {user}")
+
+ if not isinstance(user.hashed_password, str):
+ logger.error(
+ f"Invalid hashed_password type: {type(user.hashed_password)}"
+ )
+ raise HTTPException(
+ status_code=500,
+ detail="Invalid password hash in database",
+ )
+
+ try:
+ password_verified = self.crypto_provider.verify_password(
+ plain_password=password,
+ hashed_password=user.hashed_password,
+ )
+ except Exception as e:
+ logger.error(f"Error during password verification: {str(e)}")
+ raise HTTPException(
+ status_code=500,
+ detail="Error during password verification",
+ ) from e
+
+ if not password_verified:
+ logger.warning(f"Invalid password for user: {email}")
+ raise R2RException(
+ status_code=401, message="Incorrect email or password"
+ )
+
+ if not user.is_verified and self.config.require_email_verification:
+ logger.warning(f"Unverified user attempted login: {email}")
+ raise R2RException(status_code=401, message="Email not verified")
+
+ access_token = self.create_access_token(
+ data={"sub": normalize_email(user.email)}
+ )
+ refresh_token = self.create_refresh_token(
+ data={"sub": normalize_email(user.email)}
+ )
+ return {
+ "access_token": Token(token=access_token, token_type="access"),
+ "refresh_token": Token(token=refresh_token, token_type="refresh"),
+ }
+
+ async def refresh_access_token(
+ self, refresh_token: str
+ ) -> dict[str, Token]:
+ token_data = await self.decode_token(refresh_token)
+ if token_data.token_type != "refresh":
+ raise R2RException(
+ status_code=401, message="Invalid refresh token"
+ )
+
+ # Invalidate the old refresh token and create a new one
+ await self.database_provider.token_handler.blacklist_token(
+ token=refresh_token
+ )
+
+ new_access_token = self.create_access_token(
+ data={"sub": normalize_email(token_data.email)}
+ )
+ new_refresh_token = self.create_refresh_token(
+ data={"sub": normalize_email(token_data.email)}
+ )
+ return {
+ "access_token": Token(token=new_access_token, token_type="access"),
+ "refresh_token": Token(
+ token=new_refresh_token, token_type="refresh"
+ ),
+ }
+
+ async def change_password(
+ self, user: User, current_password: str, new_password: str
+ ) -> dict[str, str]:
+ if not isinstance(user.hashed_password, str):
+ logger.error(
+ f"Invalid hashed_password type: {type(user.hashed_password)}"
+ )
+ raise HTTPException(
+ status_code=500,
+ detail="Invalid password hash in database",
+ )
+
+ if not self.crypto_provider.verify_password(
+ plain_password=current_password,
+ hashed_password=user.hashed_password,
+ ):
+ raise R2RException(
+ status_code=400, message="Incorrect current password"
+ )
+
+ hashed_new_password = self.crypto_provider.get_password_hash(
+ password=new_password
+ )
+ await self.database_provider.users_handler.update_user_password(
+ id=user.id,
+ new_hashed_password=hashed_new_password,
+ )
+ try:
+ await self.email_provider.send_password_changed_email(
+ to_email=normalize_email(user.email),
+ dynamic_template_data={
+ "first_name": (
+ user.name.split(" ")[0] or "User"
+ if user.name
+ else "User"
+ )
+ },
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to send password change notification: {str(e)}"
+ )
+
+ return {"message": "Password changed successfully"}
+
+ async def request_password_reset(self, email: str) -> dict[str, str]:
+ try:
+ user = (
+ await self.database_provider.users_handler.get_user_by_email(
+ email=normalize_email(email)
+ )
+ )
+
+ reset_token = self.crypto_provider.generate_verification_code()
+ expiry = datetime.now(timezone.utc) + timedelta(hours=1)
+ await self.database_provider.users_handler.store_reset_token(
+ id=user.id,
+ reset_token=reset_token,
+ expiry=expiry,
+ )
+
+ first_name = (
+ user.name.split(" ")[0] if user.name else email.split("@")[0]
+ )
+ await self.email_provider.send_password_reset_email(
+ to_email=normalize_email(email),
+ reset_token=reset_token,
+ dynamic_template_data={"first_name": first_name},
+ )
+
+ return {
+ "message": "If the email exists, a reset link has been sent"
+ }
+ except R2RException as e:
+ if e.status_code == 404:
+ # User doesn't exist; return a success message anyway
+ return {
+ "message": "If the email exists, a reset link has been sent"
+ }
+ else:
+ raise
+
+ async def confirm_password_reset(
+ self, reset_token: str, new_password: str
+ ) -> dict[str, str]:
+ user_id = await self.database_provider.users_handler.get_user_id_by_reset_token(
+ reset_token=reset_token
+ )
+ if not user_id:
+ raise R2RException(
+ status_code=400, message="Invalid or expired reset token"
+ )
+
+ hashed_new_password = self.crypto_provider.get_password_hash(
+ password=new_password
+ )
+ await self.database_provider.users_handler.update_user_password(
+ id=user_id,
+ new_hashed_password=hashed_new_password,
+ )
+ await self.database_provider.users_handler.remove_reset_token(
+ id=user_id
+ )
+ # Get the user information
+ user = await self.database_provider.users_handler.get_user_by_id(
+ id=user_id
+ )
+
+ try:
+ await self.email_provider.send_password_changed_email(
+ to_email=normalize_email(user.email),
+ dynamic_template_data={
+ "first_name": (
+ user.name.split(" ")[0] or "User"
+ if user.name
+ else "User"
+ )
+ },
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to send password change notification: {str(e)}"
+ )
+
+ return {"message": "Password reset successfully"}
+
+ async def logout(self, token: str) -> dict[str, str]:
+ await self.database_provider.token_handler.blacklist_token(token=token)
+ return {"message": "Logged out successfully"}
+
+ async def clean_expired_blacklisted_tokens(self):
+ await self.database_provider.token_handler.clean_expired_blacklisted_tokens()
+
+ async def send_reset_email(self, email: str) -> dict:
+ verification_code, expiry = await self.send_verification_email(
+ email=normalize_email(email)
+ )
+
+ return {
+ "verification_code": verification_code,
+ "expiry": expiry,
+ "message": f"Verification email sent successfully to {email}",
+ }
+
+ async def create_user_api_key(
+ self,
+ user_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ) -> dict[str, str]:
+ key_id, raw_api_key = self.crypto_provider.generate_api_key()
+ hashed_key = self.crypto_provider.hash_api_key(raw_api_key)
+
+ api_key_uuid = (
+ await self.database_provider.users_handler.store_user_api_key(
+ user_id=user_id,
+ key_id=key_id,
+ hashed_key=hashed_key,
+ name=name,
+ description=description,
+ )
+ )
+
+ return {
+ "api_key": f"{key_id}.{raw_api_key}",
+ "key_id": str(api_key_uuid),
+ "public_key": key_id,
+ "name": name or "",
+ }
+
+ async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
+ return await self.database_provider.users_handler.get_user_api_keys(
+ user_id=user_id
+ )
+
+ async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
+ return await self.database_provider.users_handler.delete_api_key(
+ user_id=user_id,
+ key_id=key_id,
+ )
+
+ async def rename_api_key(
+ self, user_id: UUID, key_id: UUID, new_name: str
+ ) -> bool:
+ return await self.database_provider.users_handler.update_api_key_name(
+ user_id=user_id,
+ key_id=key_id,
+ name=new_name,
+ )
+
+ async def oauth_callback_handler(
+ self, provider: str, oauth_id: str, email: str
+ ) -> dict[str, Token]:
+ """Handles a login/registration flow for OAuth providers (e.g., Google
+ or GitHub).
+
+ :param provider: "google" or "github"
+ :param oauth_id: The unique ID from the OAuth provider (e.g. Google's
+ 'sub')
+ :param email: The user's email from the provider, if available.
+ :return: dict with access_token and refresh_token
+ """
+ # 1) Attempt to find user by google_id or github_id, or by email
+ # The logic depends on your preference. We'll assume "google" => google_id, etc.
+ try:
+ if provider == "google":
+ try:
+ user = await self.database_provider.users_handler.get_user_by_email(
+ normalize_email(email)
+ )
+ # If user found, check if user.google_id matches or is null. If null, update it
+ if user and not user.google_id:
+ raise R2RException(
+ status_code=401,
+ message="User already exists and is not linked to Google account",
+ )
+ except Exception:
+ # Create new user
+ user = await self.register(
+ email=normalize_email(email)
+ or f"{oauth_id}@google_oauth.fake", # fallback
+ password=None, # no password
+ account_type="oauth",
+ google_id=oauth_id,
+ )
+ elif provider == "github":
+ try:
+ user = await self.database_provider.users_handler.get_user_by_email(
+ normalize_email(email)
+ )
+ # If user found, check if user.google_id matches or is null. If null, update it
+ if user and not user.github_id:
+ raise R2RException(
+ status_code=401,
+ message="User already exists and is not linked to Github account",
+ )
+ except Exception:
+ # Create new user
+ user = await self.register(
+ email=normalize_email(email)
+ or f"{oauth_id}@github_oauth.fake", # fallback
+ password=None, # no password
+ account_type="oauth",
+ github_id=oauth_id,
+ )
+ # else handle other providers
+
+ except R2RException:
+ # If no user found or creation fails
+ raise R2RException(
+ status_code=401, message="Could not create or fetch user"
+ ) from None
+
+ # If user is inactive, etc.
+ if not user.is_active:
+ raise R2RException(
+ status_code=401, message="User account is inactive"
+ )
+
+ # Possibly mark user as verified if you trust the OAuth provider's email
+ user.is_verified = True
+ await self.database_provider.users_handler.update_user(user)
+
+ # 2) Generate tokens
+ access_token = self.create_access_token(
+ data={"sub": normalize_email(user.email)}
+ )
+ refresh_token = self.create_refresh_token(
+ data={"sub": normalize_email(user.email)}
+ )
+
+ return {
+ "access_token": Token(token=access_token, token_type="access"),
+ "refresh_token": Token(token=refresh_token, token_type="refresh"),
+ }
diff --git a/.venv/lib/python3.12/site-packages/core/providers/auth/supabase.py b/.venv/lib/python3.12/site-packages/core/providers/auth/supabase.py
new file mode 100644
index 00000000..5fc0e0bf
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/auth/supabase.py
@@ -0,0 +1,249 @@
+import logging
+import os
+from datetime import datetime
+from typing import Optional
+from uuid import UUID
+
+from fastapi import Depends, HTTPException
+from fastapi.security import OAuth2PasswordBearer
+from supabase import Client, create_client
+
+from core.base import (
+ AuthConfig,
+ AuthProvider,
+ CryptoProvider,
+ EmailProvider,
+ R2RException,
+ Token,
+ TokenData,
+)
+from core.base.api.models import User
+
+from ..database import PostgresDatabaseProvider
+
+logger = logging.getLogger()
+
+logger = logging.getLogger()
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+
+class SupabaseAuthProvider(AuthProvider):
+ def __init__(
+ self,
+ config: AuthConfig,
+ crypto_provider: CryptoProvider,
+ database_provider: PostgresDatabaseProvider,
+ email_provider: EmailProvider,
+ ):
+ super().__init__(
+ config, crypto_provider, database_provider, email_provider
+ )
+ self.supabase_url = config.extra_fields.get(
+ "supabase_url", None
+ ) or os.getenv("SUPABASE_URL")
+ self.supabase_key = config.extra_fields.get(
+ "supabase_key", None
+ ) or os.getenv("SUPABASE_KEY")
+ if not self.supabase_url or not self.supabase_key:
+ raise HTTPException(
+ status_code=500,
+ detail="Supabase URL and key must be provided",
+ )
+ self.supabase: Client = create_client(
+ self.supabase_url, self.supabase_key
+ )
+
+ async def initialize(self):
+ # No initialization needed for Supabase
+ pass
+
+ def create_access_token(self, data: dict) -> str:
+ raise NotImplementedError(
+ "create_access_token is not used with Supabase authentication"
+ )
+
+ def create_refresh_token(self, data: dict) -> str:
+ raise NotImplementedError(
+ "create_refresh_token is not used with Supabase authentication"
+ )
+
+ async def decode_token(self, token: str) -> TokenData:
+ raise NotImplementedError(
+ "decode_token is not used with Supabase authentication"
+ )
+
+ async def register(
+ self,
+ email: str,
+ password: str,
+ name: Optional[str] = None,
+ bio: Optional[str] = None,
+ profile_picture: Optional[str] = None,
+ ) -> User: # type: ignore
+ # Use Supabase client to create a new user
+
+ if self.supabase.auth.sign_up(email=email, password=password):
+ raise R2RException(
+ status_code=400,
+ message="Supabase provider implementation is still under construction",
+ )
+ else:
+ raise R2RException(
+ status_code=400, message="User registration failed"
+ )
+
+ async def send_verification_email(
+ self, email: str, user: Optional[User] = None
+ ) -> tuple[str, datetime]:
+ raise NotImplementedError(
+ "send_verification_email is not used with Supabase"
+ )
+
+ async def verify_email(
+ self, email: str, verification_code: str
+ ) -> dict[str, str]:
+ # Use Supabase client to verify email
+ if self.supabase.auth.verify_email(email, verification_code):
+ return {"message": "Email verified successfully"}
+ else:
+ raise R2RException(
+ status_code=400, message="Invalid or expired verification code"
+ )
+
+ async def login(self, email: str, password: str) -> dict[str, Token]:
+ # Use Supabase client to authenticate user and get tokens
+ if response := self.supabase.auth.sign_in(
+ email=email, password=password
+ ):
+ access_token = response.access_token
+ refresh_token = response.refresh_token
+ return {
+ "access_token": Token(token=access_token, token_type="access"),
+ "refresh_token": Token(
+ token=refresh_token, token_type="refresh"
+ ),
+ }
+ else:
+ raise R2RException(
+ status_code=401, message="Invalid email or password"
+ )
+
+ async def refresh_access_token(
+ self, refresh_token: str
+ ) -> dict[str, Token]:
+ # Use Supabase client to refresh access token
+ if response := self.supabase.auth.refresh_access_token(refresh_token):
+ new_access_token = response.access_token
+ new_refresh_token = response.refresh_token
+ return {
+ "access_token": Token(
+ token=new_access_token, token_type="access"
+ ),
+ "refresh_token": Token(
+ token=new_refresh_token, token_type="refresh"
+ ),
+ }
+ else:
+ raise R2RException(
+ status_code=401, message="Invalid refresh token"
+ )
+
+ async def user(self, token: str = Depends(oauth2_scheme)) -> User:
+ # Use Supabase client to get user details from token
+ if user := self.supabase.auth.get_user(token).user:
+ return User(
+ id=user.id,
+ email=user.email,
+ is_active=True, # Assuming active if exists in Supabase
+ is_superuser=False, # Default to False unless explicitly set
+ created_at=user.created_at,
+ updated_at=user.updated_at,
+ is_verified=user.email_confirmed_at is not None,
+ name=user.user_metadata.get("full_name"),
+ # Set other optional fields if available in user metadata
+ )
+
+ else:
+ raise R2RException(status_code=401, message="Invalid token")
+
+ def get_current_active_user(
+ self, current_user: User = Depends(user)
+ ) -> User:
+ # Check if user is active
+ if not current_user.is_active:
+ raise R2RException(status_code=400, message="Inactive user")
+ return current_user
+
+ async def change_password(
+ self, user: User, current_password: str, new_password: str
+ ) -> dict[str, str]:
+ # Use Supabase client to update user password
+ if self.supabase.auth.update(user.id, {"password": new_password}):
+ return {"message": "Password changed successfully"}
+ else:
+ raise R2RException(
+ status_code=400, message="Failed to change password"
+ )
+
+ async def request_password_reset(self, email: str) -> dict[str, str]:
+ # Use Supabase client to send password reset email
+ if self.supabase.auth.send_password_reset_email(email):
+ return {
+ "message": "If the email exists, a reset link has been sent"
+ }
+ else:
+ raise R2RException(
+ status_code=400, message="Failed to send password reset email"
+ )
+
+ async def confirm_password_reset(
+ self, reset_token: str, new_password: str
+ ) -> dict[str, str]:
+ # Use Supabase client to reset password with token
+ if self.supabase.auth.reset_password_for_email(
+ reset_token, new_password
+ ):
+ return {"message": "Password reset successfully"}
+ else:
+ raise R2RException(
+ status_code=400, message="Invalid or expired reset token"
+ )
+
+ async def logout(self, token: str) -> dict[str, str]:
+ # Use Supabase client to logout user and revoke token
+ self.supabase.auth.sign_out(token)
+ return {"message": "Logged out successfully"}
+
+ async def clean_expired_blacklisted_tokens(self):
+ # Not applicable for Supabase, tokens are managed by Supabase
+ pass
+
+ async def send_reset_email(self, email: str) -> dict[str, str]:
+ raise NotImplementedError("send_reset_email is not used with Supabase")
+
+ async def create_user_api_key(
+ self,
+ user_id: UUID,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ ) -> dict[str, str]:
+ raise NotImplementedError(
+ "API key management is not supported with Supabase authentication"
+ )
+
+ async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
+ raise NotImplementedError(
+ "API key management is not supported with Supabase authentication"
+ )
+
+ async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
+ raise NotImplementedError(
+ "API key management is not supported with Supabase authentication"
+ )
+
+ async def oauth_callback_handler(
+ self, provider: str, oauth_id: str, email: str
+ ) -> dict[str, Token]:
+ raise NotImplementedError(
+ "API key management is not supported with Supabase authentication"
+ )