aboutsummaryrefslogtreecommitdiff
path: root/gn_auth
diff options
context:
space:
mode:
Diffstat (limited to 'gn_auth')
-rw-r--r--gn_auth/auth/authentication/oauth2/models/oauth2client.py119
1 files changed, 76 insertions, 43 deletions
diff --git a/gn_auth/auth/authentication/oauth2/models/oauth2client.py b/gn_auth/auth/authentication/oauth2/models/oauth2client.py
index ea86772..b140a16 100644
--- a/gn_auth/auth/authentication/oauth2/models/oauth2client.py
+++ b/gn_auth/auth/authentication/oauth2/models/oauth2client.py
@@ -1,17 +1,26 @@
"""OAuth2 Client model."""
import json
import datetime
+
from uuid import UUID
-from typing import Sequence, Optional, NamedTuple
+from dataclasses import dataclass
+from functools import cached_property
+from typing import Sequence, Optional
+from authlib.oauth2.rfc6749 import ClientMixin
from pymonad.maybe import Just, Maybe, Nothing
from gn_auth.auth.db import sqlite3 as db
-from gn_auth.auth.authentication.users import User, users, user_by_id, same_password
+from gn_auth.auth.authentication.users import (User,
+ users,
+ user_by_id,
+ same_password)
from gn_auth.auth.authorisation.errors import NotFoundError
-class OAuth2Client(NamedTuple):
+
+@dataclass(frozen=True)
+class OAuth2Client(ClientMixin):
"""
Client to the OAuth2 Server.
@@ -29,12 +38,12 @@ class OAuth2Client(NamedTuple):
"""Check whether the `client_secret` matches this client."""
return same_password(client_secret, self.client_secret)
- @property
+ @cached_property
def token_endpoint_auth_method(self) -> str:
"""Return the token endpoint authorisation method."""
return self.client_metadata.get("token_endpoint_auth_method", ["none"])
- @property
+ @cached_property
def client_type(self) -> str:
"""
Return the token endpoint authorisation method.
@@ -64,12 +73,12 @@ class OAuth2Client(NamedTuple):
and method == "client_secret_basic")
return False
- @property
- def id(self):# pylint: disable=[invalid-name]
+ @cached_property
+ def id(self): # pylint: disable=[invalid-name]
"""Return the client_id."""
return self.client_id
- @property
+ @cached_property
def grant_types(self) -> Sequence[str]:
"""
Return the grant types that this client supports.
@@ -88,7 +97,7 @@ class OAuth2Client(NamedTuple):
"""
return grant_type in self.grant_types
- @property
+ @cached_property
def redirect_uris(self) -> Sequence[str]:
"""Return the redirect_uris that this client supports."""
return self.client_metadata.get('redirect_uris', [])
@@ -99,7 +108,7 @@ class OAuth2Client(NamedTuple):
"""
return redirect_uri in self.redirect_uris
- @property
+ @cached_property
def response_types(self) -> Sequence[str]:
"""Return the response_types that this client supports."""
return self.client_metadata.get("response_type", [])
@@ -108,13 +117,14 @@ class OAuth2Client(NamedTuple):
"""Check whether this client supports `response_type`."""
return response_type in self.response_types
- @property
+ @cached_property
def scope(self) -> Sequence[str]:
"""Return valid scopes for this client."""
return tuple(set(self.client_metadata.get("scope", [])))
def get_allowed_scope(self, scope: str) -> str:
- """Return list of scopes in `scope` that are supported by this client."""
+ """Return list of scopes in `scope` that are supported by this
+ client."""
if not bool(scope):
return ""
requested = scope.split()
@@ -129,33 +139,39 @@ class OAuth2Client(NamedTuple):
"""Return the default redirect uri"""
return self.client_metadata.get("default_redirect_uri", "")
+
def client(conn: db.DbConnection, client_id: UUID,
user: Optional[User] = None) -> Maybe:
"""Retrieve a client by its ID"""
with db.cursor(conn) as cursor:
cursor.execute(
- "SELECT * FROM oauth2_clients WHERE client_id=?", (str(client_id),))
+ "SELECT * FROM oauth2_clients WHERE client_id=?",
+ (str(client_id),))
result = cursor.fetchone()
the_user = user
if result:
if not bool(the_user):
try:
the_user = user_by_id(conn, result["user_id"])
- except NotFoundError as _nfe:
+ except NotFoundError:
the_user = None
return Just(
- OAuth2Client(UUID(result["client_id"]),
- result["client_secret"],
- datetime.datetime.fromtimestamp(
- result["client_id_issued_at"]),
- datetime.datetime.fromtimestamp(
- result["client_secret_expires_at"]),
- json.loads(result["client_metadata"]),
- the_user))# type: ignore[arg-type]
-
+ OAuth2Client(
+ client_id=UUID(result["client_id"]),
+ client_secret=result["client_secret"],
+ client_id_issued_at=datetime.datetime.fromtimestamp(
+ result["client_id_issued_at"]
+ ),
+ client_secret_expires_at=datetime.datetime.fromtimestamp(
+ result["client_secret_expires_at"]
+ ),
+ client_metadata=json.loads(result["client_metadata"]),
+ user=the_user) # type: ignore[arg-type]
+ )
return Nothing
+
def client_by_id_and_secret(conn: db.DbConnection, client_id: UUID,
client_secret: str) -> OAuth2Client:
"""Retrieve a client by its ID and secret"""
@@ -166,16 +182,25 @@ def client_by_id_and_secret(conn: db.DbConnection, client_id: UUID,
row = cursor.fetchone()
if bool(row) and same_password(client_secret, row["client_secret"]):
return OAuth2Client(
- client_id, client_secret,
- datetime.datetime.fromtimestamp(row["client_id_issued_at"]),
- datetime.datetime.fromtimestamp(
- row["client_secret_expires_at"]),
- json.loads(row["client_metadata"]),
- user_by_id(conn, UUID(row["user_id"])))
+ client_id=client_id,
+ client_secret=client_secret,
+ client_id_issued_at=datetime.datetime.fromtimestamp(
+ row["client_id_issued_at"]
+ ),
+ client_secret_expires_at=datetime.datetime.fromtimestamp(
+ row["client_secret_expires_at"]
+ ),
+ client_metadata=json.loads(row["client_metadata"]),
+ user=user_by_id(conn, UUID(row["user_id"]))
+ )
+ raise NotFoundError(
+ "Could not find client with the given credentials."
+ )
- raise NotFoundError("Could not find client with the given credentials.")
-def save_client(conn: db.DbConnection, the_client: OAuth2Client) -> OAuth2Client:
+def save_client(
+ conn: db.DbConnection, the_client: OAuth2Client
+) -> OAuth2Client:
"""Persist the client details into the database."""
with db.cursor(conn) as cursor:
query = (
@@ -204,6 +229,7 @@ def save_client(conn: db.DbConnection, the_client: OAuth2Client) -> OAuth2Client
})
return the_client
+
def oauth2_clients(conn: db.DbConnection) -> tuple[OAuth2Client, ...]:
"""Fetch a list of all OAuth2 clients."""
with db.cursor(conn) as cursor:
@@ -211,19 +237,26 @@ def oauth2_clients(conn: db.DbConnection) -> tuple[OAuth2Client, ...]:
clients_rs = cursor.fetchall()
the_users = {
usr.user_id: usr for usr in users(
- conn, tuple({UUID(result["user_id"]) for result in clients_rs}))
+ conn, tuple({UUID(result["user_id"])
+ for result in clients_rs}))
}
- return tuple(OAuth2Client(UUID(result["client_id"]),
- result["client_secret"],
- datetime.datetime.fromtimestamp(
- result["client_id_issued_at"]),
- datetime.datetime.fromtimestamp(
- result["client_secret_expires_at"]),
- json.loads(result["client_metadata"]),
- the_users[UUID(result["user_id"])])
- for result in clients_rs)
-
-def delete_client(conn: db.DbConnection, the_client: OAuth2Client) -> OAuth2Client:
+ return tuple(
+ OAuth2Client(
+ client_id=UUID(result["client_id"]),
+ client_secret=result["client_secret"],
+ client_id_issued_at=datetime.datetime.fromtimestamp(
+ result["client_id_issued_at"]),
+ client_secret_expires_at=datetime.datetime.fromtimestamp(
+ result["client_secret_expires_at"]),
+ client_metadata=json.loads(result["client_metadata"]),
+ user=the_users[UUID(result["user_id"])]
+ ) for result in clients_rs
+ )
+
+
+def delete_client(
+ conn: db.DbConnection, the_client: OAuth2Client
+) -> OAuth2Client:
"""Delete the given client from the database"""
with db.cursor(conn) as cursor:
params = (str(the_client.client_id),)