about summary refs log tree commit diff
path: root/gn_auth/auth/authentication
diff options
context:
space:
mode:
Diffstat (limited to 'gn_auth/auth/authentication')
-rw-r--r--gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py69
-rw-r--r--gn_auth/auth/authentication/oauth2/grants/refresh_token_grant.py10
-rw-r--r--gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py20
-rw-r--r--gn_auth/auth/authentication/oauth2/models/oauth2client.py36
-rw-r--r--gn_auth/auth/authentication/oauth2/resource_server.py11
-rw-r--r--gn_auth/auth/authentication/oauth2/server.py100
-rw-r--r--gn_auth/auth/authentication/oauth2/views.py17
-rw-r--r--gn_auth/auth/authentication/users.py4
8 files changed, 168 insertions, 99 deletions
diff --git a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
index 27783ac..63f979c 100644
--- a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
+++ b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
@@ -1,16 +1,25 @@
 """JWT as Authorisation Grant"""
 import uuid
+import time
+import logging
+from typing import Optional
 
-from flask import current_app as app
-
+from authlib.jose import jwt
+from authlib.common.encoding import to_native
 from authlib.common.security import generate_token
 from authlib.oauth2.rfc7523.jwt_bearer import JWTBearerGrant as _JWTBearerGrant
 from authlib.oauth2.rfc7523.token import (
     JWTBearerTokenGenerator as _JWTBearerTokenGenerator)
 
-from gn_auth.debug import __pk__
+from gn_libs.debug import make_peeker
+
 from gn_auth.auth.db.sqlite3 import with_db_connection
-from gn_auth.auth.authentication.users import user_by_id
+from gn_auth.auth.authentication.users import User, user_by_id
+from gn_auth.auth.authentication.oauth2.models.oauth2client import OAuth2Client
+
+
+logger = logging.getLogger(__name__)
+__pk__ = make_peeker(logger)
 
 
 class JWTBearerTokenGenerator(_JWTBearerTokenGenerator):
@@ -20,12 +29,24 @@ class JWTBearerTokenGenerator(_JWTBearerTokenGenerator):
 
     DEFAULT_EXPIRES_IN = 300
 
-    def get_token_data(#pylint: disable=[too-many-arguments]
+    def get_token_data(#pylint: disable=[too-many-arguments, too-many-positional-arguments]
             self, grant_type, client, expires_in=None, user=None, scope=None
     ):
         """Post process data to prevent JSON serialization problems."""
-        tokendata = super().get_token_data(
-            grant_type, client, expires_in, user, scope)
+        issued_at = int(time.time())
+        tokendata = {
+            "scope": self.get_allowed_scope(client, scope),
+            "grant_type": grant_type,
+            "iat": issued_at,
+            "client_id": client.get_client_id()
+        }
+        if isinstance(expires_in, int) and expires_in > 0:
+            tokendata["exp"] = issued_at + expires_in
+        if self.issuer:
+            tokendata["iss"] = self.issuer
+        if user:
+            tokendata["sub"] = self.get_sub_value(user)
+
         return {
             **{
                 key: str(value) if key.endswith("_id") else value
@@ -36,8 +57,38 @@ class JWTBearerTokenGenerator(_JWTBearerTokenGenerator):
             "oauth2_client_id": str(client.client_id)
         }
 
+    def generate(# pylint: disable=[too-many-arguments, too-many-positional-arguments]
+            self,
+            grant_type: str,
+            client: OAuth2Client,
+            user: Optional[User] = None,
+            scope: Optional[str] = None,
+            expires_in: Optional[int] = None
+    ) -> dict:
+        """Generate a bearer token for OAuth 2.0 authorization token endpoint.
+
+        :param client: the client that making the request.
+        :param grant_type: current requested grant_type.
+        :param user: current authorized user.
+        :param expires_in: if provided, use this value as expires_in.
+        :param scope: current requested scope.
+        :return: Token dict
+        """
+
+        token_data = self.get_token_data(grant_type, client, expires_in, user, scope)
+        access_token = jwt.encode({"alg": self.alg}, token_data, key=self.secret_key, check=False)
+        token = {
+            "token_type": "Bearer",
+            "access_token": to_native(access_token)
+        }
+        if expires_in:
+            token["expires_in"] = expires_in
+        if scope:
+            token["scope"] = scope
+        return token
+
 
-    def __call__(# pylint: disable=[too-many-arguments]
+    def __call__(# pylint: disable=[too-many-arguments, too-many-positional-arguments]
             self, grant_type, client, user=None, scope=None, expires_in=None,
             include_refresh_token=True
     ):
@@ -102,6 +153,6 @@ class JWTBearerGrant(_JWTBearerGrant):
             include_refresh_token=self.request.client.check_grant_type(
                 "refresh_token")
         )
-        app.logger.debug('Issue token %r to %r', token, self.request.client)
+        logger.debug('Issue token %r to %r', token, self.request.client)
         self.save_token(token)
         return 200, token, self.TOKEN_RESPONSE_HEADER
diff --git a/gn_auth/auth/authentication/oauth2/grants/refresh_token_grant.py b/gn_auth/auth/authentication/oauth2/grants/refresh_token_grant.py
index fd6804d..f897d89 100644
--- a/gn_auth/auth/authentication/oauth2/grants/refresh_token_grant.py
+++ b/gn_auth/auth/authentication/oauth2/grants/refresh_token_grant.py
@@ -34,18 +34,18 @@ class RefreshTokenGrant(grants.RefreshTokenGrant):
                     else Nothing)
             ).maybe(None, lambda _tok: _tok)
 
