about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2023-05-08 16:31:38 +0300
committerFrederick Muriuki Muriithi2023-05-09 13:15:47 +0300
commit5526f0316c2714d30e47a90f81e0ff686a29042f (patch)
tree64b6422984a6e3ce8bee3850b47a16c822677073
parentf2c09dc2dc2528c75fcf5b80aa4b530a0b5eef08 (diff)
downloadgenenetwork3-5526f0316c2714d30e47a90f81e0ff686a29042f.tar.gz
auth: Implement "Authorization Code Flow" auth/implement-authorization-code-flow
Implement the "Authorization Code Flow" for the authentication of users.

* gn3/auth/authentication/oauth2/grants/authorisation_code_grant.py: query and
  save the authorisation code.
* gn3/auth/authentication/oauth2/models/authorization_code.py: Implement the
  `AuthorisationCode` model
* gn3/auth/authentication/oauth2/models/oauth2client.py: Fix typo
* gn3/auth/authentication/oauth2/server.py: Register the
  `AuthorisationCodeGrant` grant with the server.
* gn3/auth/authentication/oauth2/views.py: Implement `/authorise` endpoint
* gn3/templates/base.html: New HTML Templates of authorisation UI
* gn3/templates/common-macros.html: New HTML Templates of authorisation UI
* gn3/templates/oauth2/authorise-user.html: New HTML Templates of
  authorisation UI
* main.py: Allow both "code" and "token" response types.
-rw-r--r--gn3/auth/authentication/oauth2/grants/authorisation_code_grant.py48
-rw-r--r--gn3/auth/authentication/oauth2/models/authorization_code.py93
-rw-r--r--gn3/auth/authentication/oauth2/models/oauth2client.py2
-rw-r--r--gn3/auth/authentication/oauth2/server.py11
-rw-r--r--gn3/auth/authentication/oauth2/views.py52
-rw-r--r--gn3/templates/base.html17
-rw-r--r--gn3/templates/common-macros.html7
-rw-r--r--gn3/templates/oauth2/authorise-user.html40
-rw-r--r--main.py7
9 files changed, 263 insertions, 14 deletions
diff --git a/gn3/auth/authentication/oauth2/grants/authorisation_code_grant.py b/gn3/auth/authentication/oauth2/grants/authorisation_code_grant.py
index d398192..f80d02e 100644
--- a/gn3/auth/authentication/oauth2/grants/authorisation_code_grant.py
+++ b/gn3/auth/authentication/oauth2/grants/authorisation_code_grant.py
@@ -1,24 +1,45 @@
 """Classes and function for Authorisation Code flow."""
 import uuid
+import string
+import random
 from typing import Optional
+from datetime import datetime
 
 from flask import current_app as app
 from authlib.oauth2.rfc6749 import grants
+from authlib.oauth2.rfc7636 import create_s256_code_challenge
 
 from gn3.auth import db
+from gn3.auth.db_utils import with_db_connection
 from gn3.auth.authentication.users import User
 
+from ..models.oauth2client import OAuth2Client
+from ..models.authorization_code import (
+    AuthorisationCode, authorisation_code, save_authorisation_code)
+
 class AuthorisationCodeGrant(grants.AuthorizationCodeGrant):
     """Implement the 'Authorisation Code' grant."""
-    TOKEN_ENDPOINT_AUTH_METHODS = ["client_secret_basic", "client_secret_post"]
+    TOKEN_ENDPOINT_AUTH_METHODS: list[str] = [
+        "client_secret_basic", "client_secret_post"]
+    AUTHORIZATION_CODE_LENGTH: int = 48
+    TOKEN_ENDPOINT_HTTP_METHODS = ['POST']
+    GRANT_TYPE = "authorization_code"
+    RESPONSE_TYPES = {'code'}
 
     def save_authorization_code(self, code, request):
         """Persist the authorisation code to database."""
-        raise Exception("NOT IMPLEMENTED!", self, code, request)
+        client = request.client
+        nonce = "".join(random.sample(string.ascii_letters + string.digits,
+                                      k=self.AUTHORIZATION_CODE_LENGTH))
+        return __save_authorization_code__(AuthorisationCode(
+            uuid.uuid4(), code, client, request.redirect_uri, request.scope,
+            nonce, int(datetime.now().timestamp()),
+            create_s256_code_challenge(app.config["SECRET_KEY"]),
+            "S256", request.user))
 
     def query_authorization_code(self, code, client):
         """Retrieve the code from the database."""
-        raise Exception("NOT IMPLEMENTED!", self, code, client)
+        return __query_authorization_code__(code, client)
 
     def delete_authorization_code(self, authorization_code):# pylint: disable=[no-self-use]
         """Delete the authorisation code."""
