about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/api/case_attributes.py (renamed from gn3/case_attributes.py)62
-rw-r--r--gn3/app.py2
2 files changed, 47 insertions, 17 deletions
diff --git a/gn3/case_attributes.py b/gn3/api/case_attributes.py
index 7646c92..7053e6a 100644
--- a/gn3/case_attributes.py
+++ b/gn3/api/case_attributes.py
@@ -5,7 +5,6 @@ import json
 import uuid
 import tempfile
 import lmdb
-import pickle
 from typing import Union
 
 from pathlib import Path
@@ -42,6 +41,7 @@ CATTR_DIFFS_DIR = "case-attribute-diffs"
 
 class NoDiffError(ValueError):
     """Raised if there is no difference between the old and new data."""
+
     def __init__(self):
         """Initialise exception."""
         super().__init__(
@@ -50,7 +50,8 @@ class NoDiffError(ValueError):
 
 class CAJSONEncoder(json.JSONEncoder):
     """Encoder for CaseAttribute-specific data"""
-    def default(self, obj): # pylint: disable=[arguments-renamed]
+
+    def default(self, obj):  # pylint: disable=[arguments-renamed]
         """Default encoder"""
         if isinstance(obj, datetime):
             return obj.isoformat()
@@ -58,6 +59,7 @@ class CAJSONEncoder(json.JSONEncoder):
             return str(obj)
         return json.JSONEncoder.default(self, obj)
 
+
 def required_access(
         token: dict,
         inbredset_id: int,
@@ -86,7 +88,8 @@ def required_access(
                     urljoin(current_app.config["AUTH_SERVER_URL"],
                             "auth/resource/authorisation"),
                     json={"resource-ids": [resource_id]},
-                    headers={"Authorization": f"Bearer {token['access_token']}"},
+                    headers={
+                        "Authorization": f"Bearer {token['access_token']}"},
                     timeout=300)
                 if auth.status_code == 200:
                     privs = tuple(priv["privilege_id"]
@@ -101,8 +104,6 @@ def required_access(
         f"User does not have the privileges {access_levels}")
 
 
-
-
 def __inbredset_group__(conn, inbredset_id):
     """Return InbredSet group's top-level details."""
     with conn.cursor(cursorclass=DictCursor) as cursor:
@@ -111,6 +112,7 @@ def __inbredset_group__(conn, inbredset_id):
             {"inbredset_id": inbredset_id})
         return dict(cursor.fetchone())
 
+
 def __inbredset_strains__(conn, inbredset_id):
     """Return all samples/strains for given InbredSet group."""
     with conn.cursor(cursorclass=DictCursor) as cursor:
@@ -121,6 +123,7 @@ def __inbredset_strains__(conn, inbredset_id):
             {"inbredset_id": inbredset_id})
         return tuple(dict(row) for row in cursor.fetchall())
 
+
 def __case_attribute_labels_by_inbred_set__(conn, inbredset_id):
     """Return the case-attribute labels/names for the given InbredSet group."""
     with conn.cursor(cursorclass=DictCursor) as cursor:
@@ -129,18 +132,21 @@ def __case_attribute_labels_by_inbred_set__(conn, inbredset_id):
             {"inbredset_id": inbredset_id})
         return tuple(dict(row) for row in cursor.fetchall())
 
+
 @caseattr.route("/<int:inbredset_id>", methods=["GET"])
 def inbredset_group(inbredset_id: int) -> Response:
     """Retrieve InbredSet group's details."""
     with database_connection(current_app.config["SQL_URI"]) as conn:
         return jsonify(__inbredset_group__(conn, inbredset_id))
 
+
 @caseattr.route("/<int:inbredset_id>/strains", methods=["GET"])
 def inbredset_strains(inbredset_id: int) -> Response:
     """Retrieve ALL strains/samples relating to a specific InbredSet group."""
     with database_connection(current_app.config["SQL_URI"]) as conn:
         return jsonify(__inbredset_strains__(conn, inbredset_id))
 
+
 @caseattr.route("/<int:inbredset_id>/names", methods=["GET"])
 def inbredset_case_attribute_names(inbredset_id: int) -> Response:
     """Retrieve ALL case-attributes for a specific InbredSet group."""
@@ -148,6 +154,7 @@ def inbredset_case_attribute_names(inbredset_id: int) -> Response:
         return jsonify(
             __case_attribute_labels_by_inbred_set__(conn, inbredset_id))
 
+
 def __by_strain__(accumulator, item):
     attr = {item["CaseAttributeName"]: item["CaseAttributeValue"]}
     strain_name = item["StrainName"]
@@ -166,13 +173,14 @@ def __by_strain__(accumulator, item):
         **accumulator,
         strain_name: {
             **{
-                key: value for key,value in item.items()
+                key: value for key, value in item.items()
                 if key in ("StrainName", "StrainName2", "Symbol", "Alias")
             },
             "case-attributes": attr
         }
     }
 
