"""OAuth2 Client model."""
import json
import datetime
from uuid import UUID
from functools import cached_property
from dataclasses import asdict, dataclass
from typing import Any, Sequence, Optional
import requests
from flask import current_app as app
from requests.exceptions import JSONDecodeError
from authlib.jose import KeySet, JsonWebKey
from authlib.oauth2.rfc6749 import ClientMixin
from pymonad.maybe import Just, Maybe, Nothing
from gn_auth.debug import __pk__
from gn_auth.auth.db import sqlite3 as db
from gn_auth.auth.errors import NotFoundError
from gn_auth.auth.authentication.users import (User,
fetch_users,
user_by_id,
same_password)
@dataclass(frozen=True)
class OAuth2Client(ClientMixin):
"""
Client to the OAuth2 Server.
This is defined according to the mixin at
https://docs.authlib.org/en/latest/specs/rfc6749.html#authlib.oauth2.rfc6749.ClientMixin
"""
client_id: UUID
client_secret: str
client_id_issued_at: datetime.datetime
client_secret_expires_at: datetime.datetime
client_metadata: dict
user: User
def check_client_secret(self, client_secret: str) -> bool:
"""Check whether the `client_secret` matches this client."""
return same_password(client_secret, self.client_secret)
@cached_property
def token_endpoint_auth_method(self) -> str:
"""Return the token endpoint authorisation method."""
return self.client_metadata.get("token_endpoint_auth_method", ["none"])
@cached_property
def client_type(self) -> str:
"""
Return the token endpoint authorisation method.
Acceptable client types:
* public: Unable to use registered client secrets, e.g. browsers, apps
on mobile devices.
* confidential: able to securely authenticate with authorisation server
e.g. being able to keep their registered client secret safe.
"""
return self.client_metadata.get("client_type", "public")
def jwks(self) -> KeySet:
"""Return this client's KeySet."""
jwksuri = self.client_metadata.get("public-jwks-uri")
__pk__(f"PUBLIC JWKs link for client {self.client_id}", jwksuri)
if not bool(jwksuri):
app.logger.debug("No Public JWKs URI set for client!")
return KeySet([])
try:
## IMPORTANT: This can cause a deadlock if the client is working in
## single-threaded mode, i.e. can only serve one request
## at a time.
return KeySet([JsonWebKey.import_key(key)
for key in requests.get(
jwksuri,
timeout=300,
allow_redirects=True).json()["jwks"]])
except requests.ConnectionError as _connerr:
app.logger.debug(
"Could not connect to provided URI: %s", jwksuri, exc_info=True)
except JSONDecodeError as _jsonerr:
app.logger.debug(
"Could not convert response to JSON", exc_info=True)
except Exception as _exc:# pylint: disable=[broad-except]
app.logger.debug(
"Error retrieving the JWKs for the client.", exc_info=True)
return KeySet([])
def check_endpoint_auth_method(self, method: str, endpoint: str) -> bool:
"""
Check if the client supports the given method for the given endpoint.
Acceptable methods:
* none: Client is a public client and does not have a client secret
* client_secret_post: Client uses the HTTP POST parameters
* client_secret_basic: Client uses HTTP Basic
"""
if endpoint in ("token", "revoke", "introspection"):
return (method in self.token_endpoint_auth_method
and method == "client_secret_post")
return False
@cached_property
def id(self): # pylint: disable=[invalid-name]
"""Return the client_id."""
return self.client_id
@cached_property
def grant_types(self) -> Sequence[str]:
"""
Return the grant types that this client supports.
Valid grant types:
* authorisation_code
* implicit
* client_credentials
* password
"""
return self.client_metadata.get("grant_types", [])
def check_grant_type(self, grant_type: str) -> bool:
"""
Validate that client can handle the given grant types
"""
return grant_type in self.grant_types
@cached_property
def redirect_uris(self) -> Sequence[str]:
"""Return the redirect_uris that this client supports."""
return self.client_metadata.get('redirect_uris', [])
def check_redirect_uri(self, redirect_uri: str) -> bool:
"""
Check whether the given `redirect_uri` is one of the expected ones.
"""
return redirect_uri in self.redirect_uris
@cached_property
def response_types(self) -> Sequence[str]:
"""Return the response_types that this client supports."""
return self.client_metadata.get("response_type", [])
def check_response_type(self, response_type: str) -> bool:
"""Check whether this client supports `response_type`."""
return response_type in self.response_types
@cached_property
def scope(self) -> Sequence[str]:
"""Return valid scopes for this client."""
return tuple(set(self.client_metadata.get("scope", [])))
def get_allowed_scope(self, scope: str) -> str:
"""Return list of scopes in `scope` that are supported by this
client."""
if not bool(scope):
return ""
requested = scope.split()
return " ".join(sorted(set(
scp for scp in requested if scp in self.scope)))
def get_client_id(self):
"""Return this client's identifier."""
return self.client_id
def get_default_redirect_uri(self) -> str:
"""Return the default redirect uri"""
return self.client_metadata.get("default_redirect_uri", "")
def client(conn: db.DbConnection, client_id: UUID,
user: Optional[User] = None) -> Maybe:
"""Retrieve a client by its ID"""
with db.cursor(conn) as cursor:
cursor.execute(
"SELECT * FROM oauth2_clients WHERE client_id=?",
(str(client_id),))
result = cursor.fetchone()
_user = user
if result:
if not bool(_user):
try:
_user = user_by_id(conn, result["user_id"])
except NotFoundError:
_user = None
return Just(
OAuth2Client(
client_id=UUID(result["client_id"]),
client_secret=result["client_secret"],
client_id_issued_at=datetime.datetime.fromtimestamp(
result["client_id_issued_at"]
),
client_secret_expires_at=datetime.datetime.fromtimestamp(
result["client_secret_expires_at"]
),
client_metadata=json.loads(result["client_metadata"]),
user=_user) # type: ignore[arg-type]
)
return Nothing
def client_by_id_and_secret(conn: db.DbConnection, client_id: UUID,
client_secret: str) -> OAuth2Client:
"""Retrieve a client by its ID and secret"""
with db.cursor(conn) as cursor:
cursor.execute(
"SELECT * FROM oauth2_clients WHERE client_id=?",
(str(client_id),))
row = cursor.fetchone()
if bool(row) and same_password(client_secret, row["client_secret"]):
return OAuth2Client(
client_id=client_id,
client_secret=client_secret,
client_id_issued_at=datetime.datetime.fromtimestamp(
row["client_id_issued_at"]
),
client_secret_expires_at=datetime.datetime.fromtimestamp(
row["client_secret_expires_at"]
),
client_metadata=json.loads(row["client_metadata"]),
user=user_by_id(conn, UUID(row["user_id"]))
)
raise NotFoundError(
"Could not find client with the given credentials."
)
def save_client(
conn: db.DbConnection, the_client: OAuth2Client
) -> OAuth2Client:
"""Persist the client details into the database."""
with db.cursor(conn) as cursor:
query = (
"INSERT INTO oauth2_clients "
"(client_id, client_secret, client_id_issued_at, "
"client_secret_expires_at, client_metadata, user_id) "
"VALUES "
"(:client_id, :client_secret, :client_id_issued_at, "
":client_secret_expires_at, :client_metadata, :user_id) "
"ON CONFLICT (client_id) DO UPDATE SET "
"client_secret=:client_secret, "
"client_id_issued_at=:client_id_issued_at, "
"client_secret_expires_at=:client_secret_expires_at, "
"client_metadata=:client_metadata, user_id=:user_id")
cursor.execute(
query,
{
"client_id": str(the_client.client_id),
"client_secret": the_client.client_secret,
"client_id_issued_at": (
the_client.client_id_issued_at.timestamp()),
"client_secret_expires_at": (
the_client.client_secret_expires_at.timestamp()),
"client_metadata": json.dumps(the_client.client_metadata),
"user_id": str(the_client.user.user_id)
})
return the_client
def oauth2_clients(conn: db.DbConnection) -> tuple[OAuth2Client, ...]:
"""Fetch a list of all OAuth2 clients."""
with db.cursor(conn) as cursor:
cursor.execute("SELECT * FROM oauth2_clients")
clients_rs = cursor.fetchall()
the_users = {
usr.user_id: usr for usr in fetch_users(
conn, tuple({UUID(result["user_id"])
for result in clients_rs}))
}
return tuple(
OAuth2Client(
client_id=UUID(result["client_id"]),
client_secret=result["client_secret"],
client_id_issued_at=datetime.datetime.fromtimestamp(
result["client_id_issued_at"]),
client_secret_expires_at=datetime.datetime.fromtimestamp(
result["client_secret_expires_at"]),
client_metadata=json.loads(result["client_metadata"]),
user=the_users[UUID(result["user_id"])]
) for result in clients_rs
)
def delete_client(
conn: db.DbConnection, the_client: OAuth2Client
) -> OAuth2Client:
"""Delete the given client from the database"""
with db.cursor(conn) as cursor:
params = (str(the_client.client_id),)
cursor.execute("DELETE FROM authorisation_code WHERE client_id=?",
params)
cursor.execute("DELETE FROM oauth2_tokens WHERE client_id=?", params)
cursor.execute("DELETE FROM oauth2_clients WHERE client_id=?", params)
return the_client
def update_client_attribute(
client: OAuth2Client,# pylint: disable=[redefined-outer-name]
attribute: str,
value: Any
) -> OAuth2Client:
"""Return a new OAuth2Client with the given attribute updated/changed."""
attrs = {
attr: type(value)
for attr, value in asdict(client).items()
if attr != "client_id"
}
assert (
attribute in attrs.keys() and isinstance(value, attrs[attribute])), (
"Invalid attribute/value provided!")
return OAuth2Client(
client_id=client.client_id,
**{
attr: (value if attr==attribute else getattr(client, attr))
for attr in attrs
})