1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
|
"""OAuth2 Client model."""
import json
import uuid
import datetime
from typing import NamedTuple, Sequence
from pymonad.maybe import Just, Maybe, Nothing
from gn3.auth import db
from gn3.auth.authentication.users import User, user_by_id
class OAuth2Client(NamedTuple):
"""
Client to the OAuth2 Server.
This is defined according to the mixin at
https://docs.authlib.org/en/latest/specs/rfc6749.html#authlib.oauth2.rfc6749.ClientMixin
"""
client_id: uuid.UUID
client_secret: str
client_id_issued_at: datetime.datetime
client_secret_expires_at: datetime.datetime
client_metadata: dict
user: User
def check_client_secret(self, client_secret: str) -> bool:
"""Check whether the `client_secret` matches this client."""
return self.client_secret == client_secret
@property
def token_endpoint_auth_method(self) -> str:
"""Return the token endpoint authorisation method."""
return self.client_metadata.get("token_endpoint_auth_method", ["none"])
@property
def client_type(self) -> str:
"""Return the token endpoint authorisation method."""
return self.client_metadata.get("client_type", "public")
def check_endpoint_auth_method(self, method: str, endpoint: str) -> bool:
"""
Check if the client supports the given method for the given endpoint.
Acceptable methods:
* none: Client is a public client and does not have a client secret
* client_secret_post: Client uses the HTTP POST parameters
* client_secret_basic: Client uses HTTP Basic
"""
if endpoint == "token":
return (method in self.token_endpoint_auth_method
and method == "client_secret_post")
if endpoint in ("introspection", "revoke"):
return (method in self.token_endpoint_auth_method
and method == "client_secret_basic")
return False
@property
def id(self):# pylint: disable=[invalid-name]
"""Return the client_id."""
return self.client_id
@property
def grant_types(self) -> Sequence[str]:
"""
Return the grant types that this client supports.
Valid grant types:
* authorisation_code
* implicit
* client_credentials
* password
"""
return self.client_metadata.get("grant_types", [])
def check_grant_type(self, grant_type: str) -> bool:
"""
Validate that client can handle the given grant types
"""
return grant_type in self.grant_types
@property
def redirect_uris(self) -> Sequence[str]:
"""Return the redirect_uris that this client supports."""
return self.client_metadata.get('redirect_uris', [])
def check_redirect_uri(self, redirect_uri: str) -> bool:
"""
Check whether the given `redirect_uri` is one of the expected ones.
"""
return redirect_uri in self.redirect_uris
@property
def response_types(self) -> Sequence[str]:
"""Return the response_types that this client supports."""
return self.client_metadata.get("response_types", [])
def check_response_type(self, response_type: str) -> bool:
"""Check whether this client supports `response_type`."""
return response_type in self.response_types
@property
def scope(self) -> Sequence[str]:
"""Return valid scopes for this client."""
return tuple(set(self.client_metadata.get("scope", [])))
def get_allowed_scope(self, scope: str) -> str:
"""Return list of scopes in `scope` that are supported by this client."""
if not bool(scope):
return ""
requested = scope.split()
return " ".join(sorted(set(
scp for scp in requested if scp in self.scope)))
def get_client_id(self):
"""Return this client's identifier."""
return self.client_id
def get_default_redirect_uri(self) -> str:
"""Return the default redirect uri"""
return self.client_metadata.get("default_redirect_uri", "")
def client(conn: db.DbConnection, client_id: uuid.UUID) -> Maybe:
"""Retrieve a client by its ID"""
with db.cursor(conn) as cursor:
cursor.execute(
"SELECT * FROM oauth2_clients WHERE client_id=?", (str(client_id),))
result = cursor.fetchone()
if result:
return Just(
OAuth2Client(uuid.UUID(result["client_id"]),
result["client_secret"],
datetime.datetime.fromtimestamp(
result["client_id_issued_at"]),
datetime.datetime.fromtimestamp(
result["client_secret_expires_at"]),
json.loads(result["client_metadata"]),
user_by_id( # type: ignore[misc]
conn, uuid.UUID(result["user_id"])).maybe(
None, lambda usr: usr)))
return Nothing
|