-    def authenticate_user(self, credential):
+    def authenticate_user(self, refresh_token):
         """Check that user is valid for given token."""
         with connection(app.config["AUTH_DB"]) as conn:
             try:
-                return user_by_id(conn, credential.user.user_id)
+                return user_by_id(conn, refresh_token.user.user_id)
             except NotFoundError as _nfe:
                 return None
 
         return None
 
-    def revoke_old_credential(self, credential):
+    def revoke_old_credential(self, refresh_token):
         """Revoke any old refresh token after issuing new refresh token."""
         with connection(app.config["AUTH_DB"]) as conn:
-            if credential.parent_of is not None:
-                revoke_refresh_token(conn, credential)
+            if refresh_token.parent_of is not None:
+                revoke_refresh_token(conn, refresh_token)
diff --git a/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py b/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py
index cca75f4..71769e1 100644
--- a/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py
+++ b/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py
@@ -1,5 +1,7 @@
 """Implement model for JWTBearerToken"""
 import uuid
+import time
+from typing import Optional
 
 from authlib.oauth2.rfc7523 import JWTBearerToken as _JWTBearerToken
 
@@ -28,3 +30,21 @@ class JWTBearerToken(_JWTBearerToken):
     def check_client(self, client):
         """Check that the client is right."""
         return self.client.get_client_id() == client.get_client_id()
+
+
+    def get_expires_in(self) -> Optional[int]:
+        """Return the number of seconds the token is valid for since issue.
+
+        If `None`, the token never expires."""
+        if "exp" in self:
+            return self['exp'] - self['iat']
+        return None
+
+
+    def is_expired(self):
+        """Check whether the token is expired.
+
+        If there is no 'exp' member, assume this token will never expire."""
+        if "exp" in self:
+            return self["exp"] < time.time()
+        return False
diff --git a/gn_auth/auth/authentication/oauth2/models/oauth2client.py b/gn_auth/auth/authentication/oauth2/models/oauth2client.py
index 615d0ee..dfe5d79 100644
--- a/gn_auth/auth/authentication/oauth2/models/oauth2client.py
+++ b/gn_auth/auth/authentication/oauth2/models/oauth2client.py
@@ -3,6 +3,7 @@ import json
 import logging
 import datetime
 from uuid import UUID
