about summary refs log tree commit diff
path: root/gn_auth
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2024-05-02 04:55:11 +0300
committerFrederick Muriuki Muriithi2024-05-02 05:40:56 +0300
commit8e79d0f7b8faba92bebb27c345563ebc1bd6e945 (patch)
tree120847fa6d0dea2590b3a27bd50f33da2e7f9d85 /gn_auth
parent137648f5db940b8a6d65db31eee231ef1ce5d761 (diff)
downloadgn-auth-8e79d0f7b8faba92bebb27c345563ebc1bd6e945.tar.gz
Add error checking to form input data.
Diffstat (limited to 'gn_auth')
-rw-r--r--gn_auth/auth/authorisation/users/admin/views.py67
1 files changed, 67 insertions, 0 deletions
diff --git a/gn_auth/auth/authorisation/users/admin/views.py b/gn_auth/auth/authorisation/users/admin/views.py
index 7ffee95..dc7b8c6 100644
--- a/gn_auth/auth/authorisation/users/admin/views.py
+++ b/gn_auth/auth/authorisation/users/admin/views.py
@@ -3,7 +3,9 @@ import uuid
 import json
 import random
 import string
+from typing import Optional
 from functools import partial
+from urllib.parse import urlparse
 from datetime import datetime, timezone, timedelta
 
 from email_validator import validate_email, EmailNotValidError
@@ -40,6 +42,9 @@ from .ui import is_admin
 
 admin = Blueprint("admin", __name__)
 
+class RegisterClientError(Exception):
+    """Error to raise in case of client registration issues"""
+
 @admin.before_request
 def update_expires():
     """Update session expiration."""
@@ -115,6 +120,53 @@ def __response_types__(grant_types: tuple[str, ...]) -> tuple[str, ...]:
         in (types for grant, types in resps.items() if grant in grant_types)
         for resp_typ in types_list))
 
+def check_string(form, inputname: str, errormessage: str) -> Optional[str]:
+    """Check that an input expecting a string has an actual value."""
+    if not bool(form.get(inputname, "").strip()):
+        return errormessage
+    return None
+
+def check_list(form, inputname: str, errormessage: str) -> Optional[str]:
+    """Check that an input expecting a list has at least one value."""
+    _list = [item for item in form.getlist(inputname) if bool(item.strip())]
+    if not bool(_list):
+        return errormessage
+    return None
+
+def uri_valid(value: str) -> bool:
+    """Check that the `value` is a valid URI"""
+    uri = urlparse(value)
+    return (bool(uri.scheme) and bool(uri.netloc))
+
+def check_register_client_form(form):
+    """Check that all expected data is provided."""
+    errors = (check_list(form,
+                         "scope[]",
+                         "You need to select at least one scope option."),)
+
+    errors = errors + (check_string(
+        form,
+        "client_name",
+        "You need to provide a name for the client being registered."),)
+
+    errors = errors + (check_string(
+        form,
+        "redirect_uri",
+        "You need to provide the main redirect uri."),)
+
+    if not uri_valid(form.get("redirect_uri", "")):
+        errors = errors + ("The provided redirect URI is not a valid URI.",)
+
+    errors = errors + (check_list(
+        form,
+        "scope[]",
+        "You need to select at least one scope option."),)
+
+    errors = tuple(item for item in errors if item is not None)
+    if bool(errors):
+        raise RegisterClientError(errors)
+
+
 @admin.route("/register-client", methods=["GET", "POST"])
 @is_admin
 def register_client():
@@ -134,6 +186,13 @@ def register_client():
 
     form = request.form
     raw_client_secret = random_string()
+    try:
+        check_register_client_form(form)
+    except RegisterClientError as _rce:
+        for error_message in _rce.args:
+            flash(error_message, "alert-danger")
+        return redirect(url_for("oauth2.admin.register_client"))
+
     default_redirect_uri = form["redirect_uri"]
     grant_types = form.getlist("grants[]")
     client = OAuth2Client(
@@ -190,6 +249,14 @@ def view_client(client_id: uuid.UUID):
 def edit_client():
     """Edit the details of the given client."""
     form = request.form
+    try:
+        check_register_client_form(form)
+    except RegisterClientError as _rce:
+        for error_message in _rce.args:
+            flash(error_message, "alert-danger")
+        return redirect(url_for("oauth2.admin.view_client",
+                                client_id=form["client_id"]))
+
     the_client = with_db_connection(partial(
         oauth2_client, client_id=uuid.UUID(form["client_id"])))
     if the_client.is_nothing():