aboutsummaryrefslogtreecommitdiff
path: root/gn_auth/auth/authentication/oauth2
diff options
context:
space:
mode:
Diffstat (limited to 'gn_auth/auth/authentication/oauth2')
-rw-r--r--gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py53
-rw-r--r--gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py20
-rw-r--r--gn_auth/auth/authentication/oauth2/models/oauth2client.py16
-rw-r--r--gn_auth/auth/authentication/oauth2/resource_server.py5
-rw-r--r--gn_auth/auth/authentication/oauth2/server.py82
5 files changed, 115 insertions, 61 deletions
diff --git a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
index 27783ac..c200ce6 100644
--- a/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
+++ b/gn_auth/auth/authentication/oauth2/grants/jwt_bearer_grant.py
@@ -1,8 +1,12 @@
"""JWT as Authorisation Grant"""
import uuid
+import time
+from typing import Optional
from flask import current_app as app
+from authlib.jose import jwt
+from authlib.common.encoding import to_native
from authlib.common.security import generate_token
from authlib.oauth2.rfc7523.jwt_bearer import JWTBearerGrant as _JWTBearerGrant
from authlib.oauth2.rfc7523.token import (
@@ -10,7 +14,8 @@ from authlib.oauth2.rfc7523.token import (
from gn_auth.debug import __pk__
from gn_auth.auth.db.sqlite3 import with_db_connection
-from gn_auth.auth.authentication.users import user_by_id
+from gn_auth.auth.authentication.users import User, user_by_id
+from gn_auth.auth.authentication.oauth2.models.oauth2client import OAuth2Client
class JWTBearerTokenGenerator(_JWTBearerTokenGenerator):
@@ -24,8 +29,20 @@ class JWTBearerTokenGenerator(_JWTBearerTokenGenerator):
self, grant_type, client, expires_in=None, user=None, scope=None
):
"""Post process data to prevent JSON serialization problems."""
- tokendata = super().get_token_data(
- grant_type, client, expires_in, user, scope)
+ issued_at = int(time.time())
+ tokendata = {
+ "scope": self.get_allowed_scope(client, scope),
+ "grant_type": grant_type,
+ "iat": issued_at,
+ "client_id": client.get_client_id()
+ }
+ if isinstance(expires_in, int) and expires_in > 0:
+ tokendata["exp"] = issued_at + expires_in
+ if self.issuer:
+ tokendata["iss"] = self.issuer
+ if user:
+ tokendata["sub"] = self.get_sub_value(user)
+
return {
**{
key: str(value) if key.endswith("_id") else value
@@ -36,6 +53,36 @@ class JWTBearerTokenGenerator(_JWTBearerTokenGenerator):
"oauth2_client_id": str(client.client_id)
}
+ def generate(# pylint: disable=[too-many-arguments]
+ self,
+ grant_type: str,
+ client: OAuth2Client,
+ user: Optional[User] = None,
+ scope: Optional[str] = None,
+ expires_in: Optional[int] = None
+ ) -> dict:
+ """Generate a bearer token for OAuth 2.0 authorization token endpoint.
+
+ :param client: the client that making the request.
+ :param grant_type: current requested grant_type.
+ :param user: current authorized user.
+ :param expires_in: if provided, use this value as expires_in.
+ :param scope: current requested scope.
+ :return: Token dict
+ """
+
+ token_data = self.get_token_data(grant_type, client, expires_in, user, scope)
+ access_token = jwt.encode({"alg": self.alg}, token_data, key=self.secret_key, check=False)
+ token = {
+ "token_type": "Bearer",
+ "access_token": to_native(access_token)
+ }
+ if expires_in:
+ token["expires_in"] = expires_in
+ if scope:
+ token["scope"] = scope
+ return token
+
def __call__(# pylint: disable=[too-many-arguments]
self, grant_type, client, user=None, scope=None, expires_in=None,
diff --git a/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py b/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py
index cca75f4..71769e1 100644
--- a/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py
+++ b/gn_auth/auth/authentication/oauth2/models/jwt_bearer_token.py
@@ -1,5 +1,7 @@
"""Implement model for JWTBearerToken"""
import uuid
+import time
+from typing import Optional
from authlib.oauth2.rfc7523 import JWTBearerToken as _JWTBearerToken
@@ -28,3 +30,21 @@ class JWTBearerToken(_JWTBearerToken):
def check_client(self, client):
"""Check that the client is right."""
return self.client.get_client_id() == client.get_client_id()
+
+
+ def get_expires_in(self) -> Optional[int]:
+ """Return the number of seconds the token is valid for since issue.
+
+ If `None`, the token never expires."""
+ if "exp" in self:
+ return self['exp'] - self['iat']
+ return None
+
+
+ def is_expired(self):
+ """Check whether the token is expired.
+
+ If there is no 'exp' member, assume this token will never expire."""
+ if "exp" in self:
+ return self["exp"] < time.time()
+ return False
diff --git a/gn_auth/auth/authentication/oauth2/models/oauth2client.py b/gn_auth/auth/authentication/oauth2/models/oauth2client.py
index df5d564..c7e1c90 100644
--- a/gn_auth/auth/authentication/oauth2/models/oauth2client.py
+++ b/gn_auth/auth/authentication/oauth2/models/oauth2client.py
@@ -1,6 +1,5 @@
"""OAuth2 Client model."""
import json
-import logging
import datetime
from uuid import UUID
from functools import cached_property
@@ -8,11 +7,13 @@ from dataclasses import asdict, dataclass
from typing import Any, Sequence, Optional
import requests
+from flask import current_app as app
from requests.exceptions import JSONDecodeError
from authlib.jose import KeySet, JsonWebKey
from authlib.oauth2.rfc6749 import ClientMixin
from pymonad.maybe import Just, Maybe, Nothing
+from gn_auth.debug import __pk__
from gn_auth.auth.db import sqlite3 as db
from gn_auth.auth.errors import NotFoundError
from gn_auth.auth.authentication.users import (User,
@@ -62,23 +63,26 @@ class OAuth2Client(ClientMixin):
def jwks(self) -> KeySet:
"""Return this client's KeySet."""
jwksuri = self.client_metadata.get("public-jwks-uri")
+ __pk__(f"PUBLIC JWKs link for client {self.client_id}", jwksuri)
if not bool(jwksuri):
- logging.debug("No Public JWKs URI set for client!")
+ app.logger.debug("No Public JWKs URI set for client!")
return KeySet([])
try:
## IMPORTANT: This can cause a deadlock if the client is working in
## single-threaded mode, i.e. can only serve one request
## at a time.
return KeySet([JsonWebKey.import_key(key)
- for key in requests.get(jwksuri).json()["jwks"]])
+ for key in requests.get(
+ jwksuri,
+ allow_redirects=True).json()["jwks"]])
except requests.ConnectionError as _connerr:
- logging.debug(
+ app.logger.debug(
"Could not connect to provided URI: %s", jwksuri, exc_info=True)
except JSONDecodeError as _jsonerr:
- logging.debug(
+ app.logger.debug(
"Could not convert response to JSON", exc_info=True)
except Exception as _exc:# pylint: disable=[broad-except]
- logging.debug(
+ app.logger.debug(
"Error retrieving the JWKs for the client.", exc_info=True)
return KeySet([])
diff --git a/gn_auth/auth/authentication/oauth2/resource_server.py b/gn_auth/auth/authentication/oauth2/resource_server.py
index 9c885e2..8ecf923 100644
--- a/gn_auth/auth/authentication/oauth2/resource_server.py
+++ b/gn_auth/auth/authentication/oauth2/resource_server.py
@@ -43,6 +43,11 @@ class JWTBearerTokenValidator(_JWTBearerTokenValidator):
self._last_jwks_update = datetime.now(tz=timezone.utc)
self._refresh_frequency = timedelta(hours=int(
extra_attributes.get("jwt_refresh_frequency_hours", 6)))
+ self.claims_options = {
+ 'exp': {'essential': False},
+ 'client_id': {'essential': True},
+ 'grant_type': {'essential': True},
+ }
def __refresh_jwks__(self):
now = datetime.now(tz=timezone.utc)
diff --git a/gn_auth/auth/authentication/oauth2/server.py b/gn_auth/auth/authentication/oauth2/server.py
index a8109b7..e6d1e01 100644
--- a/gn_auth/auth/authentication/oauth2/server.py
+++ b/gn_auth/auth/authentication/oauth2/server.py
@@ -3,8 +3,8 @@ import uuid
from typing import Callable
from datetime import datetime
-from flask import Flask, current_app
-from authlib.jose import jwt, KeySet
+from flask import Flask, current_app, request as flask_request
+from authlib.jose import KeySet
from authlib.oauth2.rfc6749.errors import InvalidClientError
from authlib.integrations.flask_oauth2 import AuthorizationServer
from authlib.oauth2.rfc6749 import OAuth2Request
@@ -16,13 +16,9 @@ from gn_auth.auth.jwks import (
jwks_directory,
newest_jwk_with_rotation)
+from .models.jwt_bearer_token import JWTBearerToken
from .models.oauth2client import client as fetch_client
from .models.oauth2token import OAuth2Token, save_token
-from .models.jwtrefreshtoken import (
- JWTRefreshToken,
- link_child_token,
- save_refresh_token,
- load_refresh_token)
from .grants.password_grant import PasswordGrant
from .grants.refresh_token_grant import RefreshTokenGrant
@@ -34,6 +30,8 @@ from .endpoints.introspection import IntrospectionEndpoint
from .resource_server import require_oauth, JWTBearerTokenValidator
+_TWO_HOURS_ = 2 * 60 * 60
+
def create_query_client_func() -> Callable:
"""Create the function that loads the client."""
@@ -50,50 +48,28 @@ def create_query_client_func() -> Callable:
return __query_client__
-def create_save_token_func(token_model: type, app: Flask) -> Callable:
+def create_save_token_func(token_model: type) -> Callable:
"""Create the function that saves the token."""
+ def __ignore_token__(token, request):# pylint: disable=[unused-argument]
+ """Ignore the token: i.e. Do not save it."""
+
def __save_token__(token, request):
- _jwt = jwt.decode(
- token["access_token"],
- newest_jwk_with_rotation(
- jwks_directory(app),
- int(app.config["JWKS_ROTATION_AGE_DAYS"])))
- _token = token_model(
- token_id=uuid.UUID(_jwt["jti"]),
- client=request.client,
- user=request.user,
- **{
- "refresh_token": None,
- "revoked": False,
- "issued_at": datetime.now(),
- **token
- })
with db.connection(current_app.config["AUTH_DB"]) as conn:
- save_token(conn, _token)
- old_refresh_token = load_refresh_token(
+ save_token(
conn,
- request.form.get("refresh_token", "nosuchtoken")
- )
- new_refresh_token = JWTRefreshToken(
- token=_token.refresh_token,
+ token_model(
+ **token,
+ token_id=uuid.uuid4(),
client=request.client,
user=request.user,
- issued_with=uuid.UUID(_jwt["jti"]),
- issued_at=datetime.fromtimestamp(_jwt["iat"]),
- expires=datetime.fromtimestamp(
- old_refresh_token.then(
- lambda _tok: _tok.expires.timestamp()
- ).maybe((int(_jwt["iat"]) +
- RefreshTokenGrant.DEFAULT_EXPIRES_IN),
- lambda _expires: _expires)),
- scope=_token.get_scope(),
+ issued_at=datetime.now(),
revoked=False,
- parent_of=None)
- save_refresh_token(conn, new_refresh_token)
- old_refresh_token.then(lambda _tok: link_child_token(
- conn, _tok.token, new_refresh_token.token))
+ expires_in=_TWO_HOURS_))
- return __save_token__
+ return {
+ OAuth2Token: __save_token__,
+ JWTBearerToken: __ignore_token__
+ }[token_model]
def make_jwt_token_generator(app):
"""Make token generator function."""
@@ -106,15 +82,17 @@ def make_jwt_token_generator(app):
include_refresh_token=True
):
return JWTBearerTokenGenerator(
- newest_jwk_with_rotation(
+ secret_key=newest_jwk_with_rotation(
jwks_directory(app),
- int(app.config["JWKS_ROTATION_AGE_DAYS"]))).__call__(
- grant_type,
- client,
- user,
- scope,
- JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN,
- include_refresh_token)
+ int(app.config["JWKS_ROTATION_AGE_DAYS"])),
+ issuer=flask_request.host_url,
+ alg="RS256").__call__(
+ grant_type=grant_type,
+ client=client,
+ user=user,
+ scope=scope,
+ expires_in=expires_in,
+ include_refresh_token=include_refresh_token)
return __generator__
@@ -153,7 +131,7 @@ def setup_oauth2_server(app: Flask) -> None:
server.init_app(
app,
query_client=create_query_client_func(),
- save_token=create_save_token_func(OAuth2Token, app))
+ save_token=create_save_token_func(JWTBearerToken))
app.config["OAUTH2_SERVER"] = server
## Set up the token validators