@@ -36,10 +57,29 @@ class AuthorisationCodeGrant(grants.AuthorizationCodeGrant):
             "WHERE authorisation_code.code=?")
         with db.connection(app.config["AUTH_DB"]) as conn:
             with db.cursor(conn) as cursor:
-                cursor.execute(query, (str(authorization_code.user_id),))
+                cursor.execute(query, (str(authorization_code.code),))
                 res = cursor.fetchone()
                 if res:
                     return User(
                         uuid.UUID(res["user_id"]), res["email"], res["name"])
 
         return None
+
+def __query_authorization_code__(
+        code: str, client: OAuth2Client) -> AuthorisationCode:
+    """A helper function that creates a new database connection.
+
+    This is found to be necessary since the `AuthorizationCodeGrant` class(es)
+    do not have a way to pass the database connection."""
+    def __auth_code__(conn) -> str:
+        the_code = authorisation_code(conn, code, client)
+        return the_code.maybe(None, lambda cde: cde) # type: ignore[misc, arg-type, return-value]
+
+    return with_db_connection(__auth_code__)
+
+def __save_authorization_code__(code: AuthorisationCode) -> AuthorisationCode:
+    """A helper function that creates a new database connection.
+
+    This is found to be necessary since the `AuthorizationCodeGrant` class(es)
+    do not have a way to pass the database connection."""
+    return with_db_connection(lambda conn: save_authorisation_code(conn, code))
diff --git a/gn3/auth/authentication/oauth2/models/authorization_code.py b/gn3/auth/authentication/oauth2/models/authorization_code.py
new file mode 100644
index 0000000..f282814
--- /dev/null
+++ b/gn3/auth/authentication/oauth2/models/authorization_code.py
@@ -0,0 +1,93 @@
+"""Model and functions for handling the Authorisation Code"""
+from uuid import UUID
+from datetime import datetime
+from typing import NamedTuple
+
+from pymonad.maybe import Just, Maybe, Nothing
+
+from gn3.auth import db
+
+from .oauth2client import OAuth2Client
+
+from ...users import User, user_by_id
+
+__5_MINUTES__ = 300 # in seconds
+
+class AuthorisationCode(NamedTuple):
+    """
+    The AuthorisationCode model for the auth(entic|oris)ation system.
+    """
+    # Instance variables
+    code_id: UUID
+    code: str
+    client: OAuth2Client
+    redirect_uri: str
+    scope: str
+    nonce: str
+    auth_time: int
+    code_challenge: str
+    code_challenge_method: str
+    user: User
+
+    @property
+    def response_type(self) -> str:
+        """
+        For authorisation code flow, the response_type type MUST always be
+        'code'.
+        """
+        return "code"
+
+    def is_expired(self):
+        """Check whether the code is expired."""
+        return self.auth_time + __5_MINUTES__ < datetime.now().timestamp()
+
+    def get_redirect_uri(self):
+        """Get the redirect URI"""
+        return self.redirect_uri
+
+    def get_scope(self):
+        """Return the assigned scope for this AuthorisationCode."""
+        return self.scope
+
+    def get_nonce(self):
+        """Get the one-time use token."""
+        return self.nonce
+
+def authorisation_code(conn: db.DbConnection ,
+                       code: str,
+                       client: OAuth2Client) -> Maybe[AuthorisationCode]:
+    """
+    Retrieve the authorisation code object that corresponds to `code` and the
+    given OAuth2 client.
+    """
+    with db.cursor(conn) as cursor:
+        query = ("SELECT * FROM authorisation_code "
+                 "WHERE code=:code AND client_id=:client_id")
+        cursor.execute(
+            query, {"code": code, "client_id": str(client.client_id)})
+        result = cursor.fetchone()
+        if result:
+            return Just(AuthorisationCode(
+                UUID(result["code_id"]), result["code"], client,
+                result["redirect_uri"], result["scope"], result["nonce"],
+                int(result["auth_time"]), result["code_challenge"],
+                result["code_challenge_method"],
+                user_by_id(conn, UUID(result["user_id"]))))
+        return Nothing
+
+def save_authorisation_code(conn: db.DbConnection,
+                            auth_code: AuthorisationCode) -> AuthorisationCode:
+    """Persist the `auth_code` into the database."""
+    with db.cursor(conn) as cursor:
+        cursor.execute(
+            "INSERT INTO authorisation_code VALUES("
+            ":code_id, :code, :client_id, :redirect_uri, :scope, :nonce, "
+            ":auth_time, :code_challenge, :code_challenge_method, :user_id"
+            ")",
+            {
+                **auth_code._asdict(),
+                "code_id": str(auth_code.code_id),
+                "client_id": str(auth_code.client.client_id),
+                "user_id": str(auth_code.user.user_id)
+            })
+        return auth_code
diff --git a/gn3/auth/authentication/oauth2/models/oauth2client.py b/gn3/auth/authentication/oauth2/models/oauth2client.py
index da5ff75..b7d37be 100644
--- a/gn3/auth/authentication/oauth2/models/oauth2client.py
+++ b/gn3/auth/authentication/oauth2/models/oauth2client.py
@@ -102,7 +102,7 @@ class OAuth2Client(NamedTuple):
     @property
     def response_types(self) -> Sequence[str]:
         """Return the response_types that this client supports."""
