aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/base/providers/auth.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/base/providers/auth.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/base/providers/auth.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/auth.py231
1 files changed, 231 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/auth.py b/.venv/lib/python3.12/site-packages/core/base/providers/auth.py
new file mode 100644
index 00000000..352c3331
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/auth.py
@@ -0,0 +1,231 @@
+import logging
+from abc import ABC, abstractmethod
+from datetime import datetime
+from typing import TYPE_CHECKING, Optional
+
+from fastapi import Security
+from fastapi.security import (
+ APIKeyHeader,
+ HTTPAuthorizationCredentials,
+ HTTPBearer,
+)
+
+from ..abstractions import R2RException, Token, TokenData
+from ..api.models import User
+from .base import Provider, ProviderConfig
+from .crypto import CryptoProvider
+from .email import EmailProvider
+
+logger = logging.getLogger()
+
+if TYPE_CHECKING:
+ from core.providers.database import PostgresDatabaseProvider
+
+api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
+
+
+class AuthConfig(ProviderConfig):
+ secret_key: Optional[str] = None
+ require_authentication: bool = False
+ require_email_verification: bool = False
+ default_admin_email: str = "admin@example.com"
+ default_admin_password: str = "change_me_immediately"
+ access_token_lifetime_in_minutes: Optional[int] = None
+ refresh_token_lifetime_in_days: Optional[int] = None
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["r2r"]
+
+ def validate_config(self) -> None:
+ pass
+
+
+class AuthProvider(Provider, ABC):
+ security = HTTPBearer(auto_error=False)
+ crypto_provider: CryptoProvider
+ email_provider: EmailProvider
+ database_provider: "PostgresDatabaseProvider"
+
+ def __init__(
+ self,
+ config: AuthConfig,
+ crypto_provider: CryptoProvider,
+ database_provider: "PostgresDatabaseProvider",
+ email_provider: EmailProvider,
+ ):
+ if not isinstance(config, AuthConfig):
+ raise ValueError(
+ "AuthProvider must be initialized with an AuthConfig"
+ )
+ self.config = config
+ self.admin_email = config.default_admin_email
+ self.admin_password = config.default_admin_password
+ self.crypto_provider = crypto_provider
+ self.database_provider = database_provider
+ self.email_provider = email_provider
+ super().__init__(config)
+ self.config: AuthConfig = config
+ self.database_provider: "PostgresDatabaseProvider" = database_provider
+
+ async def _get_default_admin_user(self) -> User:
+ return await self.database_provider.users_handler.get_user_by_email(
+ self.admin_email
+ )
+
+ @abstractmethod
+ def create_access_token(self, data: dict) -> str:
+ pass
+
+ @abstractmethod
+ def create_refresh_token(self, data: dict) -> str:
+ pass
+
+ @abstractmethod
+ async def decode_token(self, token: str) -> TokenData:
+ pass
+
+ @abstractmethod
+ async def user(self, token: str) -> User:
+ pass
+
+ @abstractmethod
+ def get_current_active_user(self, current_user: User) -> User:
+ pass
+
+ @abstractmethod
+ async def register(self, email: str, password: str) -> User:
+ pass
+
+ @abstractmethod
+ async def send_verification_email(
+ self, email: str, user: Optional[User] = None
+ ) -> tuple[str, datetime]:
+ pass
+
+ @abstractmethod
+ async def verify_email(
+ self, email: str, verification_code: str
+ ) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def login(self, email: str, password: str) -> dict[str, Token]:
+ pass
+
+ @abstractmethod
+ async def refresh_access_token(
+ self, refresh_token: str
+ ) -> dict[str, Token]:
+ pass
+
+ def auth_wrapper(
+ self,
+ public: bool = False,
+ ):
+ async def _auth_wrapper(
+ auth: Optional[HTTPAuthorizationCredentials] = Security(
+ self.security
+ ),
+ api_key: Optional[str] = Security(api_key_header),
+ ) -> User:
+ # If authentication is not required and no credentials are provided, return the default admin user
+ if (
+ ((not self.config.require_authentication) or public)
+ and auth is None
+ and api_key is None
+ ):
+ return await self._get_default_admin_user()
+ if not auth and not api_key:
+ raise R2RException(
+ message="No credentials provided. Create an account at https://app.sciphi.ai and set your API key using `r2r configure key` OR change your base URL to a custom deployment.",
+ status_code=401,
+ )
+ if auth and api_key:
+ raise R2RException(
+ message="Cannot have both Bearer token and API key",
+ status_code=400,
+ )
+ # 1. Try JWT if `auth` is present (Bearer token)
+ if auth is not None:
+ credentials = auth.credentials
+ try:
+ token_data = await self.decode_token(credentials)
+ user = await self.database_provider.users_handler.get_user_by_email(
+ token_data.email
+ )
+ if user is not None:
+ return user
+ except R2RException:
+ # JWT decoding failed for logical reasons (invalid token)
+ pass
+ except Exception as e:
+ # JWT decoding failed unexpectedly, log and continue
+ logger.debug(f"JWT verification failed: {e}")
+
+ # 2. If JWT failed, try API key from Bearer token
+ # Expected format: key_id.raw_api_key
+ if "." in credentials:
+ key_id, raw_api_key = credentials.split(".", 1)
+ api_key_record = await self.database_provider.users_handler.get_api_key_record(
+ key_id
+ )
+ if api_key_record is not None:
+ hashed_key = api_key_record["hashed_key"]
+ if self.crypto_provider.verify_api_key(
+ raw_api_key, hashed_key
+ ):
+ user = await self.database_provider.users_handler.get_user_by_id(
+ api_key_record["user_id"]
+ )
+ if user is not None and user.is_active:
+ return user
+
+ # 3. If no Bearer token worked, try the X-API-Key header
+ if api_key is not None and "." in api_key:
+ key_id, raw_api_key = api_key.split(".", 1)
+ api_key_record = await self.database_provider.users_handler.get_api_key_record(
+ key_id
+ )
+ if api_key_record is not None:
+ hashed_key = api_key_record["hashed_key"]
+ if self.crypto_provider.verify_api_key(
+ raw_api_key, hashed_key
+ ):
+ user = await self.database_provider.users_handler.get_user_by_id(
+ api_key_record["user_id"]
+ )
+ if user is not None and user.is_active:
+ return user
+
+ # If we reach here, both JWT and API key auth failed
+ raise R2RException(
+ message="Invalid token or API key",
+ status_code=401,
+ )
+
+ return _auth_wrapper
+
+ @abstractmethod
+ async def change_password(
+ self, user: User, current_password: str, new_password: str
+ ) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def request_password_reset(self, email: str) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def confirm_password_reset(
+ self, reset_token: str, new_password: str
+ ) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def logout(self, token: str) -> dict[str, str]:
+ pass
+
+ @abstractmethod
+ async def send_reset_email(self, email: str) -> dict[str, str]:
+ pass