aboutsummaryrefslogtreecommitdiff
path: root/gn_auth/auth/authentication/oauth2/models/authorization_code.py
blob: be5fdadb0a5dedc78ac1cdace0c9ae214e8925f5 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Model and functions for handling the Authorisation Code"""
from datetime import datetime
from dataclasses import dataclass, asdict
from functools import cached_property
from uuid import UUID
from authlib.oauth2.rfc6749 import AuthorizationCodeMixin


from pymonad.tools import monad_from_none_or_value
from pymonad.maybe import Just, Maybe, Nothing

from gn_auth.auth.db import sqlite3 as db

from .oauth2client import OAuth2Client

from ...users import User, user_by_id


EXPIRY_IN_SECONDS = 300  # in seconds


# pylint: disable=[too-many-instance-attributes]
@dataclass(frozen=True)
class AuthorisationCode(AuthorizationCodeMixin):
    """
    The AuthorisationCode model for the auth(entic|oris)ation system.
    """
    code_id: UUID
    code: str
    client: OAuth2Client
    redirect_uri: str
    scope: str
    nonce: str
    auth_time: int
    code_challenge: str
    code_challenge_method: str
    user: User

    @cached_property
    def response_type(self) -> str:
        """
        For authorisation code flow, the response_type type MUST always be
        'code'.
        """
        return "code"

    def is_expired(self):
        """Check whether the code is expired."""
        return self.auth_time + EXPIRY_IN_SECONDS < datetime.now().timestamp()

    def get_redirect_uri(self):
        """Get the redirect URI"""
        return self.redirect_uri

    def get_scope(self):
        """Return the assigned scope for this AuthorisationCode."""
        return self.scope


def authorisation_code(conn: db.DbConnection ,
                       code: str,
                       client: OAuth2Client) -> Maybe[AuthorisationCode]:
    """
    Retrieve the authorisation code object that corresponds to `code` and the
    given OAuth2 client.
    """
    with db.cursor(conn) as cursor:
        query = ("SELECT * FROM authorisation_code "
                 "WHERE code=:code AND client_id=:client_id")
        cursor.execute(
            query, {"code": code, "client_id": str(client.client_id)})

        return monad_from_none_or_value(
            Nothing, Just, cursor.fetchone()
        ).then(
            lambda result: AuthorisationCode(
                code_id=UUID(result["code_id"]),
                code=result["code"],
                client=client,
                redirect_uri=result["redirect_uri"],
                scope=result["scope"],
                nonce=result["nonce"],
                auth_time=int(result["auth_time"]),
                code_challenge=result["code_challenge"],
                code_challenge_method=result["code_challenge_method"],
                user=user_by_id(conn, UUID(result["user_id"]))))

def save_authorisation_code(conn: db.DbConnection,
                            auth_code: AuthorisationCode) -> AuthorisationCode:
    """Persist the `auth_code` into the database."""
    with db.cursor(conn) as cursor:
        cursor.execute(
            "INSERT INTO authorisation_code VALUES("
            ":code_id, :code, :client_id, :redirect_uri, :scope, :nonce, "
            ":auth_time, :code_challenge, :code_challenge_method, :user_id"
            ")",
            {
                **asdict(auth_code),
                "code_id": str(auth_code.code_id),
                "client_id": str(auth_code.client.client_id),
                "user_id": str(auth_code.user.user_id)
            })
        return auth_code