-        return self.client_metadata.get("response_types", [])
+        return self.client_metadata.get("response_type", [])
 
     def check_response_type(self, response_type: str) -> bool:
         """Check whether this client supports `response_type`."""
diff --git a/gn3/auth/authentication/oauth2/server.py b/gn3/auth/authentication/oauth2/server.py
index 73c9340..e9946b4 100644
--- a/gn3/auth/authentication/oauth2/server.py
+++ b/gn3/auth/authentication/oauth2/server.py
@@ -5,8 +5,7 @@ from typing import Callable
 
 from flask import Flask, current_app
 from authlib.integrations.flask_oauth2 import AuthorizationServer
-# from authlib.integrations.sqla_oauth2 import (
-#     create_save_token_func, create_query_client_func)
+# from authlib.oauth2.rfc7636 import CodeChallenge
 
 from gn3.auth import db
 
@@ -14,7 +13,7 @@ from .models.oauth2client import client
 from .models.oauth2token import OAuth2Token, save_token
 
 from .grants.password_grant import PasswordGrant
-# from .grants.authorisation_code_grant import AuthorisationCodeGrant
+from .grants.authorisation_code_grant import AuthorisationCodeGrant
 
 from .endpoints.revocation import RevocationEndpoint
 from .endpoints.introspection import IntrospectionEndpoint
@@ -49,7 +48,11 @@ def setup_oauth2_server(app: Flask) -> None:
     """Set's up the oauth2 server for the flask application."""
     server = AuthorizationServer()
     server.register_grant(PasswordGrant)
-    # server.register_grant(AuthorisationCodeGrant)
+
+    # Figure out a common `code_verifier` for GN2 and GN3 and set
+    # server.register_grant(AuthorisationCodeGrant, [CodeChallenge(required=False)])
+    # below
+    server.register_grant(AuthorisationCodeGrant)
 
     # register endpoints
     server.register_endpoint(RevocationEndpoint)
diff --git a/gn3/auth/authentication/oauth2/views.py b/gn3/auth/authentication/oauth2/views.py
index 3a14a48..48a97da 100644
--- a/gn3/auth/authentication/oauth2/views.py
+++ b/gn3/auth/authentication/oauth2/views.py
@@ -1,14 +1,28 @@
 """Endpoints for the oauth2 server"""
 import uuid
+import traceback
 
-from flask import Response, Blueprint, current_app as app
+from email_validator import validate_email, EmailNotValidError
+from flask import (
+    flash,
+    request,
+    url_for,
+    redirect,
+    Response,
+    Blueprint,
+    render_template,
+    current_app as app)
 
+from gn3.auth import db
+from gn3.auth.db_utils import with_db_connection
 from gn3.auth.authorisation.errors import ForbiddenAccess
 
 from .resource_server import require_oauth
 from .endpoints.revocation import RevocationEndpoint
 from .endpoints.introspection import IntrospectionEndpoint
 
+from ..users import valid_login, NotFoundError, user_by_email
+
 auth = Blueprint("auth", __name__)
 
 @auth.route("/register-client", methods=["GET", "POST"])
@@ -24,7 +38,41 @@ def delete_client(client_id: uuid.UUID):
 @auth.route("/authorise", methods=["GET", "POST"])
 def authorise():
     """Authorise a user"""