+from urllib.parse import urlparse
 from functools import cached_property
 from dataclasses import asdict, dataclass
 from typing import Any, Sequence, Optional
@@ -12,8 +13,8 @@ from requests.exceptions import JSONDecodeError
 from authlib.jose import KeySet, JsonWebKey
 from authlib.oauth2.rfc6749 import ClientMixin
 from pymonad.maybe import Just, Maybe, Nothing
+from gn_libs.debug import make_peeker
 
-from gn_auth.debug import __pk__, getLogger
 from gn_auth.auth.db import sqlite3 as db
 from gn_auth.auth.errors import NotFoundError
 from gn_auth.auth.authentication.users import (User,
@@ -22,6 +23,10 @@ from gn_auth.auth.authentication.users import (User,
                                                same_password)
 
 
+logger = logging.getLogger(__name__)
+__pk__ = make_peeker(logger)
+
+
 @dataclass(frozen=True)
 class OAuth2Client(ClientMixin):
     """
@@ -62,36 +67,29 @@ class OAuth2Client(ClientMixin):
 
     def jwks(self) -> KeySet:
         """Return this client's KeySet."""
-        logger = getLogger(__name__)
         jwksuri = self.client_metadata.get("public-jwks-uri")
-        ### ----- DEBUG: Remove this section ----- ###
-        import os
-        from pathlib import Path
-        ca_bundle = Path(os.environ.get("REQUESTS_CA_BUNDLE"))
-        __pk__(f"{ca_bundle} exists?", ca_bundle.exists())
-        ### ----- DEBUG: Remove this section ----- ###
-        if not bool(__pk__(
-                f"CLIENT'S ({self.client_id}) JWKs URI =======> ", jwksuri)):
+        __pk__(f"PUBLIC JWKs link for client {self.client_id}", jwksuri)
+        if not bool(jwksuri):
             logger.debug("No Public JWKs URI set for client!")
-            return __pk__("Return empty KeySet since URI is not set =====>",
-                          KeySet([]))
+            return KeySet([])
         try:
             ## IMPORTANT: This can cause a deadlock if the client is working in
             ##            single-threaded mode, i.e. can only serve one request
             ##            at a time.
             return KeySet([JsonWebKey.import_key(key)
-                           for key in requests.get(jwksuri).json()["jwks"]])
+                           for key in requests.get(
+                                   jwksuri,
+                                   timeout=300,
+                                   allow_redirects=True).json()["jwks"]])
         except requests.ConnectionError as _connerr:
             logger.debug(
                 "Could not connect to provided URI: %s", jwksuri, exc_info=True)
         except JSONDecodeError as _jsonerr:
-            logger.debug(
-                "Could not convert response to JSON", exc_info=True)
+            logger.debug("Could not convert response to JSON", exc_info=True)
         except Exception as _exc:# pylint: disable=[broad-except]
             logger.debug(
                 "Error retrieving the JWKs for the client.", exc_info=True)
-        return __pk__("Return empty KeySet after failure =====>",
-                      KeySet([]))
+        return KeySet([])
 
 
     def check_endpoint_auth_method(self, method: str, endpoint: str) -> bool:
@@ -141,7 +139,9 @@ class OAuth2Client(ClientMixin):
         """
         Check whether the given `redirect_uri` is one of the expected ones.
         """
-        return redirect_uri in self.redirect_uris
+        uri = urlparse(redirect_uri)._replace(
+            query="")._replace(fragment="").geturl()
+        return uri in self.redirect_uris
 
     @cached_property
     def response_types(self) -> Sequence[str]:
diff --git a/gn_auth/auth/authentication/oauth2/resource_server.py b/gn_auth/auth/authentication/oauth2/resource_server.py
index 9c885e2..edab02c 100644
--- a/gn_auth/auth/authentication/oauth2/resource_server.py
+++ b/gn_auth/auth/authentication/oauth2/resource_server.py
@@ -1,4 +1,5 @@
 """Protect the resources endpoints"""
