aboutsummaryrefslogtreecommitdiff
path: root/gn_auth/auth/authentication/oauth2
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2024-05-09 08:12:16 +0300
committerFrederick Muriuki Muriithi2024-05-13 06:27:46 +0300
commita22dbcba7b28b75c13aa25bdd36583ade5fe3747 (patch)
tree84ba00ba66bceda8136116d54988f46344ea1c8d /gn_auth/auth/authentication/oauth2
parent0aa19052d00b2e7d0d6edd5905314484480e7ea2 (diff)
downloadgn-auth-a22dbcba7b28b75c13aa25bdd36583ade5fe3747.tar.gz
Link old refresh token to newly issued refresh token
We need to track the "lineage" of refresh tokens in order to detect possible stolen tokens and mitigate damage.
Diffstat (limited to 'gn_auth/auth/authentication/oauth2')
-rw-r--r--gn_auth/auth/authentication/oauth2/server.py53
1 files changed, 46 insertions, 7 deletions
diff --git a/gn_auth/auth/authentication/oauth2/server.py b/gn_auth/auth/authentication/oauth2/server.py
index 8b65aa9..75d6e1b 100644
--- a/gn_auth/auth/authentication/oauth2/server.py
+++ b/gn_auth/auth/authentication/oauth2/server.py
@@ -13,7 +13,11 @@ from gn_auth.auth.db import sqlite3 as db
from .models.oauth2client import client
from .models.oauth2token import OAuth2Token, save_token
-from .models.jwtrefreshtoken import JWTRefreshToken, save_refresh_token
+from .models.jwtrefreshtoken import (
+ JWTRefreshToken,
+ link_child_token,
+ save_refresh_token,
+ load_refresh_token)
from .grants.password_grant import PasswordGrant
from .grants.refresh_token_grant import RefreshTokenGrant
@@ -25,6 +29,7 @@ from .endpoints.introspection import IntrospectionEndpoint
from .resource_server import require_oauth, BearerTokenValidator
+
def create_query_client_func() -> Callable:
"""Create the function that loads the client."""
def __query_client__(client_id: uuid.UUID):
@@ -56,20 +61,53 @@ def create_save_token_func(token_model: type, jwtkey: jwk) -> Callable:
})
with db.connection(current_app.config["AUTH_DB"]) as conn:
save_token(conn, _token)
- save_refresh_token(
+ old_refresh_token = load_refresh_token(
conn,
- JWTRefreshToken(
+ request.form.get("refresh_token", "nosuchtoken")
+ )
+ new_refresh_token = JWTRefreshToken(
token=_token.refresh_token,
client=request.client,
user=request.user,
issued_with=uuid.UUID(_jwt["jti"]),
issued_at=datetime.datetime.fromtimestamp(_jwt["iat"]),
- expires=datetime.datetime.fromtimestamp(_jwt["iat"]),
+ expires=datetime.datetime.fromtimestamp(
+ old_refresh_token.then(
+ lambda _tok: _tok.expires.timestamp()
+ ).maybe((int(_jwt["iat"]) +
+ RefreshTokenGrant.DEFAULT_EXPIRES_IN),
+ lambda _expires: _expires)),
+ scope=_token.get_scope(),
revoked=False,
- parent_of=None))
+ parent_of=None)
+ save_refresh_token(conn, new_refresh_token)
+ old_refresh_token.then(lambda _tok: link_child_token(
+ conn, _tok.token, new_refresh_token.token))
return __save_token__
+
+def make_jwt_token_generator(app):
+ """Make token generator function."""
+ _gen = JWTBearerTokenGenerator(app.config["SSL_PRIVATE_KEY"])
+ def __generator__(# pylint: disable=[too-many-arguments]
+ grant_type,
+ client,
+ user=None,
+ scope=None,
+ expires_in=None,# pylint: disable=[unused-argument]
+ include_refresh_token=True
+ ):
+ return _gen.__call__(
+ grant_type,
+ client,
+ user,
+ scope,
+ JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN,
+ include_refresh_token)
+ return __generator__
+
+
def setup_oauth2_server(app: Flask) -> None:
"""Set's up the oauth2 server for the flask application."""
server = AuthorizationServer()
@@ -81,9 +119,10 @@ def setup_oauth2_server(app: Flask) -> None:
server.register_grant(AuthorisationCodeGrant)
server.register_grant(JWTBearerGrant)
+ jwttokengenerator = make_jwt_token_generator(app)
server.register_token_generator(
- "urn:ietf:params:oauth:grant-type:jwt-bearer",
- JWTBearerTokenGenerator(app.config["SSL_PRIVATE_KEY"]))
+ "urn:ietf:params:oauth:grant-type:jwt-bearer", jwttokengenerator)
+ server.register_token_generator("refresh_token", jwttokengenerator)
server.register_grant(RefreshTokenGrant)
# register endpoints