"""OAuth2 Client model.""" import json import datetime from pathlib import Path from uuid import UUID from dataclasses import dataclass from functools import cached_property from typing import Sequence, Optional from authlib.jose import KeySet, JsonWebKey from authlib.oauth2.rfc6749 import ClientMixin from pymonad.maybe import Just, Maybe, Nothing from gn_auth.auth.db import sqlite3 as db from gn_auth.auth.errors import NotFoundError from gn_auth.auth.authentication.users import (User, fetch_users, user_by_id, same_password) @dataclass(frozen=True) class OAuth2Client(ClientMixin): """ Client to the OAuth2 Server. This is defined according to the mixin at https://docs.authlib.org/en/latest/specs/rfc6749.html#authlib.oauth2.rfc6749.ClientMixin """ client_id: UUID client_secret: str client_id_issued_at: datetime.datetime client_secret_expires_at: datetime.datetime client_metadata: dict user: User def check_client_secret(self, client_secret: str) -> bool: """Check whether the `client_secret` matches this client.""" return same_password(client_secret, self.client_secret) @cached_property def token_endpoint_auth_method(self) -> str: """Return the token endpoint authorisation method.""" return self.client_metadata.get("token_endpoint_auth_method", ["none"]) @cached_property def client_type(self) -> str: """ Return the token endpoint authorisation method. Acceptable client types: * public: Unable to use registered client secrets, e.g. browsers, apps on mobile devices. * confidential: able to securely authenticate with authorisation server e.g. being able to keep their registered client secret safe. """ return self.client_metadata.get("client_type", "public") @cached_property def jwks(self) -> KeySet: """Return this client's KeySet.""" def __parse_key__(keypath: Path) -> JsonWebKey:# pylint: disable=[unspecified-encoding] with open(keypath) as _key: return JsonWebKey.import_key(_key.read()) return KeySet([ __parse_key__(Path(pth)) for pth in self.client_metadata.get("public_keys", [])]) def check_endpoint_auth_method(self, method: str, endpoint: str) -> bool: """ Check if the client supports the given method for the given endpoint. Acceptable methods: * none: Client is a public client and does not have a client secret * client_secret_post: Client uses the HTTP POST parameters * client_secret_basic: Client uses HTTP Basic """ if endpoint == "token": return (method in self.token_endpoint_auth_method and method == "client_secret_post") if endpoint in ("introspection", "revoke"): return (method in self.token_endpoint_auth_method and method == "client_secret_basic") return False @cached_property def id(self): # pylint: disable=[invalid-name] """Return the client_id.""" return self.client_id @cached_property def grant_types(self) -> Sequence[str]: """ Return the grant types that this client supports. Valid grant types: * authorisation_code * implicit * client_credentials * password """ return self.client_metadata.get("grant_types", []) def check_grant_type(self, grant_type: str) -> bool: """ Validate that client can handle the given grant types """ return grant_type in self.grant_types @cached_property def redirect_uris(self) -> Sequence[str]: """Return the redirect_uris that this client supports.""" return self.client_metadata.get('redirect_uris', []) def check_redirect_uri(self, redirect_uri: str) -> bool: """ Check whether the given `redirect_uri` is one of the expected ones. """ return redirect_uri in self.redirect_uris @cached_property def response_types(self) -> Sequence[str]: """Return the response_types that this client supports.""" return self.client_metadata.get("response_type", []) def check_response_type(self, response_type: str) -> bool: """Check whether this client supports `response_type`.""" return response_type in self.response_types @cached_property def scope(self) -> Sequence[str]: """Return valid scopes for this client.""" return tuple(set(self.client_metadata.get("scope", []))) def get_allowed_scope(self, scope: str) -> str: """Return list of scopes in `scope` that are supported by this client.""" if not bool(scope): return "" requested = scope.split() return " ".join(sorted(set( scp for scp in requested if scp in self.scope))) def get_client_id(self): """Return this client's identifier.""" return self.client_id def get_default_redirect_uri(self) -> str: """Return the default redirect uri""" return self.client_metadata.get("default_redirect_uri", "") def client(conn: db.DbConnection, client_id: UUID, user: Optional[User] = None) -> Maybe: """Retrieve a client by its ID""" with db.cursor(conn) as cursor: cursor.execute( "SELECT * FROM oauth2_clients WHERE client_id=?", (str(client_id),)) result = cursor.fetchone() _user = user if result: if not bool(_user): try: _user = user_by_id(conn, result["user_id"]) except NotFoundError: _user = None return Just( OAuth2Client( client_id=UUID(result["client_id"]), client_secret=result["client_secret"], client_id_issued_at=datetime.datetime.fromtimestamp( result["client_id_issued_at"] ), client_secret_expires_at=datetime.datetime.fromtimestamp( result["client_secret_expires_at"] ), client_metadata=json.loads(result["client_metadata"]), user=_user) # type: ignore[arg-type] ) return Nothing def client_by_id_and_secret(conn: db.DbConnection, client_id: UUID, client_secret: str) -> OAuth2Client: """Retrieve a client by its ID and secret""" with db.cursor(conn) as cursor: cursor.execute( "SELECT * FROM oauth2_clients WHERE client_id=?", (str(client_id),)) row = cursor.fetchone() if bool(row) and same_password(client_secret, row["client_secret"]): return OAuth2Client( client_id=client_id, client_secret=client_secret, client_id_issued_at=datetime.datetime.fromtimestamp( row["client_id_issued_at"] ), client_secret_expires_at=datetime.datetime.fromtimestamp( row["client_secret_expires_at"] ), client_metadata=json.loads(row["client_metadata"]), user=user_by_id(conn, UUID(row["user_id"])) ) raise NotFoundError( "Could not find client with the given credentials." ) def save_client( conn: db.DbConnection, the_client: OAuth2Client ) -> OAuth2Client: """Persist the client details into the database.""" with db.cursor(conn) as cursor: query = ( "INSERT INTO oauth2_clients " "(client_id, client_secret, client_id_issued_at, " "client_secret_expires_at, client_metadata, user_id) " "VALUES " "(:client_id, :client_secret, :client_id_issued_at, " ":client_secret_expires_at, :client_metadata, :user_id) " "ON CONFLICT (client_id) DO UPDATE SET " "client_secret=:client_secret, " "client_id_issued_at=:client_id_issued_at, " "client_secret_expires_at=:client_secret_expires_at, " "client_metadata=:client_metadata, user_id=:user_id") cursor.execute( query, { "client_id": str(the_client.client_id), "client_secret": the_client.client_secret, "client_id_issued_at": ( the_client.client_id_issued_at.timestamp()), "client_secret_expires_at": ( the_client.client_secret_expires_at.timestamp()), "client_metadata": json.dumps(the_client.client_metadata), "user_id": str(the_client.user.user_id) }) return the_client def oauth2_clients(conn: db.DbConnection) -> tuple[OAuth2Client, ...]: """Fetch a list of all OAuth2 clients.""" with db.cursor(conn) as cursor: cursor.execute("SELECT * FROM oauth2_clients") clients_rs = cursor.fetchall() the_users = { usr.user_id: usr for usr in fetch_users( conn, tuple({UUID(result["user_id"]) for result in clients_rs})) } return tuple( OAuth2Client( client_id=UUID(result["client_id"]), client_secret=result["client_secret"], client_id_issued_at=datetime.datetime.fromtimestamp( result["client_id_issued_at"]), client_secret_expires_at=datetime.datetime.fromtimestamp( result["client_secret_expires_at"]), client_metadata=json.loads(result["client_metadata"]), user=the_users[UUID(result["user_id"])] ) for result in clients_rs ) def delete_client( conn: db.DbConnection, the_client: OAuth2Client ) -> OAuth2Client: """Delete the given client from the database""" with db.cursor(conn) as cursor: params = (str(the_client.client_id),) cursor.execute("DELETE FROM authorisation_code WHERE client_id=?", params) cursor.execute("DELETE FROM oauth2_tokens WHERE client_id=?", params) cursor.execute("DELETE FROM oauth2_clients WHERE client_id=?", params) return the_client