about summary refs log tree commit diff
path: root/gn_auth
diff options
context:
space:
mode:
Diffstat (limited to 'gn_auth')
-rw-r--r--gn_auth/auth/authorisation/data/views.py101
-rw-r--r--gn_auth/wsgi.py42
2 files changed, 100 insertions, 43 deletions
diff --git a/gn_auth/auth/authorisation/data/views.py b/gn_auth/auth/authorisation/data/views.py
index 584b239..1184d63 100644
--- a/gn_auth/auth/authorisation/data/views.py
+++ b/gn_auth/auth/authorisation/data/views.py
@@ -59,7 +59,7 @@ def list_species() -> Response:
 
 @data.route("/authorisation", methods=["POST"])
 @require_json
-def authorisation() -> Response:
+def authorisation() -> Response:# pylint: disable=[too-many-locals]
     """Retrieve the authorisation level for datasets/traits for the user."""
     # Access endpoint with something like:
     # curl -X POST http://127.0.0.1:8081/auth/data/authorisation \
@@ -104,10 +104,22 @@ def authorisation() -> Response:
                         authconn, _dset_traits["ProbeSet"]))
             for _rrow in _rtypes
         }
-        if len(_all_resources.keys()) == 0:
+        if (len(_all_resources.keys()) == 0 and
+                len(_dset_traits.get("Temp", tuple())) == 0):
             raise NotFoundError(
                 "No resource(s) found for specified trait(s). Do(es) the "
                 "trait(s) actually exist?")
+
+        # Handle Temp traits specially - they should be public/anonymous resources
+        if len(_dset_traits.get("Temp", tuple())) > 0:
+            # Create a synthetic public resource for Temp traits
+            # Use a predictable ID to identify synthetic temp resources
+            temp_resource_id = "gn-auth-temp-traits"
+            _all_resources[temp_resource_id] = {
+                "resource_id": temp_resource_id,
+                "resource_data": tuple(f"{dset}::{trait}" for dset, trait in _dset_traits["Temp"])
+            }
+
         _resource_ids = tuple(_all_resources.keys())
 
 
@@ -125,42 +137,55 @@ def authorisation() -> Response:
             }
 
         _paramstr = ", ".join(["?"] * len(_resource_ids))