+import logging
 from datetime import datetime, timezone, timedelta
 
 from flask import current_app as app
@@ -16,6 +17,9 @@ from gn_auth.auth.authentication.oauth2.models.jwt_bearer_token import (
 from gn_auth.auth.authentication.oauth2.models.oauth2token import (
     token_by_access_token)
 
+logger = logging.getLogger(__name__)
+
+
 class BearerTokenValidator(_BearerTokenValidator):
     """Extends `authlib.oauth2.rfc6750.BearerTokenValidator`"""
     def authenticate_token(self, token_string: str):
@@ -43,6 +47,11 @@ class JWTBearerTokenValidator(_JWTBearerTokenValidator):
         self._last_jwks_update = datetime.now(tz=timezone.utc)
         self._refresh_frequency = timedelta(hours=int(
             extra_attributes.get("jwt_refresh_frequency_hours", 6)))
+        self.claims_options = {
+            'exp': {'essential': False},
+            'client_id': {'essential': True},
+            'grant_type': {'essential': True},
+        }
 
     def __refresh_jwks__(self):
         now = datetime.now(tz=timezone.utc)
@@ -61,7 +70,7 @@ class JWTBearerTokenValidator(_JWTBearerTokenValidator):
                 claims.validate()
                 return claims
             except JoseError as error:
-                app.logger.debug('Authenticate token failed. %r', error)
+                logger.debug('Authenticate token failed. %r', error)
 
         return None
 
diff --git a/gn_auth/auth/authentication/oauth2/server.py b/gn_auth/auth/authentication/oauth2/server.py
index a8109b7..8ac5106 100644
--- a/gn_auth/auth/authentication/oauth2/server.py
+++ b/gn_auth/auth/authentication/oauth2/server.py
@@ -3,12 +3,12 @@ import uuid
 from typing import Callable
 from datetime import datetime
 
-from flask import Flask, current_app
-from authlib.jose import jwt, KeySet
+from flask import Flask, current_app, request as flask_request
+from authlib.jose import KeySet
+from authlib.oauth2.rfc6749 import OAuth2Request
 from authlib.oauth2.rfc6749.errors import InvalidClientError
 from authlib.integrations.flask_oauth2 import AuthorizationServer
-from authlib.oauth2.rfc6749 import OAuth2Request
-from authlib.integrations.flask_helpers import create_oauth_request
+from authlib.integrations.flask_oauth2.requests import FlaskOAuth2Request
 
 from gn_auth.auth.db import sqlite3 as db
 from gn_auth.auth.jwks import (
@@ -16,13 +16,9 @@ from gn_auth.auth.jwks import (
     jwks_directory,
     newest_jwk_with_rotation)
 
+from .models.jwt_bearer_token import JWTBearerToken
 from .models.oauth2client import client as fetch_client
 from .models.oauth2token import OAuth2Token, save_token
-from .models.jwtrefreshtoken import (
-    JWTRefreshToken,
-    link_child_token,
-    save_refresh_token,
-    load_refresh_token)
 
 from .grants.password_grant import PasswordGrant
 from .grants.refresh_token_grant import RefreshTokenGrant
@@ -34,6 +30,8 @@ from .endpoints.introspection import IntrospectionEndpoint
 
 from .resource_server import require_oauth, JWTBearerTokenValidator
 
+_TWO_HOURS_ = 2 * 60 * 60
+
 
 def create_query_client_func() -> Callable:
     """Create the function that loads the client."""
@@ -50,54 +48,32 @@ def create_query_client_func() -> Callable:
 
     return __query_client__
 
-def create_save_token_func(token_model: type, app: Flask) -> Callable:
+def create_save_token_func(token_model: type) -> Callable:
     """Create the function that saves the token."""
+    def __ignore_token__(token, request):# pylint: disable=[unused-argument]
+        """Ignore the token: i.e. Do not save it."""
+
     def __save_token__(token, request):
-        _jwt = jwt.decode(
-            token["access_token"],
-            newest_jwk_with_rotation(
-                jwks_directory(app),
-                int(app.config["JWKS_ROTATION_AGE_DAYS"])))
-        _token = token_model(
-            token_id=uuid.UUID(_jwt["jti"]),
-            client=request.client,
-            user=request.user,
-            **{
-                "refresh_token": None,
-                "revoked": False,
-                "issued_at": datetime.now(),
-                **token
-            })
         with db.connection(current_app.config["AUTH_DB"]) as conn:
-            save_token(conn, _token)
-            old_refresh_token = load_refresh_token(
+            save_token(
                 conn,
-                request.form.get("refresh_token", "nosuchtoken")
-            )
-            new_refresh_token = JWTRefreshToken(
-                    token=_token.refresh_token,
+                token_model(
+                    **token,
+                    token_id=uuid.uuid4(),
                     client=request.client,
                     user=request.user,
-                    issued_with=uuid.UUID(_jwt["jti"]),
-                    issued_at=datetime.fromtimestamp(_jwt["iat"]),
-                    expires=datetime.fromtimestamp(
-                        old_refresh_token.then(
-                            lambda _tok: _tok.expires.timestamp()
-                        ).maybe((int(_jwt["iat"]) +
-                                 RefreshTokenGrant.DEFAULT_EXPIRES_IN),
-                                lambda _expires: _expires)),
-                    scope=_token.get_scope(),
+                    issued_at=datetime.now(),
                     revoked=False,
-                    parent_of=None)
-            save_refresh_token(conn, new_refresh_token)
-            old_refresh_token.then(lambda _tok: link_child_token(
-                conn, _tok.token, new_refresh_token.token))
+                    expires_in=_TWO_HOURS_))
 
-    return __save_token__
+    return {
+        OAuth2Token: __save_token__,
+        JWTBearerToken: __ignore_token__
+    }[token_model]
 
 def make_jwt_token_generator(app):
     """Make token generator function."""
-    def __generator__(# pylint: disable=[too-many-arguments]
+    def __generator__(# pylint: disable=[too-many-arguments, too-many-positional-arguments]
             grant_type,
             client,
             user=None,
@@ -106,15 +82,17 @@ def make_jwt_token_generator(app):
             include_refresh_token=True
     ):
         return JWTBearerTokenGenerator(
-            newest_jwk_with_rotation(
+            secret_key=newest_jwk_with_rotation(
                 jwks_directory(app),
-                int(app.config["JWKS_ROTATION_AGE_DAYS"]))).__call__(
-                        grant_type,
-                        client,
-                        user,
-                        scope,
-                        JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN,
-                        include_refresh_token)
+                int(app.config["JWKS_ROTATION_AGE_DAYS"])),
+            issuer=flask_request.host_url,
+            alg="RS256").__call__(
+                grant_type=grant_type,
+                client=client,
+                user=user,
+                scope=scope,
+                expires_in=expires_in,
+                include_refresh_token=include_refresh_token)
     return __generator__
 
 
@@ -124,8 +102,16 @@ class JsonAuthorizationServer(AuthorizationServer):
 
     def create_oauth2_request(self, request):
         """Create an OAuth2 Request from the flask request."""
-        res = create_oauth_request(request, OAuth2Request, True)
-        return res
+        match flask_request.headers.get("Content-Type"):
+            case "application/json":
+                req = OAuth2Request(flask_request.method,
+                                     flask_request.url,
+                                     flask_request.get_json(),
+                                     flask_request.headers)
+            case _:
+                req = FlaskOAuth2Request(flask_request)
+
+        return req
 
 
 def setup_oauth2_server(app: Flask) -> None:
@@ -153,7 +139,7 @@ def setup_oauth2_server(app: Flask) -> None:
     server.init_app(
         app,
         query_client=create_query_client_func(),
-        save_token=create_save_token_func(OAuth2Token, app))
+        save_token=create_save_token_func(JWTBearerToken))
     app.config["OAUTH2_SERVER"] = server
 
     ## Set up the token validators
diff --git a/gn_auth/auth/authentication/oauth2/views.py b/gn_auth/auth/authentication/oauth2/views.py
index d0b55b4..8cc123f 100644
--- a/gn_auth/auth/authentication/oauth2/views.py
+++ b/gn_auth/auth/authentication/oauth2/views.py
@@ -1,5 +1,6 @@
 """Endpoints for the oauth2 server"""
 import uuid
+import logging
 import traceback
 from urllib.parse import urlparse
 
@@ -27,8 +28,10 @@ from .endpoints.revocation import RevocationEndpoint
 from .endpoints.introspection import IntrospectionEndpoint
 
 
+logger = logging.getLogger(__name__)
 auth = Blueprint("auth", __name__)
 
+
 @auth.route("/delete-client/<uuid:client_id>", methods=["GET", "POST"])
 def delete_client(client_id: uuid.UUID):
     """Delete an OAuth2 client."""
@@ -44,7 +47,7 @@ def authorise():
                               or str(uuid.uuid4()))
         client = server.query_client(client_id)
         if not bool(client):
