about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn_auth/auth/authentication/oauth2/models/oauth2client.py119
1 files changed, 76 insertions, 43 deletions
diff --git a/gn_auth/auth/authentication/oauth2/models/oauth2client.py b/gn_auth/auth/authentication/oauth2/models/oauth2client.py
index ea86772..b140a16 100644
--- a/gn_auth/auth/authentication/oauth2/models/oauth2client.py
+++ b/gn_auth/auth/authentication/oauth2/models/oauth2client.py
@@ -1,17 +1,26 @@
 """OAuth2 Client model."""
 import json
 import datetime
+
 from uuid import UUID
-from typing import Sequence, Optional, NamedTuple
+from dataclasses import dataclass
+from functools import cached_property
+from typing import Sequence, Optional
 
+from authlib.oauth2.rfc6749 import ClientMixin
 from pymonad.maybe import Just, Maybe, Nothing
 
 from gn_auth.auth.db import sqlite3 as db
-from gn_auth.auth.authentication.users import User, users, user_by_id, same_password
+from gn_auth.auth.authentication.users import (User,
+                                               users,
+                                               user_by_id,
+                                               same_password)
 
 from gn_auth.auth.authorisation.errors import NotFoundError
 
-class OAuth2Client(NamedTuple):
+
+@dataclass(frozen=True)
+class OAuth2Client(ClientMixin):
     """
     Client to the OAuth2 Server.
 
@@ -29,12 +38,12 @@ class OAuth2Client(NamedTuple):
         """Check whether the `client_secret` matches this client."""
         return same_password(client_secret, self.client_secret)
 
-    @property
+    @cached_property
     def token_endpoint_auth_method(self) -> str:
         """Return the token endpoint authorisation method."""
         return self.client_metadata.get("token_endpoint_auth_method", ["none"])
 
-    @property
+    @cached_property
     def client_type(self) -> str:
         """
         Return the token endpoint authorisation method.
@@ -64,12 +73,12 @@ class OAuth2Client(NamedTuple):
                     and method == "client_secret_basic")
         return False
 
-    @property
-    def id(self):# pylint: disable=[invalid-name]
+    @cached_property
+    def id(self):  # pylint: disable=[invalid-name]
         """Return the client_id."""
         return self.client_id
 
-    @property
+    @cached_property
     def grant_types(self) -> Sequence[str]:
         """
         Return the grant types that this client supports.
@@ -88,7 +97,7 @@ class OAuth2Client(NamedTuple):
         """
         return grant_type in self.grant_types
 
-    @property
+    @cached_property
     def redirect_uris(self) -> Sequence[str]:
         """Return the redirect_uris that this client supports."""
         return self.client_metadata.get('redirect_uris', [])
