about summary refs log tree commit diff
path: root/gn3/auth
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/auth')
-rw-r--r--gn3/auth/authorisation/users/admin/ui.py24
-rw-r--r--gn3/auth/authorisation/users/admin/views.py32
2 files changed, 18 insertions, 38 deletions
diff --git a/gn3/auth/authorisation/users/admin/ui.py b/gn3/auth/authorisation/users/admin/ui.py
index 7357136..242c7a6 100644
--- a/gn3/auth/authorisation/users/admin/ui.py
+++ b/gn3/auth/authorisation/users/admin/ui.py
@@ -1,28 +1,12 @@
 """UI utilities for the auth system."""
 from functools import wraps
-from datetime import datetime, timezone
-from flask import flash, session, request, url_for, redirect
+from flask import flash, url_for, redirect
 
 from gn3.auth.authentication.users import User
 from gn3.auth.db_utils import with_db_connection
 from gn3.auth.authorisation.roles.models import user_roles
 
-SESSION_KEY = "session_details"
-
-def __session_expired__():
-    """Check whether the session has expired."""
-    return datetime.now(tz=timezone.utc) >= session[SESSION_KEY]["expires"]
-
-def logged_in(func):
-    """Verify the user is logged in."""
-    @wraps(func)
-    def __logged_in__(*args, **kwargs):
-        if bool(session.get(SESSION_KEY)) and not __session_expired__():
-            return func(*args, **kwargs)
-        flash("You need to be logged in to access that page.", "alert-danger")
-        return redirect(url_for(
-            "oauth2.admin.login", next=request.url_rule.endpoint))
-    return __logged_in__
+from gn3.session import logged_in, session_user, clear_session_info
 
 def is_admin(func):
     """Verify user is a system admin."""
@@ -32,12 +16,12 @@ def is_admin(func):
         admin_roles = [
             role for role in with_db_connection(
                 lambda conn: user_roles(
-                    conn, User(**session[SESSION_KEY]["user"])))
+                    conn, User(**session_user())))
             if role.role_name == "system-administrator"]
         if len(admin_roles) > 0:
             return func(*args, **kwargs)
         flash("Expected a system administrator.", "alert-danger")
         flash("You have been logged out of the system.", "alert-info")
-        session.pop(SESSION_KEY)
+        clear_session_info()
         return redirect(url_for("oauth2.admin.login"))
     return __admin__
diff --git a/gn3/auth/authorisation/users/admin/views.py b/gn3/auth/authorisation/users/admin/views.py
index cf6fa59..ee76354 100644
--- a/gn3/auth/authorisation/users/admin/views.py
+++ b/gn3/auth/authorisation/users/admin/views.py
@@ -8,7 +8,6 @@ from datetime import datetime, timezone, timedelta
 from email_validator import validate_email, EmailNotValidError
 from flask import (
     flash,
-    session,
     request,
     url_for,
     redirect,
@@ -16,6 +15,8 @@ from flask import (
     current_app,
     render_template)
 
+
+from gn3 import session
 from gn3.auth import db
 from gn3.auth.db_utils import with_db_connection
 
@@ -29,22 +30,17 @@ from gn3.auth.authentication.users import (
     user_by_email,
     hash_password)
 
-from .ui import SESSION_KEY, is_admin
+from .ui import is_admin
 
 admin = Blueprint("admin", __name__)
 
 @admin.before_request
 def update_expires():
     """Update session expiration."""
-    if bool(session.get(SESSION_KEY)):
-        now = datetime.now(tz=timezone.utc)
-        if now >= session[SESSION_KEY]["expires"]:
-            flash("Session has expired. Logging out...", "alert-warning")
-            session.pop(SESSION_KEY)
-            return redirect(url_for("oauth2.admin.login"))
-        # If not expired, extend expiry.
-        session[SESSION_KEY]["expires"] = now + timedelta(minutes=10)
-
+    if session.session_info() and not session.update_expiry():
+        flash("Session has expired. Logging out...", "alert-warning")
+        session.clear_session_info()
+        return redirect(url_for("oauth2.admin.login"))
     return None
 
 @admin.route("/dashboard", methods=["GET"])
@@ -71,10 +67,10 @@ def login():
         with db.connection(current_app.config["AUTH_DB"]) as conn:
             user = user_by_email(conn, email["email"])
             if valid_login(conn, user, password):
-                session[SESSION_KEY] = {
-                    "user": user._asdict(),
-                    "expires": datetime.now(tz=timezone.utc) + timedelta(minutes=10)
-                }
+                session.update_session_info(
+                    user=user._asdict(),
+                    expires=(
+                        datetime.now(tz=timezone.utc) + timedelta(minutes=10)))
                 return redirect(url_for(next_uri))
             flash(error_message, "alert-danger")
             return login_page
@@ -85,10 +81,10 @@ def login():
 @admin.route("/logout", methods=["GET"])
 def logout():
     """Log out the admin."""
-    if not bool(session.get(SESSION_KEY)):
+    if not session.session_info():
         flash("Not logged in.", "alert-info")
         return redirect(url_for("oauth2.admin.login"))
-    session.pop(SESSION_KEY)
+    session.clear_session_info()
     flash("Logged out", "alert-success")
     return redirect(url_for("oauth2.admin.login"))
 
@@ -125,7 +121,7 @@ def register_client():
             "admin/register-client.html",
             scope=current_app.config["OAUTH2_SCOPE"],
             users=with_db_connection(__list_users__),
-            current_user=session[SESSION_KEY]["user"])
+            current_user=session.session_user())
 
     form = request.form
     raw_client_secret = random_string()