about summary refs log tree commit diff
path: root/gn2/wqflask/oauth2/tokens.py
blob: e0ee814b25e01789e9835a72879dd293bee72f92 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""This file contains functions/classes related to dealing with JWTs"""
from dataclasses import dataclass
from dataclasses import field
from authlib.jose import jwt


@dataclass
class JWTToken:
    """Class for constructing a JWT according to RFC7519

https://datatracker.ietf.org/doc/html/rfc7519

    """
    key: str
    private_claims: dict = field(default_factory=lambda: {})
    public_claims: dict = field(default_factory=lambda: {})
    jose_header: dict = field(
        default_factory=lambda: {
            "alg": "HS256",
            "typ": "jwt",
            "cty": "json",
        })
    registered_claims: dict = field(
        default_factory={
            "iss": "",  # Issuer Claim
            "iat": "",  # Issued At
            "sub": "",  # Subject Claim
            "aud": "",  # Audience Claim
            "exp": "",  # Expiration Time Claim
            "jti": "",  # Unique Identifier for this token
        })

    def __post__init__(self):
        match self.jose_header.get("alg"):
            case "HS256":
                self.key = self.key
            case _:
                with open(self.key, "rb")as f_:
                    self.key = f_.read()

    def encode(self):
        """Encode the JWT"""
        payload = self.registered_claims \
            | self.private_claims \
            | self.public_claims \
            | self.registered_claims
        return jwt.encode(self.jose_header, payload, self.key)

    @property
    def bearer_token(self) -> dict:
        """Return a header that contains this tokens Bearer Token"""
        return {
            "Authorization": f"Bearer {self.encode()}"
        }

    @staticmethod
    def decode(token, key) -> str:
        """Decode the JWT"""
        return jwt.decode(token, key)