about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/auth/authorisation/users/admin/ui.py24
-rw-r--r--gn3/auth/authorisation/users/admin/views.py32
-rw-r--r--gn3/session.py60
3 files changed, 78 insertions, 38 deletions
diff --git a/gn3/auth/authorisation/users/admin/ui.py b/gn3/auth/authorisation/users/admin/ui.py
index 7357136..242c7a6 100644
--- a/gn3/auth/authorisation/users/admin/ui.py
+++ b/gn3/auth/authorisation/users/admin/ui.py
@@ -1,28 +1,12 @@
 """UI utilities for the auth system."""
 from functools import wraps
-from datetime import datetime, timezone
-from flask import flash, session, request, url_for, redirect
+from flask import flash, url_for, redirect
 
 from gn3.auth.authentication.users import User
 from gn3.auth.db_utils import with_db_connection
 from gn3.auth.authorisation.roles.models import user_roles
 
-SESSION_KEY = "session_details"
-
-def __session_expired__():
-    """Check whether the session has expired."""
-    return datetime.now(tz=timezone.utc) >= session[SESSION_KEY]["expires"]
-
-def logged_in(func):
-    """Verify the user is logged in."""
-    @wraps(func)
-    def __logged_in__(*args, **kwargs):
-        if bool(session.get(SESSION_KEY)) and not __session_expired__():
-            return func(*args, **kwargs)
-        flash("You need to be logged in to access that page.", "alert-danger")
-        return redirect(url_for(
-            "oauth2.admin.login", next=request.url_rule.endpoint))
-    return __logged_in__
+from gn3.session import logged_in, session_user, clear_session_info
 
 def is_admin(func):
     """Verify user is a system admin."""
@@ -32,12 +16,12 @@ def is_admin(func):
         admin_roles = [
             role for role in with_db_connection(
                 lambda conn: user_roles(
-                    conn, User(**session[SESSION_KEY]["user"])))
+                    conn, User(**session_user())))
             if role.role_name == "system-administrator"]
         if len(admin_roles) > 0:
             return func(*args, **kwargs)
         flash("Expected a system administrator.", "alert-danger")
         flash("You have been logged out of the system.", "alert-info")
-        session.pop(SESSION_KEY)
+        clear_session_info()
         return redirect(url_for("oauth2.admin.login"))
     return __admin__
