about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2023-05-12 11:57:12 +0300
committerFrederick Muriuki Muriithi2023-05-12 11:57:12 +0300
commita175ea7ea3b0c85ca4e9c0909833f6842474e225 (patch)
treedb6c123723328e81a47f107cf44e82e9800d53ba
parent96fb589371dd1700f0d0f7abc367098d6820c37a (diff)
downloadgenenetwork2-a175ea7ea3b0c85ca4e9c0909833f6842474e225.tar.gz
auth: Integrate sessions with auth
Rework the sessions to do what was handled by the soon-to-be-obsolete
`wqflask.user_sessions` module.

This is necessary in order to retain the expected functionality of the
user collections, especially:

* anonymous user collections
* authenticated user collections
* import of anonymous collections when user authenticates
-rw-r--r--wqflask/wqflask/oauth2/checks.py28
-rw-r--r--wqflask/wqflask/oauth2/client.py76
-rw-r--r--wqflask/wqflask/oauth2/request_utils.py2
-rw-r--r--wqflask/wqflask/oauth2/session.py74
-rw-r--r--wqflask/wqflask/oauth2/toplevel.py15
-rw-r--r--wqflask/wqflask/oauth2/ui.py7
-rw-r--r--wqflask/wqflask/oauth2/users.py23
7 files changed, 170 insertions, 55 deletions
diff --git a/wqflask/wqflask/oauth2/checks.py b/wqflask/wqflask/oauth2/checks.py
index c60ab1de..3b6d2471 100644
--- a/wqflask/wqflask/oauth2/checks.py
+++ b/wqflask/wqflask/oauth2/checks.py
@@ -3,33 +3,41 @@ from functools import wraps
 from urllib.parse import urljoin
 
 from authlib.integrations.requests_client import OAuth2Session
-from flask import flash, request, session, url_for, redirect, current_app
+from flask import (
+    flash, request, url_for, redirect, current_app, session as flask_session)
+
+from . import session
 
 def user_logged_in():
     """Check whether the user has logged in."""
-    return bool(session.get("oauth2_token", False))
+    suser = session.session_info()["user"]
+    return suser["token"].is_right() and suser["logged_in"]
 
 def require_oauth2(func):
     """Decorator for ensuring user is logged in."""
     @wraps(func)
     def __token_valid__(*args, **kwargs):
         """Check that the user is logged in and their token is valid."""
-        if user_logged_in():
-            config = current_app.config
+        config = current_app.config
+        def __clear_session__(_no_token):
+            session.clear_session_info()
+            flask_session.pop("oauth2_token", None)
+            flask_session.pop("user_details", None)
+            flash("You need to be logged in.", "alert-warning")
+            return redirect("/")
+
+        def __with_token__(token):
             client = OAuth2Session(
                 config["OAUTH2_CLIENT_ID"], config["OAUTH2_CLIENT_SECRET"],
-                token=session["oauth2_token"])
+                token=token)
             resp = client.get(
                 urljoin(config["GN_SERVER_URL"], "oauth2/user"))
             user_details = resp.json()
             if not user_details.get("error", False):
                 return func(*args, **kwargs)
 
-            session.pop("oauth2_token", None)
-            session.pop("user_details", None)
+            return clear_session_info(token)
 
-        flash("You need to be logged in.", "alert-warning")
-        # return redirect(url_for("oauth2.user.login", next=request.endpoint))
-        return redirect("/")
+        return session.user_token().either(__clear_session__, __with_token__)
 
     return __token_valid__
diff --git a/wqflask/wqflask/oauth2/client.py b/wqflask/wqflask/oauth2/client.py
index efa862f2..249d158d 100644
--- a/wqflask/wqflask/oauth2/client.py
+++ b/wqflask/wqflask/oauth2/client.py
@@ -3,50 +3,70 @@ import requests
 from typing import Any, Optional
 from urllib.parse import urljoin
 
-from flask import session, current_app as app
+from flask import jsonify, current_app as app
 from pymonad.maybe import Just, Maybe, Nothing
 from pymonad.either import Left, Right, Either
 from authlib.integrations.requests_client import OAuth2Session
 