+
 def __case_attribute_values_by_inbred_set__(
         conn: Connection, inbredset_id: int) -> tuple[dict, ...]:
     """
@@ -195,12 +203,14 @@ def __case_attribute_values_by_inbred_set__(
         return tuple(
             reduce(__by_strain__, cursor.fetchall(), {}).values())
 
+
 @caseattr.route("/<int:inbredset_id>/values", methods=["GET"])
 def inbredset_case_attribute_values(inbredset_id: int) -> Response:
     """Retrieve the group's (InbredSet's) case-attribute values."""
     with database_connection(current_app.config["SQL_URI"]) as conn:
         return jsonify(__case_attribute_values_by_inbred_set__(conn, inbredset_id))
 
+
 def __process_orig_data__(fieldnames, cadata, strains) -> tuple[dict, ...]:
     """Process data from database and return tuple of dicts."""
     data = {item["StrainName"]: item for item in cadata}
@@ -214,18 +224,20 @@ def __process_orig_data__(fieldnames, cadata, strains) -> tuple[dict, ...]:
             }
         } for strain in strains)
 
+
 def __process_edit_data__(fieldnames, form_data) -> tuple[dict, ...]:
     """Process data from form and return tuple of dicts."""
     def __process__(acc, strain_cattrs):
         strain, cattrs = strain_cattrs
         return acc + ({
             "Sample": strain, **{
-            field: cattrs["case-attributes"].get(field, "")
-            for field in fieldnames[1:]
+                field: cattrs["case-attributes"].get(field, "")
+                for field in fieldnames[1:]
             }
         },)
     return reduce(__process__, form_data.items(), tuple())
 
+
 def __write_csv__(fieldnames, data):
     """Write the given `data` to a csv file and return the path to the file."""
     fdesc, filepath = tempfile.mkstemp(".csv")
@@ -237,6 +249,7 @@ def __write_csv__(fieldnames, data):
 
     return filepath
 
+
 def __compute_diff__(
         fieldnames: tuple[str, ...],
         original_data: tuple[dict, ...],
@@ -252,6 +265,7 @@ def __compute_diff__(
         return json.loads(diff_results["output"])
     return {}
 
+
 def __queue_diff__(conn: Connection, diff_data, diff_data_dir: Path) -> Path:
     """
     Queue diff for future processing.
@@ -278,6 +292,7 @@ def __queue_diff__(conn: Connection, diff_data, diff_data_dir: Path) -> Path:
         return filepath
     raise NoDiffError
 
+
 def __save_diff__(conn: Connection, diff_data: dict, status: EditStatus) -> int:
     """Save to the database."""
     with conn.cursor() as cursor:
@@ -295,6 +310,7 @@ def __save_diff__(conn: Connection, diff_data: dict, status: EditStatus) -> int:
             })
         return diff_data.get("db_id") or cursor.lastrowid
 
+
 def __parse_diff_json__(json_str):
     """Parse the json string to python objects."""
     raw_diff = json.loads(json_str)
@@ -309,11 +325,13 @@ def __parse_diff_json__(json_str):
                     if raw_diff.get("created") else None)
     }
 
+
 def __load_diff__(diff_filename):
     """Load the diff."""
     with open(diff_filename, encoding="utf8") as diff_file:
         return __parse_diff_json__(diff_file.read())
 
+
 def __apply_additions__(
         cursor, inbredset_id: int, additions_diff) -> None:
     """Apply additions: creates new case attributes."""
@@ -327,6 +345,7 @@ def __apply_additions__(
             "desc": diff["description"]
         } for diff in additions_diff))
 
+
 def __apply_modifications__(
         cursor, inbredset_id: int, modifications_diff, fieldnames) -> None:
     """Apply modifications: changes values of existing case attributes."""
@@ -387,6 +406,7 @@ def __apply_modifications__(
             for cattr in (key for key in row.keys() if key != "Sample")
             if not bool(row[cattr].strip())))
 
+
 def __apply_deletions__(
         cursor, inbredset_id: int, deletions_diff) -> None:
     """Apply deletions: delete existing case attributes and their values."""
@@ -404,6 +424,7 @@ def __apply_deletions__(
         "InbredSetId=:inbredset_id AND CaseAttributeId=:case_attribute_id",
         params)
 
