aboutsummaryrefslogtreecommitdiff
path: root/gn_auth/auth/authentication/oauth2/models/jwtrefreshtoken.py
blob: 40b1554bda08a4287bec7d5f6011f9252f2616b1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Refresh tokens for JWTs

Refresh tokens are not supported directly by JWTs. This therefore provides a
form of extension to JWTs.
"""
import uuid
import datetime
from typing import Optional
from dataclasses import dataclass

from authlib.oauth2.rfc6749 import TokenMixin, InvalidGrantError

from pymonad.maybe import Just, Maybe, Nothing
from pymonad.tools import monad_from_none_or_value

from gn_auth.auth.db import sqlite3 as db
from gn_auth.auth.authentication.users import User, user_by_id

from gn_auth.auth.authentication.oauth2.models.oauth2client import (
    OAuth2Client,
    client as fetch_client)

@dataclass(frozen=True)
class JWTRefreshToken(TokenMixin):# pylint: disable=[too-many-instance-attributes]
    """Class representing a JWT refresh token."""
    token: str
    client: OAuth2Client
    user: User
    issued_with: uuid.UUID
    issued_at: datetime.datetime
    expires: datetime.datetime
    scope: str
    revoked: bool
    parent_of: Optional[str] = None

    def is_expired(self):
        """Check whether refresh token has expired."""
        return self.expires <= datetime.datetime.now()

    def get_scope(self):
        return self.scope

    def get_expires_in(self):
        return (self.expires - self.issued_at).total_seconds()

    def is_revoked(self):
        """Check whether refresh token is revoked"""
        return self.revoked

    def check_client(self, client: OAuth2Client) -> bool:
        """Check whether the token is issued to given `client`."""
        return client.client_id == self.client.client_id


def revoke_refresh_token(conn: db.DbConnection, token: JWTRefreshToken) -> None:
    """Revoke a refresh token."""
    # TODO: this token has been used before - revoke tree.
    # TODO: Fetch all the children tokens
    #   HINT:
    #     SELECT t1.token, t1.parent_of FROM jwt_refresh_tokens AS t1
    #     LEFT JOIN jwt_refresh_tokens AS t2 ON t1.parent_of=t2.token
    # TODO: Revoke all children tokens including the treeroot token
    raise NotImplementedError()


def save_refresh_token(conn: db.DbConnection, token: JWTRefreshToken) -> None:
    """Save the Refresh tokens into the database."""
    with db.cursor(conn) as cursor:
        cursor.execute(
            ("INSERT INTO jwt_refresh_tokens"
             "(token, client_id, user_id, issued_with, issued_at, expires, "
             "scope, revoked, parent_of) "
             "VALUES"
             "(:token, :client_id, :user_id, :issued_with, :issued_at, "
             ":expires, :scope, :revoked, :parent_of) "
             "ON CONFLICT (token) DO UPDATE SET parent_of=:parent_of"),
            {
                "token": token.token,
                "client_id": str(token.client.client_id),
                "user_id": str(token.user.user_id),
                "issued_with": str(token.issued_with),
                "issued_at": token.issued_at.timestamp(),
                "expires": token.expires.timestamp(),
                "scope": token.get_scope(),
                "revoked": token.revoked,
                "parent_of": token.parent_of
            })


def load_refresh_token(conn: db.DbConnection, token: str) -> Maybe:
    """Load a refresh_token by its token string."""
    def __process_results__(results):
        _user = user_by_id(conn, uuid.UUID(results["user_id"]))
        _now = datetime.datetime.now()
        return JWTRefreshToken(
            token=results["token"],
            client=fetch_client(
                conn, uuid.UUID(results["client_id"]), user=_user).maybe(
                    OAuth2Client(uuid.uuid4(), "secret", _now, _now, {}, _user),
                    lambda _client: _client),
            user=_user,
            issued_with=uuid.UUID(results["issued_with"]),
            issued_at=datetime.datetime.fromtimestamp(results["issued_at"]),
            expires=datetime.datetime.fromtimestamp(results["expires"]),
            scope=results["scope"],
            revoked=bool(int(results["revoked"])),
            parent_of=results["parent_of"]
        )

    with db.cursor(conn) as cursor:
        cursor.execute("SELECT * FROM jwt_refresh_tokens WHERE token=:token",
                       {"token": token})
        return monad_from_none_or_value(Nothing, Just, cursor.fetchone()).then(
            __process_results__)


def link_child_token(conn: db.DbConnection, parenttoken: str, childtoken: str):
    """Link child token."""
    _parent = load_refresh_token(conn, parenttoken).maybe(
        None, lambda _tok: _tok)
    if _parent is None:
        raise InvalidGrantError("Token not found.")

    with db.cursor(conn) as cursor:
        cursor.execute(("UPDATE jwt_refresh_tokens SET parent_of=:childtoken "
                        "WHERE token=:parenttoken"),
                       {"parenttoken": parenttoken, "childtoken": childtoken})


def is_refresh_token_valid(token: JWTRefreshToken, client: OAuth2Client) -> bool:
    """Check whether a token is valid."""
    return (
        (token.client.client_id == client.client_id)
        and
        (not token.is_expired())
        and
        (not token.revoked)
    )