@@ -99,7 +108,7 @@ class OAuth2Client(NamedTuple):
         """
         return redirect_uri in self.redirect_uris
 
-    @property
+    @cached_property
     def response_types(self) -> Sequence[str]:
         """Return the response_types that this client supports."""
         return self.client_metadata.get("response_type", [])
@@ -108,13 +117,14 @@ class OAuth2Client(NamedTuple):
         """Check whether this client supports `response_type`."""
         return response_type in self.response_types
 
-    @property
+    @cached_property
     def scope(self) -> Sequence[str]:
         """Return valid scopes for this client."""
         return tuple(set(self.client_metadata.get("scope", [])))
 
     def get_allowed_scope(self, scope: str) -> str:
-        """Return list of scopes in `scope` that are supported by this client."""
+        """Return list of scopes in `scope` that are supported by this
+        client."""
         if not bool(scope):
             return ""
         requested = scope.split()
@@ -129,33 +139,39 @@ class OAuth2Client(NamedTuple):
         """Return the default redirect uri"""
         return self.client_metadata.get("default_redirect_uri", "")
 
+
 def client(conn: db.DbConnection, client_id: UUID,
            user: Optional[User] = None) -> Maybe:
     """Retrieve a client by its ID"""
     with db.cursor(conn) as cursor:
         cursor.execute(
-            "SELECT * FROM oauth2_clients WHERE client_id=?", (str(client_id),))
+            "SELECT * FROM oauth2_clients WHERE client_id=?",
+            (str(client_id),))
         result = cursor.fetchone()
         the_user = user
         if result:
             if not bool(the_user):
                 try:
                     the_user = user_by_id(conn, result["user_id"])
-                except NotFoundError as _nfe:
+                except NotFoundError:
                     the_user = None
 
             return Just(
-                OAuth2Client(UUID(result["client_id"]),
-                             result["client_secret"],
-                             datetime.datetime.fromtimestamp(
-                                 result["client_id_issued_at"]),
-                             datetime.datetime.fromtimestamp(
-                                 result["client_secret_expires_at"]),
-                             json.loads(result["client_metadata"]),
-                             the_user))# type: ignore[arg-type]
-
+                OAuth2Client(
+                    client_id=UUID(result["client_id"]),
+                    client_secret=result["client_secret"],
+                    client_id_issued_at=datetime.datetime.fromtimestamp(
+                        result["client_id_issued_at"]
+                    ),
+                    client_secret_expires_at=datetime.datetime.fromtimestamp(
+                        result["client_secret_expires_at"]
+                    ),
+                    client_metadata=json.loads(result["client_metadata"]),
+                    user=the_user)  # type: ignore[arg-type]
+            )
     return Nothing
 
+
 def client_by_id_and_secret(conn: db.DbConnection, client_id: UUID,
                             client_secret: str) -> OAuth2Client:
     """Retrieve a client by its ID and secret"""
@@ -166,16 +182,25 @@ def client_by_id_and_secret(conn: db.DbConnection, client_id: UUID,
         row = cursor.fetchone()
         if bool(row) and same_password(client_secret, row["client_secret"]):
             return OAuth2Client(
-                client_id, client_secret,
-                datetime.datetime.fromtimestamp(row["client_id_issued_at"]),
-                datetime.datetime.fromtimestamp(
-                    row["client_secret_expires_at"]),
-                json.loads(row["client_metadata"]),
-                user_by_id(conn, UUID(row["user_id"])))
+                client_id=client_id,
+                client_secret=client_secret,
+                client_id_issued_at=datetime.datetime.fromtimestamp(
+                    row["client_id_issued_at"]
+                ),
+                client_secret_expires_at=datetime.datetime.fromtimestamp(
+                    row["client_secret_expires_at"]
+                ),
+                client_metadata=json.loads(row["client_metadata"]),
+                user=user_by_id(conn, UUID(row["user_id"]))
+            )
+        raise NotFoundError(
+            "Could not find client with the given credentials."
+        )
 
-        raise NotFoundError("Could not find client with the given credentials.")
 
-def save_client(conn: db.DbConnection, the_client: OAuth2Client) -> OAuth2Client:
+def save_client(
+        conn: db.DbConnection, the_client: OAuth2Client
+) -> OAuth2Client:
     """Persist the client details into the database."""
     with db.cursor(conn) as cursor:
         query = (
@@ -204,6 +229,7 @@ def save_client(conn: db.DbConnection, the_client: OAuth2Client) -> OAuth2Client
             })
         return the_client
 
+
 def oauth2_clients(conn: db.DbConnection) -> tuple[OAuth2Client, ...]:
     """Fetch a list of all OAuth2 clients."""
     with db.cursor(conn) as cursor:
@@ -211,19 +237,26 @@ def oauth2_clients(conn: db.DbConnection) -> tuple[OAuth2Client, ...]:
         clients_rs = cursor.fetchall()
         the_users = {
             usr.user_id: usr for usr in users(
-                conn, tuple({UUID(result["user_id"]) for result in clients_rs}))
+                conn, tuple({UUID(result["user_id"])
+                             for result in clients_rs}))
         }
-        return tuple(OAuth2Client(UUID(result["client_id"]),
-                                  result["client_secret"],
-                                  datetime.datetime.fromtimestamp(
-                                     result["client_id_issued_at"]),
-                                  datetime.datetime.fromtimestamp(
-                                     result["client_secret_expires_at"]),
-                                  json.loads(result["client_metadata"]),
-                                  the_users[UUID(result["user_id"])])
-                     for result in clients_rs)
-
-def delete_client(conn: db.DbConnection, the_client: OAuth2Client) -> OAuth2Client:
+        return tuple(
+            OAuth2Client(
+                client_id=UUID(result["client_id"]),
+                client_secret=result["client_secret"],
+                client_id_issued_at=datetime.datetime.fromtimestamp(
+                    result["client_id_issued_at"]),
+                client_secret_expires_at=datetime.datetime.fromtimestamp(
+                    result["client_secret_expires_at"]),
+                client_metadata=json.loads(result["client_metadata"]),
+                user=the_users[UUID(result["user_id"])]
+            ) for result in clients_rs
+        )
+
+
+def delete_client(
+        conn: db.DbConnection, the_client: OAuth2Client
+) -> OAuth2Client:
     """Delete the given client from the database"""
     with db.cursor(conn) as cursor:
         params = (str(the_client.client_id),)