+from wqflask.oauth2 import session
+
 SCOPE = "profile group role resource register-client user introspect migrate-data"
 
 def oauth2_client():
     config = app.config
-    return OAuth2Session(
-        config["OAUTH2_CLIENT_ID"], config["OAUTH2_CLIENT_SECRET"],
-        scope=SCOPE, token_endpoint_auth_method="client_secret_post",
-        token=session.get("oauth2_token"))
+    def __client__(token) -> OAuth2Session:
+        return OAuth2Session(
+            config["OAUTH2_CLIENT_ID"], config["OAUTH2_CLIENT_SECRET"],
+            scope=SCOPE, token_endpoint_auth_method="client_secret_post",
+            token=token)
+    return session.user_token().either(
+        lambda _notok: __client__(None),
+        lambda token: __client__(token))
+
+def __no_token__(_err) -> Left:
+    """Handle situation where request is attempted with no token."""
+    resp = requests.models.Response()
+    resp._content = json.dumps({
+        "error": "AuthenticationError",
+        "error-description": ("You need to authenticate to access requested "
+                              "information.")}).encode("utf-8")
+    resp.status_code = 400
+    return Left(resp)
 
 def oauth2_get(uri_path: str, data: dict = {}, **kwargs) -> Either:
-    token = session.get("oauth2_token")
-    config = app.config
-    client = OAuth2Session(
-        config["OAUTH2_CLIENT_ID"], config["OAUTH2_CLIENT_SECRET"],
-        token=token, scope=SCOPE)
-    resp = client.get(
-        urljoin(config["GN_SERVER_URL"], uri_path),
-        data=data,
-        **kwargs)
-    if resp.status_code == 200:
-        return Right(resp.json())
+    def __get__(token) -> Either:
+        config = app.config
+        client = OAuth2Session(
+            config["OAUTH2_CLIENT_ID"], config["OAUTH2_CLIENT_SECRET"],
+            token=token, scope=SCOPE)
+        resp = client.get(
+            urljoin(config["GN_SERVER_URL"], uri_path),
+            data=data,
+            **kwargs)
+        if resp.status_code == 200:
+            return Right(resp.json())
 
-    return Left(resp)
+        return Left(resp)
+
+    return session.user_token().either(__no_token__, __get__)
 
 def oauth2_post(
         uri_path: str, data: Optional[dict] = None, json: Optional[dict] = None,
         **kwargs) -> Either:
-    token = session.get("oauth2_token")
-    config = app.config
-    client = OAuth2Session(
-        config["OAUTH2_CLIENT_ID"], config["OAUTH2_CLIENT_SECRET"],
-        token=token, scope=SCOPE)
-    resp = client.post(
-        urljoin(config["GN_SERVER_URL"], uri_path), data=data, json=json,
-        **kwargs)
-    if resp.status_code == 200:
-        return Right(resp.json())
+    def __post__(token) -> Either:
+        config = app.config
+        client = OAuth2Session(
+            config["OAUTH2_CLIENT_ID"], config["OAUTH2_CLIENT_SECRET"],
+            token=token, scope=SCOPE)
+        resp = client.post(
+            urljoin(config["GN_SERVER_URL"], uri_path), data=data, json=json,
+            **kwargs)
+        if resp.status_code == 200:
+            return Right(resp.json())
 
-    return Left(resp)
+        return Left(resp)
+
+    return session.user_token().either(__no_token__, __post__)
 
 def no_token_get(uri_path: str, **kwargs) -> Either:
     config = app.config
