about summary refs log tree commit diff
path: root/gn2/wqflask
diff options
context:
space:
mode:
Diffstat (limited to 'gn2/wqflask')
-rw-r--r--gn2/wqflask/__init__.py26
-rw-r--r--gn2/wqflask/oauth2/toplevel.py7
2 files changed, 22 insertions, 11 deletions
diff --git a/gn2/wqflask/__init__.py b/gn2/wqflask/__init__.py
index 6b6c48ac..f6e9ef53 100644
--- a/gn2/wqflask/__init__.py
+++ b/gn2/wqflask/__init__.py
@@ -45,13 +45,6 @@ from gn2.wqflask.startup import (
     startup_errors,
     check_mandatory_configs)
 
-app = Flask(__name__)
-
-
-# See http://flask.pocoo.org/docs/config/#configuring-from-files
-# Note no longer use the badly named WQFLASK_OVERRIDES (nyi)
-app.config.from_object('gn2.default_settings')
-app.config.from_envvar('GN2_SETTINGS')
 
 def numcoll():
     """Handle possible errors."""
@@ -60,6 +53,21 @@ def numcoll():
     except Exception as _exc:
         return "ERROR"
 
+
+def parse_ssl_key(app: Flask, keyconfig: str):
+    """Parse key file paths into objects"""
+    with open(app.config[keyconfig]) as _sslkey:
+        app.config[keyconfig] = JsonWebKey.import_key(_sslkey.read())
+
+
+
+app = Flask(__name__)
+
+# See http://flask.pocoo.org/docs/config/#configuring-from-files
+# Note no longer use the badly named WQFLASK_OVERRIDES (nyi)
+app.config.from_object('gn2.default_settings')
+app.config.from_envvar('GN2_SETTINGS')
+
 app.jinja_env.globals.update(
     undefined=jinja2.StrictUndefined,
     numify=formatting.numify,
@@ -108,8 +116,8 @@ except StartupError as serr:
 
 server_session = Session(app)
 
-with open(app.config["SSL_KEY_PAIR_PRIVATE_KEY"]) as _sslkey:
-    app.config["JWT_PRIVATE_KEY"] = JsonWebKey.import_key(_sslkey.read())
+parse_ssl_key(app, "SSL_PRIVATE_KEY")
+parse_ssl_key(app, "AUTH_SERVER_SSL_PUBLIC_KEY")
 
 @app.before_request
 def before_request():
diff --git a/gn2/wqflask/oauth2/toplevel.py b/gn2/wqflask/oauth2/toplevel.py
index bc32e80e..a1e9196d 100644
--- a/gn2/wqflask/oauth2/toplevel.py
+++ b/gn2/wqflask/oauth2/toplevel.py
@@ -46,7 +46,7 @@ def authorisation_code():
     code = request.args.get("code", "")
     if bool(code):
         base_url = urlparse(request.base_url, scheme=request.scheme)
-        jwtkey = app.config["JWT_PRIVATE_KEY"]
+        jwtkey = app.config["SSL_PRIVATE_KEY"]
         issued = datetime.datetime.now()
         request_data = {
             "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
@@ -56,7 +56,10 @@ def authorisation_code():
                 urlunparse(base_url),
                 url_for("oauth2.toplevel.authorisation_code")),
             "assertion": jwt.encode(
-                header={"alg": "RS256", "typ": "jwt", "kid": jwtkey.kid},
+                header={
+                    "alg": "RS256",
+                    "typ": "jwt",
+                    "kid": jwtkey.as_dict()["kid"]},
                 payload={
                     "iss": str(oauth2_clientid()),
                     "sub": request.args["user_id"],