about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/core/base/providers/auth.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/base/providers/auth.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/base/providers/auth.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/base/providers/auth.py231
1 files changed, 231 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/base/providers/auth.py b/.venv/lib/python3.12/site-packages/core/base/providers/auth.py
new file mode 100644
index 00000000..352c3331
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/base/providers/auth.py
@@ -0,0 +1,231 @@
+import logging
+from abc import ABC, abstractmethod
+from datetime import datetime
+from typing import TYPE_CHECKING, Optional
+
+from fastapi import Security
+from fastapi.security import (
+    APIKeyHeader,
+    HTTPAuthorizationCredentials,
+    HTTPBearer,
+)
+
+from ..abstractions import R2RException, Token, TokenData
+from ..api.models import User
+from .base import Provider, ProviderConfig
+from .crypto import CryptoProvider
+from .email import EmailProvider
+
+logger = logging.getLogger()
+
+if TYPE_CHECKING:
+    from core.providers.database import PostgresDatabaseProvider
+
+api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
+
+
+class AuthConfig(ProviderConfig):
+    secret_key: Optional[str] = None
+    require_authentication: bool = False
+    require_email_verification: bool = False
+    default_admin_email: str = "admin@example.com"
+    default_admin_password: str = "change_me_immediately"
+    access_token_lifetime_in_minutes: Optional[int] = None
+    refresh_token_lifetime_in_days: Optional[int] = None
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return ["r2r"]
+
+    def validate_config(self) -> None:
+        pass
+
+
+class AuthProvider(Provider, ABC):
+    security = HTTPBearer(auto_error=False)
+    crypto_provider: CryptoProvider
+    email_provider: EmailProvider
+    database_provider: "PostgresDatabaseProvider"
+
+    def __init__(
+        self,
+        config: AuthConfig,
+        crypto_provider: CryptoProvider,
+        database_provider: "PostgresDatabaseProvider",
+        email_provider: EmailProvider,
+    ):
+        if not isinstance(config, AuthConfig):
+            raise ValueError(
+                "AuthProvider must be initialized with an AuthConfig"
+            )
+        self.config = config
+        self.admin_email = config.default_admin_email
+        self.admin_password = config.default_admin_password
+        self.crypto_provider = crypto_provider
+        self.database_provider = database_provider
+        self.email_provider = email_provider
+        super().__init__(config)
+        self.config: AuthConfig = config
+        self.database_provider: "PostgresDatabaseProvider" = database_provider
+
+    async def _get_default_admin_user(self) -> User:
+        return await self.database_provider.users_handler.get_user_by_email(
+            self.admin_email
+        )
+
+    @abstractmethod
+    def create_access_token(self, data: dict) -> str:
+        pass
+
+    @abstractmethod
+    def create_refresh_token(self, data: dict) -> str:
+        pass
+
+    @abstractmethod
+    async def decode_token(self, token: str) -> TokenData:
+        pass
+
+    @abstractmethod
+    async def user(self, token: str) -> User:
+        pass
+
+    @abstractmethod
+    def get_current_active_user(self, current_user: User) -> User:
+        pass
+
+    @abstractmethod
+    async def register(self, email: str, password: str) -> User:
+        pass
+
+    @abstractmethod
+    async def send_verification_email(
+        self, email: str, user: Optional[User] = None
+    ) -> tuple[str, datetime]:
+        pass
+
+    @abstractmethod
+    async def verify_email(
+        self, email: str, verification_code: str
+    ) -> dict[str, str]:
+        pass
+
+    @abstractmethod
+    async def login(self, email: str, password: str) -> dict[str, Token]:
+        pass
+
+    @abstractmethod
+    async def refresh_access_token(
+        self, refresh_token: str
+    ) -> dict[str, Token]:
+        pass
+
+    def auth_wrapper(
+        self,
+        public: bool = False,
+    ):
+        async def _auth_wrapper(
+            auth: Optional[HTTPAuthorizationCredentials] = Security(
+                self.security
+            ),
+            api_key: Optional[str] = Security(api_key_header),
+        ) -> User:
+            # If authentication is not required and no credentials are provided, return the default admin user
+            if (
+                ((not self.config.require_authentication) or public)
+                and auth is None
+                and api_key is None
+            ):
+                return await self._get_default_admin_user()
+            if not auth and not api_key:
+                raise R2RException(
+                    message="No credentials provided. Create an account at https://app.sciphi.ai and set your API key using `r2r configure key` OR change your base URL to a custom deployment.",
+                    status_code=401,
+                )
+            if auth and api_key:
+                raise R2RException(
+                    message="Cannot have both Bearer token and API key",
+                    status_code=400,
+                )
+            # 1. Try JWT if `auth` is present (Bearer token)
+            if auth is not None:
+                credentials = auth.credentials
+                try:
+                    token_data = await self.decode_token(credentials)
+                    user = await self.database_provider.users_handler.get_user_by_email(
+                        token_data.email
+                    )
+                    if user is not None:
+                        return user
+                except R2RException:
+                    # JWT decoding failed for logical reasons (invalid token)
+                    pass
+                except Exception as e:
+                    # JWT decoding failed unexpectedly, log and continue
+                    logger.debug(f"JWT verification failed: {e}")
+
+                # 2. If JWT failed, try API key from Bearer token
+                # Expected format: key_id.raw_api_key
+                if "." in credentials:
+                    key_id, raw_api_key = credentials.split(".", 1)
+                    api_key_record = await self.database_provider.users_handler.get_api_key_record(
+                        key_id
+                    )
+                    if api_key_record is not None:
+                        hashed_key = api_key_record["hashed_key"]
+                        if self.crypto_provider.verify_api_key(
+                            raw_api_key, hashed_key
+                        ):
+                            user = await self.database_provider.users_handler.get_user_by_id(
+                                api_key_record["user_id"]
+                            )
+                            if user is not None and user.is_active:
+                                return user
+
+            # 3. If no Bearer token worked, try the X-API-Key header
+            if api_key is not None and "." in api_key:
+                key_id, raw_api_key = api_key.split(".", 1)
+                api_key_record = await self.database_provider.users_handler.get_api_key_record(
+                    key_id
+                )
+                if api_key_record is not None:
+                    hashed_key = api_key_record["hashed_key"]
+                    if self.crypto_provider.verify_api_key(
+                        raw_api_key, hashed_key
+                    ):
+                        user = await self.database_provider.users_handler.get_user_by_id(
+                            api_key_record["user_id"]
+                        )
+                        if user is not None and user.is_active:
+                            return user
+
+            # If we reach here, both JWT and API key auth failed
+            raise R2RException(
+                message="Invalid token or API key",
+                status_code=401,
+            )
+
+        return _auth_wrapper
+
+    @abstractmethod
+    async def change_password(
+        self, user: User, current_password: str, new_password: str
+    ) -> dict[str, str]:
+        pass
+
+    @abstractmethod
+    async def request_password_reset(self, email: str) -> dict[str, str]:
+        pass
+
+    @abstractmethod
+    async def confirm_password_reset(
+        self, reset_token: str, new_password: str
+    ) -> dict[str, str]:
+        pass
+
+    @abstractmethod
+    async def logout(self, token: str) -> dict[str, str]:
+        pass
+
+    @abstractmethod
+    async def send_reset_email(self, email: str) -> dict[str, str]:
+        pass