diff --git a/wqflask/wqflask/oauth2/request_utils.py b/wqflask/wqflask/oauth2/request_utils.py
index ac21e223..ef8ee9fd 100644
--- a/wqflask/wqflask/oauth2/request_utils.py
+++ b/wqflask/wqflask/oauth2/request_utils.py
@@ -23,7 +23,7 @@ def raise_unimplemented():
 
 def user_details():
     return oauth2_get("oauth2/user").either(
-        handle_error("oauth2.login"),
+        lambda err: {},
         lambda usr_dets: usr_dets)
 
 def process_error(error: Response,
diff --git a/wqflask/wqflask/oauth2/session.py b/wqflask/wqflask/oauth2/session.py
new file mode 100644
index 00000000..011d95f3
--- /dev/null
+++ b/wqflask/wqflask/oauth2/session.py
@@ -0,0 +1,74 @@
+"""Deal with user sessions"""
+from uuid import UUID, uuid4
+from typing import Any, TypedDict
+
+from flask import request, session
+from pymonad.either import Left, Right, Either
+
+class UserDetails(TypedDict):
+    """Session information relating specifically to the user."""
+    user_id: UUID
+    name: str
+    token: Either
+    logged_in: bool
+
+class SessionInfo(TypedDict):
+    """All Session information we save."""
+    session_id: UUID
+    user: UserDetails
+    anon_id: UUID
+    user_agent: str
+    ip_addr: str
+
+__SESSION_KEY__ = "session_info" # Do not use this outside this module!!
+
+def clear_session_info():
+    """Clears the session."""
+    session.pop(__SESSION_KEY__)
+
+def save_session_info(sess_info: SessionInfo) -> SessionInfo:
+    """Save `session_info`."""
+    # TODO: if it is an existing session, verify that certain important security
+    #       bits have not changed before saving.
+    # old_session_info = session.get(__SESSION_KEY__)
+    # if bool(old_session_info):
+    #     if old_session_info["user_agent"] == request.headers.get("User-Agent"):
+    #         session[__SESSION_KEY__] = sess_info
+    #         return sess_info
+    #     # request session verification
+    #     return verify_session(sess_info)
+    # New session
+    session[__SESSION_KEY__] = sess_info
+    return sess_info
+
+def session_info() -> SessionInfo:
+    """Retrieve the session information"""
+    anon_id = uuid4()
+    return save_session_info(
+        session.get(__SESSION_KEY__, {
+            "session_id": uuid4(),
+            "user": {
+                "user_id": anon_id,
+                "name": "Anonymous User",     
+                "token": Left("INVALID-TOKEN"),
+                "logged_in": False
+            },
+            "anon_id": anon_id,
+            "user_agent": request.headers.get("User-Agent"),
+            "ip_addr": request.environ.get("HTTP_X_FORWARDED_FOR",
+                                           request.remote_addr)
+        }))
+
+def set_user_token(token: str) -> SessionInfo:
+    """Set the user's token."""
+    info = session_info()
+    return save_session_info({
+        **info, "user": {**info["user"], "token": Right(token)}})
+
+def set_user_details(userdets: UserDetails) -> SessionInfo:
+    """Set the user details information"""
+    return save_session_info({**session_info(), "user": userdets})
+
+def user_token() -> Either:
+    """Retrieve the user token."""
+    return session_info()["user"]["token"]
diff --git a/wqflask/wqflask/oauth2/toplevel.py b/wqflask/wqflask/oauth2/toplevel.py
index 109ed06c..ef9ce3db 100644
--- a/wqflask/wqflask/oauth2/toplevel.py
+++ b/wqflask/wqflask/oauth2/toplevel.py
@@ -1,12 +1,14 @@
 """Authentication endpoints."""
+from uuid import UUID
 from urllib.parse import urljoin
 from flask import (
-    flash, request, session, Blueprint, url_for, redirect, render_template,
+    flash, request, Blueprint, url_for, redirect, render_template,
     current_app as app)
 
+from . import session
 from .client import SCOPE, no_token_post
-from .request_utils import process_error
 from .checks import require_oauth2, user_logged_in
+from .request_utils import user_details, process_error
 
 toplevel = Blueprint("toplevel", __name__)
 
@@ -25,7 +27,14 @@ def authorisation_code():
         return redirect("/")
 
     def __success__(token):
-        session["oauth2_token"] = token
+        session.set_user_token(token)
+        udets = user_details()
+        session.set_user_details({
+            "user_id": UUID(udets["user_id"]),
+            "name": udets["name"],
+            "token": session.user_token(),
+            "logged_in": True
+        })
         return redirect(url_for("oauth2.user.user_profile"))
 
     code = request.args.get("code", "")
diff --git a/wqflask/wqflask/oauth2/ui.py b/wqflask/wqflask/oauth2/ui.py
index abf30f4e..315aae2b 100644
--- a/wqflask/wqflask/oauth2/ui.py
+++ b/wqflask/wqflask/oauth2/ui.py
@@ -2,13 +2,13 @@
 from flask import session, render_template
 
 from .client import oauth2_get
+from .checks import user_logged_in
 from .request_utils import process_error
 
 def render_ui(templatepath: str, **kwargs):
     """Handle repetitive UI rendering stuff."""
-    logged_in = lambda: ("oauth2_token" in session and bool(session["oauth2_token"]))
     roles = kwargs.get("roles", tuple()) # Get roles if already provided
-    if logged_in() and not bool(roles): # If not, try fetching them
+    if user_logged_in() and not bool(roles): # If not, try fetching them
         roles_results = oauth2_get("oauth2/user/roles").either(
             lambda err: {"roles_error": process_error(err)},
             lambda roles: {"roles": roles})
@@ -18,7 +18,6 @@ def render_ui(templatepath: str, **kwargs):
         privilege["privilege_id"] for role in roles
         for privilege in role["privileges"])
     kwargs = {
-        **kwargs, "roles": roles, "user_privileges": user_privileges,
-        "logged_in": logged_in
+        **kwargs, "roles": roles, "user_privileges": user_privileges
     }
     return render_template(templatepath, **kwargs)