-            flash("Invalid OAuth2 client.", "alert-danger")
+            flash("Invalid OAuth2 client.", "alert alert-danger")
 
         if request.method == "GET":
             def __forgot_password_table_exists__(conn):
@@ -77,7 +80,7 @@ def authorise():
             try:
                 email = validate_email(
                     form.get("user:email"), check_deliverability=False)
-                user = user_by_email(conn, email["email"])
+                user = user_by_email(conn, email["email"])  # type: ignore
                 if valid_login(conn, user, form.get("user:password", "")):
                     if not user.verified:
                         return redirect(
@@ -88,15 +91,15 @@ def authorise():
                                     email=email["email"]),
                             code=307)
                     return server.create_authorization_response(request=request, grant_user=user)
-                flash(email_passwd_msg, "alert-danger")
+                flash(email_passwd_msg, "alert alert-danger")
                 return redirect_response # type: ignore[return-value]
             except EmailNotValidError as _enve:
-                app.logger.debug(traceback.format_exc())
-                flash(email_passwd_msg, "alert-danger")
+                logger.debug(traceback.format_exc())
+                flash(email_passwd_msg, "alert alert-danger")
                 return redirect_response # type: ignore[return-value]
             except NotFoundError as _nfe:
-                app.logger.debug(traceback.format_exc())
-                flash(email_passwd_msg, "alert-danger")
+                logger.debug(traceback.format_exc())
+                flash(email_passwd_msg, "alert alert-danger")
                 return redirect_response # type: ignore[return-value]
 
         return with_db_connection(__authorise__)
diff --git a/gn_auth/auth/authentication/users.py b/gn_auth/auth/authentication/users.py
index 140ce36..fded79f 100644
--- a/gn_auth/auth/authentication/users.py
+++ b/gn_auth/auth/authentication/users.py
@@ -1,6 +1,6 @@
 """User-specific code and data structures."""
 import datetime
-from typing import Tuple
+from typing import Tuple, Union
 from uuid import UUID, uuid4
 from dataclasses import dataclass
 
@@ -26,7 +26,7 @@ class User:
         return self.user_id
 
     @staticmethod
-    def from_sqlite3_row(row: sqlite3.Row):
+    def from_sqlite3_row(row: Union[sqlite3.Row, dict]):
         """Generate a user from a row in an SQLite3 resultset"""
         return User(user_id=UUID(row["user_id"]),
                     email=row["email"],