-    return "WOULD AUTHORISE THE USER."
+    server = app.config["OAUTH2_SERVER"]
+    client_id = uuid.UUID(request.args.get("client_id", str(uuid.uuid4())))
+    client = server.query_client(client_id)
+    if not bool(client):
+        flash("Invalid OAuth2 client.", "alert-error")
+    if request.method == "GET":
+        client = server.query_client(request.args.get("client_id"))
+        return render_template(
+            "oauth2/authorise-user.html",
+            client=client,
+            scope=client.scope,
+            response_type="code")
+
+    form = request.form
+    def __authorise__(conn: db.DbConnection) -> Response:
+        email_passwd_msg = "Email or password is invalid!"
+        redirect_response = redirect(url_for("oauth2.auth.authorise",
+                                             client_id=client_id))
+        try:
+            email = validate_email(form.get("user:email"))
+            user = user_by_email(conn, email["email"])
+            if valid_login(conn, user, form.get("user:password", "")):
+                return server.create_authorization_response(request=request, grant_user=user)
+            flash(email_passwd_msg, "alert-error")
+            return redirect_response # type: ignore[return-value]
+        except EmailNotValidError as _enve:
+            app.logger.debug(traceback.format_exc())
+            flash(email_passwd_msg, "alert-error")
+            return redirect_response # type: ignore[return-value]
+        except NotFoundError as _nfe:
+            app.logger.debug(traceback.format_exc())
+            flash(email_passwd_msg, "alert-error")
+            return redirect_response # type: ignore[return-value]
+
+    return with_db_connection(__authorise__)
 
 @auth.route("/token", methods=["POST"])
 def token():
diff --git a/gn3/templates/base.html b/gn3/templates/base.html
new file mode 100644
index 0000000..c1070ed
--- /dev/null
+++ b/gn3/templates/base.html
@@ -0,0 +1,17 @@
+{% from "common-macros.html" import flash_messages%}
+<!DOCTYPE html>
+<html lang="en">
+  <head>
+    <meta charset="utf-8" />
+    <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+
+    <title>Genenetwork 3: {%block title%}{%endblock%}</title>
+
+    {%block css%}{%endblock%}
+  </head>
+
+  <body>
+    {%block content%}{%endblock%}
+    {%block js%}{%endblock%}
+  <body>
+</html>
diff --git a/gn3/templates/common-macros.html b/gn3/templates/common-macros.html
new file mode 100644
index 0000000..1d9f302
--- /dev/null
+++ b/gn3/templates/common-macros.html
@@ -0,0 +1,7 @@
+{%macro flash_messages()%}
+<div class="alert-messages">
+  {%for category,message in get_flashed_messages(with_categories=true)%}
+  <div class="alert {{category}}" role="alert">{{message}}</div>
+  {%endfor%}
+</div>
+{%endmacro%}
diff --git a/gn3/templates/oauth2/authorise-user.html b/gn3/templates/oauth2/authorise-user.html
new file mode 100644
index 0000000..d40379f
--- /dev/null
+++ b/gn3/templates/oauth2/authorise-user.html
@@ -0,0 +1,40 @@
+{%extends "base.html"%}
+
+{%block title%}Authorise User{%endblock%}
+
+{%block content%}
+{{flash_messages()}}
+
+<h1>Authenticate to the API Server</h1>
+
+<form method="POST" action="#">
+  <input type="hidden" name="response_type" value="{{response_type}}" />
+  <input type="hidden" name="scope" value="{{scope | join(' ')}}" />
+  <p>
+    You are authorising "{{client.client_metadata.client_name}}" to access
+    Genenetwork 3 with the following scope:
+  </p>
+  <fieldset>
+    <legend>Scope</legend>
+    {%for scp in scope%}
+    <input id="scope:{{scp}}" type="checkbox" name="scope[]" value="{{scp}}"
+	   checked="checked" disabled="disabled" />
+    <label for="scope:{{scp}}">{{scp}}</label>
+    <br />
+    {%endfor%}
+  </fieldset>
+
+  <fieldset>
+    <legend>User Credentials</legend>
+    <label for="user:email">Email</label>
+    <input type="email" name="user:email" id="user:email" required="required" />
+    <br />
+
+    <label for="user:password">Password</label>
+    <input type="password" name="user:password" id="user:password"
+	   required="required" />
+  </fieldset>
+  
+  <input type="submit" value="authorise" />
+</form>
+{%endblock%}
diff --git a/main.py b/main.py
index 6890b33..3c4b146 100644
--- a/main.py
+++ b/main.py
@@ -74,10 +74,11 @@ def init_dev_clients():
             "token_endpoint_auth_method": [
                 "client_secret_post", "client_secret_basic"],
             "client_type": "confidential",
-            "grant_types": ["password", "authorisation_code", "refresh_token"],
+            "grant_types": ["password", "authorization_code", "refresh_token"],
             "default_redirect_uri": "http://localhost:5033/oauth2/code",
-            "redirect_uris": ["http://localhost:5033/oauth2/code"],
-            "response_type": "token", # choices: ["code", "token"]
+            "redirect_uris": ["http://localhost:5033/oauth2/code",
+                              "http://localhost:5033/oauth2/token"],
+            "response_type": ["code", "token"],
             "scope": ["profile", "group", "role", "resource", "register-client",
                       "user", "migrate-data", "introspect"]
         }),