aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2024-05-09 08:12:16 +0300
committerFrederick Muriuki Muriithi2024-05-13 06:27:46 +0300
commita22dbcba7b28b75c13aa25bdd36583ade5fe3747 (patch)
tree84ba00ba66bceda8136116d54988f46344ea1c8d
parent0aa19052d00b2e7d0d6edd5905314484480e7ea2 (diff)
downloadgn-auth-a22dbcba7b28b75c13aa25bdd36583ade5fe3747.tar.gz
Link old refresh token to newly issued refresh token
We need to track the "lineage" of refresh tokens in order to detect possible stolen tokens and mitigate damage.
-rw-r--r--gn_auth/auth/authentication/oauth2/server.py53
1 files changed, 46 insertions, 7 deletions
diff --git a/gn_auth/auth/authentication/oauth2/server.py b/gn_auth/auth/authentication/oauth2/server.py
index 8b65aa9..75d6e1b 100644
--- a/gn_auth/auth/authentication/oauth2/server.py
+++ b/gn_auth/auth/authentication/oauth2/server.py
@@ -13,7 +13,11 @@ from gn_auth.auth.db import sqlite3 as db
from .models.oauth2client import client
from .models.oauth2token import OAuth2Token, save_token
-from .models.jwtrefreshtoken import JWTRefreshToken, save_refresh_token
+from .models.jwtrefreshtoken import (
+ JWTRefreshToken,
+ link_child_token,
+ save_refresh_token,
+ load_refresh_token)
from .grants.password_grant import PasswordGrant
from .grants.refresh_token_grant import RefreshTokenGrant
@@ -25,6 +29,7 @@ from .endpoints.introspection import IntrospectionEndpoint
from .resource_server import require_oauth, BearerTokenValidator
+
def create_query_client_func() -> Callable:
"""Create the function that loads the client."""
def __query_client__(client_id: uuid.UUID):
@@ -56,20 +61,53 @@ def create_save_token_func(token_model: type, jwtkey: jwk) -> Callable:
})
with db.connection(current_app.config["AUTH_DB"]) as conn:
save_token(conn, _token)
- save_refresh_token(
+ old_refresh_token = load_refresh_token(
conn,
- JWTRefreshToken(
+ request.form.get("refresh_token", "nosuchtoken")
+ )
+ new_refresh_token = JWTRefreshToken(
token=_token.refresh_token,
client=request.client,
user=request.user,
issued_with=uuid.UUID(_jwt["jti"]),
issued_at=datetime.datetime.fromtimestamp(_jwt["iat"]),
- expires=datetime.datetime.fromtimestamp(_jwt["iat"]),
+ expires=datetime.datetime.fromtimestamp(
+ old_refresh_token.then(
+ lambda _tok: _tok.expires.timestamp()
+ ).maybe((int(_jwt["iat"]) +
+ RefreshTokenGrant.DEFAULT_EXPIRES_IN),
+ lambda _expires: _expires)),
+ scope=_token.get_scope(),
revoked=False,
- parent_of=None))
+ parent_of=None)
+ save_refresh_token(conn, new_refresh_token)
+ old_refresh_token.then(lambda _tok: link_child_token(
+ conn, _tok.token, new_refresh_token.token))
return __save_token__
+
+def make_jwt_token_generator(app):
+ """Make token generator function."""
+ _gen = JWTBearerTokenGenerator(app.config["SSL_PRIVATE_KEY"])
+ def __generator__(# pylint: disable=[too-many-arguments]
+ grant_type,
+ client,
+ user=None,
+ scope=None,
+ expires_in=None,# pylint: disable=[unused-argument]
+ include_refresh_token=True
+ ):
+ return _gen.__call__(
+ grant_type,
+ client,
+ user,
+ scope,
+ JWTBearerTokenGenerator.DEFAULT_EXPIRES_IN,
+ include_refresh_token)
+ return __generator__
+
+
def setup_oauth2_server(app: Flask) -> None:
"""Set's up the oauth2 server for the flask application."""
server = AuthorizationServer()
@@ -81,9 +119,10 @@ def setup_oauth2_server(app: Flask) -> None:
server.register_grant(AuthorisationCodeGrant)
server.register_grant(JWTBearerGrant)
+ jwttokengenerator = make_jwt_token_generator(app)
server.register_token_generator(
- "urn:ietf:params:oauth:grant-type:jwt-bearer",
- JWTBearerTokenGenerator(app.config["SSL_PRIVATE_KEY"]))
+ "urn:ietf:params:oauth:grant-type:jwt-bearer", jwttokengenerator)
+ server.register_token_generator("refresh_token", jwttokengenerator)
server.register_grant(RefreshTokenGrant)
# register endpoints