diff --git a/wqflask/wqflask/oauth2/users.py b/wqflask/wqflask/oauth2/users.py
index 44ba252a..597dfb33 100644
--- a/wqflask/wqflask/oauth2/users.py
+++ b/wqflask/wqflask/oauth2/users.py
@@ -1,11 +1,13 @@
 import requests
+from uuid import UUID
 from urllib.parse import urljoin
 
 from authlib.integrations.base_client.errors import OAuthError
 from flask import (
-    flash, request, session, url_for, redirect, Response, Blueprint,
+    flash, request, url_for, redirect, Response, Blueprint,
     current_app as app)
 
+from . import session
 from .ui import render_ui
 from .checks import require_oauth2, user_logged_in
 from .client import oauth2_get, oauth2_post, oauth2_client
@@ -18,7 +20,6 @@ users = Blueprint("user", __name__)
 def user_profile():
     __id__ = lambda the_val: the_val
     usr_dets = user_details()
-    client = oauth2_client()
     def __render__(usr_dets, roles=[], **kwargs):
         return render_ui(
             "oauth2/view-user.html", user_details=usr_dets, roles=roles,
@@ -74,7 +75,14 @@ def login():
                 username=form.get("email_address"),
                 password=form.get("password"),
                 grant_type="password")
-            session["oauth2_token"] = token
+            session.set_token(token)
+            udets = user_details()
+            session.set_user_details({
+                "user_id": UUID(udets["user_id"]),
+                "name": udets["name"],
+                "token": session.user_token(),
+                "logged_in": True
+            })
         except OAuthError as _oaerr:
             flash(_oaerr.args[0], "alert-danger")
             return render_ui(
@@ -91,13 +99,10 @@ def login():
 @users.route("/logout", methods=["GET", "POST"])
 def logout():
     if user_logged_in():
-        token = session.get("oauth2_token", False)
         config = app.config
-        client = oauth2_client()
-        resp = client.revoke_token(urljoin(config["GN_SERVER_URL"], "oauth2/revoke"))
-        keys = tuple(key for key in session.keys() if not key.startswith("_"))
-        for key in keys:
-            session.pop(key, default=None)
+        resp = oauth2_client().revoke_token(
+            urljoin(config["GN_SERVER_URL"], "oauth2/revoke"))
+        session.clear_session_info()
         flash("Successfully logged out.", "alert-success")
 
     return redirect("/")