diff options
-rw-r--r-- | gn2/tests/unit/wqflask/oauth2/__init__.py | 0 | ||||
-rw-r--r-- | gn2/tests/unit/wqflask/oauth2/test_tokens.py | 37 | ||||
-rw-r--r-- | gn2/wqflask/oauth2/tokens.py | 59 |
3 files changed, 96 insertions, 0 deletions
diff --git a/gn2/tests/unit/wqflask/oauth2/__init__.py b/gn2/tests/unit/wqflask/oauth2/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/gn2/tests/unit/wqflask/oauth2/__init__.py diff --git a/gn2/tests/unit/wqflask/oauth2/test_tokens.py b/gn2/tests/unit/wqflask/oauth2/test_tokens.py new file mode 100644 index 00000000..ee527f51 --- /dev/null +++ b/gn2/tests/unit/wqflask/oauth2/test_tokens.py @@ -0,0 +1,37 @@ +"""Test oauth2 jwt tokens""" +from gn2.wqflask.oauth2.tokens import JWTToken + + +JWT_BEARER_TOKEN = b"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCIs\ +ImN0eSI6Impzb24ifQ.eyJpc3MiOiJHTjIiLCJ\ +zdWIiOiIxMjM0IiwiYXVkIjoiR04yIiwiZXhwI\ +joiMTIzNDUifQ.ETSr_7O4ZWLac5l4pinO9Xeb\ +mzTO7xp_LvbgxjnskDc" + + +def test_encode_token(): + """Test encoding a jwt token.""" + token = JWTToken( + key="secret", + registered_claims={ + "iss": "GN2", + "sub": "1234", + "aud": "GN2", + "exp": "12345", + } + ) + assert token.encode() == JWT_BEARER_TOKEN + assert token.bearer_token == { + "Authorization": f"Bearer {JWT_BEARER_TOKEN}" + } + + +def test_decode_token(): + """Test decoding a jwt token.""" + claims = JWTToken.decode(JWT_BEARER_TOKEN, "secret") + assert claims == { + 'iss': 'GN2', + 'sub': '1234', + 'aud': 'GN2', + 'exp': '12345' + } diff --git a/gn2/wqflask/oauth2/tokens.py b/gn2/wqflask/oauth2/tokens.py new file mode 100644 index 00000000..e0ee814b --- /dev/null +++ b/gn2/wqflask/oauth2/tokens.py @@ -0,0 +1,59 @@ +"""This file contains functions/classes related to dealing with JWTs""" +from dataclasses import dataclass +from dataclasses import field +from authlib.jose import jwt + + +@dataclass +class JWTToken: + """Class for constructing a JWT according to RFC7519 + +https://datatracker.ietf.org/doc/html/rfc7519 + + """ + key: str + private_claims: dict = field(default_factory=lambda: {}) + public_claims: dict = field(default_factory=lambda: {}) + jose_header: dict = field( + default_factory=lambda: { + "alg": "HS256", + "typ": "jwt", + "cty": "json", + }) + registered_claims: dict = field( + default_factory={ + "iss": "", # Issuer Claim + "iat": "", # Issued At + "sub": "", # Subject Claim + "aud": "", # Audience Claim + "exp": "", # Expiration Time Claim + "jti": "", # Unique Identifier for this token + }) + + def __post__init__(self): + match self.jose_header.get("alg"): + case "HS256": + self.key = self.key + case _: + with open(self.key, "rb")as f_: + self.key = f_.read() + + def encode(self): + """Encode the JWT""" + payload = self.registered_claims \ + | self.private_claims \ + | self.public_claims \ + | self.registered_claims + return jwt.encode(self.jose_header, payload, self.key) + + @property + def bearer_token(self) -> dict: + """Return a header that contains this tokens Bearer Token""" + return { + "Authorization": f"Bearer {self.encode()}" + } + + @staticmethod + def decode(token, key) -> str: + """Decode the JWT""" + return jwt.decode(token, key) |