+
 def __apply_diff__(
         conn: Connection, auth_token, inbredset_id: int, diff_filename, the_diff) -> None:
     """
@@ -426,6 +447,7 @@ def __apply_diff__(
             f"{diff_filename.stem}-approved{diff_filename.suffix}")
         os.rename(diff_filename, new_path)
 
+
 def __reject_diff__(conn: Connection,
                     auth_token: dict,
                     inbredset_id: int,
@@ -440,13 +462,14 @@ def __reject_diff__(conn: Connection,
                     ("system:inbredset:edit-case-attribute",
                      "system:inbredset:apply-case-attribute-edit"))
     __save_diff__(conn, diff, EditStatus.rejected)
-    new_path = Path(diff_filename.parent, f"{diff_filename.stem}-rejected{diff_filename.suffix}")
+    new_path = Path(diff_filename.parent,
+                    f"{diff_filename.stem}-rejected{diff_filename.suffix}")
     os.rename(diff_filename, new_path)
     return diff_filename
 
 
 def __update_case_attributes__(
-    cursor, inbredset_id: int, modifications) -> None:
+        cursor, inbredset_id: int, modifications) -> None:
     for strain, changes in modifications.items():
         for case_attribute, value in changes.items():
             value = value.strip()
@@ -474,21 +497,23 @@ def add_case_attributes(inbredset_id: int, auth_token=None) -> Response:
     """Add a new case attribute for `InbredSetId`."""
     required_access(
         auth_token, inbredset_id, ("system:inbredset:create-case-attribute",))
-    with database_connection(current_app.config["SQL_URI"]) as conn: # pylint: disable=[unused-variable]
+    with database_connection(current_app.config["SQL_URI"]) as conn:  # pylint: disable=[unused-variable]
         raise NotImplementedError
 
+
 @caseattr.route("/<int:inbredset_id>/delete", methods=["POST"])
 @require_token
 def delete_case_attributes(inbredset_id: int, auth_token=None) -> Response:
     """Delete a case attribute from `InbredSetId`."""
     required_access(
         auth_token, inbredset_id, ("system:inbredset:delete-case-attribute",))
-    with database_connection(current_app.config["SQL_URI"]) as conn: # pylint: disable=[unused-variable]
+    with database_connection(current_app.config["SQL_URI"]) as conn:  # pylint: disable=[unused-variable]
         raise NotImplementedError
 
+
 @caseattr.route("/<int:inbredset_id>/edit", methods=["POST"])
 @require_token
-def edit_case_attributes(inbredset_id: int, auth_token = None) -> Response:
+def edit_case_attributes(inbredset_id: int, auth_token=None) -> Response:
     """Edit the case attributes for `InbredSetId` based on data received.
 
     :inbredset_id: Identifier for the population that the case attribute belongs
@@ -538,6 +563,7 @@ def edit_case_attributes(inbredset_id: int, auth_token = None) -> Response:
                             "queued for approval."),
             })
 
+
 @caseattr.route("/<int:inbredset_id>/diff/list", methods=["GET"])
 def list_diffs(inbredset_id: int) -> Response:
     """List any changes that have not been approved/rejected."""
@@ -584,25 +610,28 @@ def list_diffs(inbredset_id: int) -> Response:
                 f"{diff['json_diff_data']['user_id']}:::"
                 f"{diff['time_stamp'].isoformat()}")
         } for diff in diffs
-              if diff["json_diff_data"].get("inbredset_id") == inbredset_id),
+            if diff["json_diff_data"].get("inbredset_id") == inbredset_id),
         cls=CAJSONEncoder))
     resp.headers["Content-Type"] = "application/json"
     return resp
 
+
 @caseattr.route("/approve/<path:filename>", methods=["POST"])
 @require_token
-def approve_case_attributes_diff(filename: str, auth_token = None) -> Response:
+def approve_case_attributes_diff(filename: str, auth_token=None) -> Response:
     """Approve the changes to the case attributes in the diff."""
     diff_dir = Path(current_app.config["TMPDIR"], CATTR_DIFFS_DIR)
     diff_filename = Path(diff_dir, filename)
     the_diff = __load_diff__(diff_filename)
     with database_connection(current_app.config["SQL_URI"]) as conn:
-        __apply_diff__(conn, auth_token, the_diff["inbredset_id"], diff_filename, the_diff)
+        __apply_diff__(conn, auth_token,
+                       the_diff["inbredset_id"], diff_filename, the_diff)
         return jsonify({
             "message": "Applied the diff successfully.",
             "diff_filename": diff_filename.name
         })
 
+
 @caseattr.route("/reject/<path:filename>", methods=["POST"])
 @require_token
 def reject_case_attributes_diff(filename: str, auth_token=None) -> Response:
@@ -621,6 +650,7 @@ def reject_case_attributes_diff(filename: str, auth_token=None) -> Response:
             "diff_filename": diff_filename.name
         })
 
+
 @caseattr.route("/<int:inbredset_id>/diff/<int:diff_id>/view", methods=["GET"])
 @require_token
 def view_diff(inbredset_id: int, diff_id: int, auth_token=None) -> Response:
diff --git a/gn3/app.py b/gn3/app.py
index 74bb5ab..6b5efa4 100644
--- a/gn3/app.py
+++ b/gn3/app.py
@@ -29,7 +29,7 @@ from gn3.api.sampledata import sampledata
 from gn3.api.llm import gnqa
 from gn3.api.rqtl2 import rqtl2
 from gn3.api.streaming import streaming
-from gn3.case_attributes import caseattr
+from gn3.api.case_attributes import caseattr
 from gn3.api.lmdb_sample_data import lmdb_sample_data