-        try:
-            with require_oauth.acquire("profile group resource") as _token:
-                user = _token.user
-                cursor.execute(
-                    "SELECT ur.resource_id, r.role_id, rp.privilege_id "
-                    "FROM user_roles AS ur "
-                    "INNER JOIN roles AS r ON ur.role_id=r.role_id "
-                    "INNER JOIN role_privileges AS rp ON r.role_id=rp.role_id "
-                    "WHERE ur.user_id = ? "
-                    f"AND ur.resource_id IN ({_paramstr})",
-                    (str(user.user_id),) + _resource_ids
-                )
-                _privileges_by_resource: dict[str, tuple[str, ...]] = reduce(
-                    lambda acc, curr: {
-                        **acc,
-                        curr["resource_id"]: (
-                            acc.get(curr["resource_id"], tuple())
-                            + (curr["privilege_id"],))
-                    },
-                    cursor.fetchall(),
-                    {})
-        except _HTTPException as exc:
-            err_msg = json.loads(exc.body)
-            if err_msg["error"] == "missing_authorization":
-                cursor.execute(
-                    "SELECT rsc.resource_id "
-                    "FROM resources AS rsc "
-                    "WHERE rsc.public = '1' "
-                    f"AND rsc.resource_id IN ({_paramstr}) ",
-                    _resource_ids)
-                _privileges_by_resource = {
-                    row["resource_id"]: ('group:resource:view-resource',)
-                    for row in cursor.fetchall()
-                }
-            else:
-                raise exc from None
+        _privileges_by_resource: dict[str, tuple[str, ...]] = {}
+
+        # Separate synthetic temp resources from real resources
+        temp_resource_id = "gn-auth-temp-traits"
+        real_resource_ids = tuple(rid for rid in _resource_ids if rid != temp_resource_id)
+
+        # Query privileges only for real resources
+        if len(real_resource_ids) > 0:
+            real_paramstr = ", ".join(["?"] * len(real_resource_ids))
+            try:
+                with require_oauth.acquire("profile group resource") as _token:
+                    user = _token.user
+                    cursor.execute(
+                        "SELECT ur.resource_id, r.role_id, rp.privilege_id "
+                        "FROM user_roles AS ur "
+                        "INNER JOIN roles AS r ON ur.role_id=r.role_id "
+                        "INNER JOIN role_privileges AS rp ON r.role_id=rp.role_id "
+                        "WHERE ur.user_id = ? "
+                        f"AND ur.resource_id IN ({real_paramstr})",
+                        (str(user.user_id),) + real_resource_ids
+                    )
+                    _privileges_by_resource = reduce(
+                        lambda acc, curr: {
+                            **acc,
+                            curr["resource_id"]: (
+                                acc.get(curr["resource_id"], tuple())
+                                + (curr["privilege_id"],))
+                        },
+                        cursor.fetchall(),
+                        {})
+            except _HTTPException as exc:
+                err_msg = json.loads(exc.body)
+                if err_msg["error"] == "missing_authorization":
+                    cursor.execute(
+                        "SELECT rsc.resource_id "
+                        "FROM resources AS rsc "
+                        "WHERE rsc.public = '1' "
+                        f"AND rsc.resource_id IN ({real_paramstr}) ",
+                        real_resource_ids)
+                    _privileges_by_resource = {
+                        row["resource_id"]: ('group:resource:view-resource',)
+                        for row in cursor.fetchall()
+                    }
+                else:
+                    raise exc from None
+
+        # Temp resources are always publicly viewable
+        if temp_resource_id in _resource_ids:
+            _privileges_by_resource[temp_resource_id] = ('group:resource:view-resource',)
 
         return jsonify({
             "authorisation": [{
diff --git a/gn_auth/wsgi.py b/gn_auth/wsgi.py
index bc90210..a5af37e 100644
--- a/gn_auth/wsgi.py
+++ b/gn_auth/wsgi.py
@@ -315,7 +315,7 @@ _DEFAULT_SCOPES_ = (
 )
 
 
-def __create_one_client__(
+def __create_one_client__(# pylint: disable=[too-many-arguments, too-many-positional-arguments]
         conn,
         client_name: str,
         owner_user,
@@ -369,8 +369,15 @@ def __create_one_client__(
               help="URI to the client's public JWKS (optional)")
 @click.option("--output", "output_path", type=click.Path(), default=None,
               help="Write credentials as JSON to this file (default: stdout)")
-def create_oauth2_client(client_name, owner_id, redirect_uris, scopes,
-                         grant_types, jwks_uri, output_path):
+def create_oauth2_client(# pylint: disable=[too-many-arguments, too-many-positional-arguments]
+        client_name,
+        owner_id,
+        redirect_uris,
+        scopes,
+        grant_types,
+        jwks_uri,
+        output_path
+):
     """Create an OAuth2 client with specified parameters.
 
     Scopes and grant types default to the full standard set if not provided.
@@ -405,7 +412,7 @@ def create_test_oauth2_client(session_timestamp, users_file, owner_role,
     owner. Client name and secret are auto-generated using the session
     timestamp. Output is written with 0600 permissions.
     """
-    with open(users_file) as f:
+    with open(users_file, encoding="utf8") as f:
         users_data = json.load(f)
 
     owner_record = next(
@@ -445,7 +452,7 @@ def delete_oauth2_client(credentials_path):
     Reads the client_id from the given credentials file and removes the
     client and all associated tokens from the database.
     """
-    with open(credentials_path) as f:
+    with open(credentials_path, encoding="utf8") as f:
         data = json.load(f)
 
     client_id_str = data.get("client", {}).get("client_id")
@@ -462,6 +469,31 @@ def delete_oauth2_client(credentials_path):
         delete_client(conn, the_client.value)
         print(f"Deleted OAuth2 client {client_id}.")
 
+
+@app.cli.command()
+@click.option("--credentials", "credentials_path", required=True,
+              type=click.Path(exists=True),
+              help="Credentials file produced by create-test-users")
+def delete_test_users(credentials_path):
+    """Delete ephemeral test users using a credentials file.
+
+    Reads the credentials file produced by create-test-users and deletes
+    all listed users unconditionally, bypassing policy checks. Intended
+    for CI test teardown.
+    """
+    with open(credentials_path, encoding="utf8") as f:
+        data = json.load(f)
+
+    user_ids = tuple(
+        uuid.UUID(u["user_id"]) for u in data.get("users", []))
+    if not user_ids:
+        print("No users found in credentials file.", file=sys.stderr)
+        sys.exit(1)
+
+    with db.connection(app.config["AUTH_DB"]) as conn:
+        deleted = delete_users_by_id(conn, user_ids)
+        print(f"Deleted {deleted} user(s).")
+
 ##### END: CLI Commands #####
 
 if __name__ == '__main__':