diff --git a/gn3/auth/authorisation/users/admin/views.py b/gn3/auth/authorisation/users/admin/views.py
index cf6fa59..ee76354 100644
--- a/gn3/auth/authorisation/users/admin/views.py
+++ b/gn3/auth/authorisation/users/admin/views.py
@@ -8,7 +8,6 @@ from datetime import datetime, timezone, timedelta
 from email_validator import validate_email, EmailNotValidError
 from flask import (
     flash,
-    session,
     request,
     url_for,
     redirect,
@@ -16,6 +15,8 @@ from flask import (
     current_app,
     render_template)
 
+
+from gn3 import session
 from gn3.auth import db
 from gn3.auth.db_utils import with_db_connection
 
@@ -29,22 +30,17 @@ from gn3.auth.authentication.users import (
     user_by_email,
     hash_password)
 
-from .ui import SESSION_KEY, is_admin
+from .ui import is_admin
 
 admin = Blueprint("admin", __name__)
 
 @admin.before_request
 def update_expires():
     """Update session expiration."""
-    if bool(session.get(SESSION_KEY)):
-        now = datetime.now(tz=timezone.utc)
-        if now >= session[SESSION_KEY]["expires"]:
-            flash("Session has expired. Logging out...", "alert-warning")
-            session.pop(SESSION_KEY)
-            return redirect(url_for("oauth2.admin.login"))
-        # If not expired, extend expiry.
-        session[SESSION_KEY]["expires"] = now + timedelta(minutes=10)
-
+    if session.session_info() and not session.update_expiry():
+        flash("Session has expired. Logging out...", "alert-warning")
+        session.clear_session_info()
+        return redirect(url_for("oauth2.admin.login"))
     return None
 
 @admin.route("/dashboard", methods=["GET"])
@@ -71,10 +67,10 @@ def login():
         with db.connection(current_app.config["AUTH_DB"]) as conn:
             user = user_by_email(conn, email["email"])
             if valid_login(conn, user, password):
-                session[SESSION_KEY] = {
-                    "user": user._asdict(),
-                    "expires": datetime.now(tz=timezone.utc) + timedelta(minutes=10)
-                }
+                session.update_session_info(
+                    user=user._asdict(),
+                    expires=(
+                        datetime.now(tz=timezone.utc) + timedelta(minutes=10)))
                 return redirect(url_for(next_uri))
             flash(error_message, "alert-danger")
             return login_page
@@ -85,10 +81,10 @@ def login():
 @admin.route("/logout", methods=["GET"])
 def logout():
     """Log out the admin."""
-    if not bool(session.get(SESSION_KEY)):
+    if not session.session_info():
         flash("Not logged in.", "alert-info")
         return redirect(url_for("oauth2.admin.login"))
-    session.pop(SESSION_KEY)
+    session.clear_session_info()
     flash("Logged out", "alert-success")
     return redirect(url_for("oauth2.admin.login"))
 
@@ -125,7 +121,7 @@ def register_client():
             "admin/register-client.html",
             scope=current_app.config["OAUTH2_SCOPE"],
             users=with_db_connection(__list_users__),
-            current_user=session[SESSION_KEY]["user"])
+            current_user=session.session_user())
 
     form = request.form
     raw_client_secret = random_string()
diff --git a/gn3/session.py b/gn3/session.py
new file mode 100644
index 0000000..f4f53a0
--- /dev/null
+++ b/gn3/session.py
@@ -0,0 +1,60 @@
+"""Handle any GN3 sessions."""
+from functools import wraps
+from datetime import datetime, timezone, timedelta
+
+from flask import flash, request, session, url_for, redirect
+
+__SESSION_KEY__ = "GN::3::session_details"
+
+def __session_expired__():
+    """Check whether the session has expired."""
+    return datetime.now(tz=timezone.utc) >= session[__SESSION_KEY__]["expires"]
+
+def logged_in(func):
+    """Verify the user is logged in."""
+    @wraps(func)
+    def __logged_in__(*args, **kwargs):
+        if bool(session.get(__SESSION_KEY__)) and not __session_expired__():
+            return func(*args, **kwargs)
+        flash("You need to be logged in to access that page.", "alert-danger")
+        return redirect(url_for(
+            "oauth2.admin.login", next=request.url_rule.endpoint))
+    return __logged_in__
+
+def session_info():
+    """Retrieve the session information."""
+    return session.get(__SESSION_KEY__, False)
+
+def session_user():
+    """Retrieve session user."""
+    info = session_info()
+    return info and info["user"]
+
+def clear_session_info():
+    """Clear any session info."""
+    try:
+        session.pop(__SESSION_KEY__)
+    except KeyError as _keyerr:
+        pass
+
+def session_expired() -> bool:
+    """
+    Check whether the session has expired. Will always return `True` if no
+    session currently exists.
+    """
+    if bool(session.get(__SESSION_KEY__)):
+        now = datetime.now(tz=timezone.utc)
+        return now >= session[__SESSION_KEY__]["expires"]
+    return True
+
+def update_expiry() -> bool:
+    """Update the session expiry and return a boolean indicating success."""
+    if not session_expired():
+        now = datetime.now(tz=timezone.utc)
+        session[__SESSION_KEY__]["expires"] = now + timedelta(minutes=10)
+        return True
+    return False
+
+def update_session_info(**info):
+    """Update the session information."""
+    session[__SESSION_KEY__] = info