diff options
author | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
---|---|---|
committer | S. Solomon Darnell | 2025-03-28 21:52:21 -0500 |
commit | 4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch) | |
tree | ee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/base/providers/auth.py | |
parent | cc961e04ba734dd72309fb548a2f97d67d578813 (diff) | |
download | gn-ai-master.tar.gz |
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/base/providers/auth.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/core/base/providers/auth.py | 231 |
1 files changed, 231 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/auth.py b/.venv/lib/python3.12/site-packages/core/base/providers/auth.py new file mode 100644 index 00000000..352c3331 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/core/base/providers/auth.py @@ -0,0 +1,231 @@ +import logging +from abc import ABC, abstractmethod +from datetime import datetime +from typing import TYPE_CHECKING, Optional + +from fastapi import Security +from fastapi.security import ( + APIKeyHeader, + HTTPAuthorizationCredentials, + HTTPBearer, +) + +from ..abstractions import R2RException, Token, TokenData +from ..api.models import User +from .base import Provider, ProviderConfig +from .crypto import CryptoProvider +from .email import EmailProvider + +logger = logging.getLogger() + +if TYPE_CHECKING: + from core.providers.database import PostgresDatabaseProvider + +api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + + +class AuthConfig(ProviderConfig): + secret_key: Optional[str] = None + require_authentication: bool = False + require_email_verification: bool = False + default_admin_email: str = "admin@example.com" + default_admin_password: str = "change_me_immediately" + access_token_lifetime_in_minutes: Optional[int] = None + refresh_token_lifetime_in_days: Optional[int] = None + + @property + def supported_providers(self) -> list[str]: + return ["r2r"] + + def validate_config(self) -> None: + pass + + +class AuthProvider(Provider, ABC): + security = HTTPBearer(auto_error=False) + crypto_provider: CryptoProvider + email_provider: EmailProvider + database_provider: "PostgresDatabaseProvider" + + def __init__( + self, + config: AuthConfig, + crypto_provider: CryptoProvider, + database_provider: "PostgresDatabaseProvider", + email_provider: EmailProvider, + ): + if not isinstance(config, AuthConfig): + raise ValueError( + "AuthProvider must be initialized with an AuthConfig" + ) + self.config = config + self.admin_email = config.default_admin_email + self.admin_password = config.default_admin_password + self.crypto_provider = crypto_provider + self.database_provider = database_provider + self.email_provider = email_provider + super().__init__(config) + self.config: AuthConfig = config + self.database_provider: "PostgresDatabaseProvider" = database_provider + + async def _get_default_admin_user(self) -> User: + return await self.database_provider.users_handler.get_user_by_email( + self.admin_email + ) + + @abstractmethod + def create_access_token(self, data: dict) -> str: + pass + + @abstractmethod + def create_refresh_token(self, data: dict) -> str: + pass + + @abstractmethod + async def decode_token(self, token: str) -> TokenData: + pass + + @abstractmethod + async def user(self, token: str) -> User: + pass + + @abstractmethod + def get_current_active_user(self, current_user: User) -> User: + pass + + @abstractmethod + async def register(self, email: str, password: str) -> User: + pass + + @abstractmethod + async def send_verification_email( + self, email: str, user: Optional[User] = None + ) -> tuple[str, datetime]: + pass + + @abstractmethod + async def verify_email( + self, email: str, verification_code: str + ) -> dict[str, str]: + pass + + @abstractmethod + async def login(self, email: str, password: str) -> dict[str, Token]: + pass + + @abstractmethod + async def refresh_access_token( + self, refresh_token: str + ) -> dict[str, Token]: + pass + + def auth_wrapper( + self, + public: bool = False, + ): + async def _auth_wrapper( + auth: Optional[HTTPAuthorizationCredentials] = Security( + self.security + ), + api_key: Optional[str] = Security(api_key_header), + ) -> User: + # If authentication is not required and no credentials are provided, return the default admin user + if ( + ((not self.config.require_authentication) or public) + and auth is None + and api_key is None + ): + return await self._get_default_admin_user() + if not auth and not api_key: + raise R2RException( + message="No credentials provided. Create an account at https://app.sciphi.ai and set your API key using `r2r configure key` OR change your base URL to a custom deployment.", + status_code=401, + ) + if auth and api_key: + raise R2RException( + message="Cannot have both Bearer token and API key", + status_code=400, + ) + # 1. Try JWT if `auth` is present (Bearer token) + if auth is not None: + credentials = auth.credentials + try: + token_data = await self.decode_token(credentials) + user = await self.database_provider.users_handler.get_user_by_email( + token_data.email + ) + if user is not None: + return user + except R2RException: + # JWT decoding failed for logical reasons (invalid token) + pass + except Exception as e: + # JWT decoding failed unexpectedly, log and continue + logger.debug(f"JWT verification failed: {e}") + + # 2. If JWT failed, try API key from Bearer token + # Expected format: key_id.raw_api_key + if "." in credentials: + key_id, raw_api_key = credentials.split(".", 1) + api_key_record = await self.database_provider.users_handler.get_api_key_record( + key_id + ) + if api_key_record is not None: + hashed_key = api_key_record["hashed_key"] + if self.crypto_provider.verify_api_key( + raw_api_key, hashed_key + ): + user = await self.database_provider.users_handler.get_user_by_id( + api_key_record["user_id"] + ) + if user is not None and user.is_active: + return user + + # 3. If no Bearer token worked, try the X-API-Key header + if api_key is not None and "." in api_key: + key_id, raw_api_key = api_key.split(".", 1) + api_key_record = await self.database_provider.users_handler.get_api_key_record( + key_id + ) + if api_key_record is not None: + hashed_key = api_key_record["hashed_key"] + if self.crypto_provider.verify_api_key( + raw_api_key, hashed_key + ): + user = await self.database_provider.users_handler.get_user_by_id( + api_key_record["user_id"] + ) + if user is not None and user.is_active: + return user + + # If we reach here, both JWT and API key auth failed + raise R2RException( + message="Invalid token or API key", + status_code=401, + ) + + return _auth_wrapper + + @abstractmethod + async def change_password( + self, user: User, current_password: str, new_password: str + ) -> dict[str, str]: + pass + + @abstractmethod + async def request_password_reset(self, email: str) -> dict[str, str]: + pass + + @abstractmethod + async def confirm_password_reset( + self, reset_token: str, new_password: str + ) -> dict[str, str]: + pass + + @abstractmethod + async def logout(self, token: str) -> dict[str, str]: + pass + + @abstractmethod + async def send_reset_email(self, email: str) -> dict[str, str]: + pass |