about summary refs log tree commit diff
path: root/gn3
diff options
context:
space:
mode:
Diffstat (limited to 'gn3')
-rw-r--r--gn3/case_attributes.py49
1 files changed, 31 insertions, 18 deletions
diff --git a/gn3/case_attributes.py b/gn3/case_attributes.py
index 89b3c63..04472c9 100644
--- a/gn3/case_attributes.py
+++ b/gn3/case_attributes.py
@@ -1,11 +1,12 @@
 """Implement case-attribute manipulations."""
 import os
 import csv
+import json
 import tempfile
 from functools import reduce
 
 from MySQLdb.cursors import DictCursor
-from flask import jsonify, Response, Blueprint, current_app
+from flask import jsonify, request, Response, Blueprint, current_app
 
 from gn3.commands import run_cmd
 
@@ -26,7 +27,7 @@ def __inbredset_group__(conn, inbredset_id):
             {"inbredset_id": inbredset_id})
         return dict(cursor.fetchone())
 
-def __inbred_set_strains__(conn, inbredset_id):
+def __inbredset_strains__(conn, inbredset_id):
     """Return all samples/strains for given InbredSet group."""
     with conn.cursor(cursorclass=DictCursor) as cursor:
         cursor.execute(
@@ -54,7 +55,7 @@ def inbredset_group(inbredset_id: int) -> Response:
 def inbredset_strains(inbredset_id: int) -> Response:
     """Retrieve ALL strains/samples relating to a specific InbredSet group."""
     with database_connection(current_app.config["SQL_URI"]) as conn:
-        return jsonify(__inbred_set_strains__(conn, inbredset_id))
+        return jsonify(__inbredset_strains__(conn, inbredset_id))
 
 @caseattr.route("/<int:inbredset_id>/names", methods=["GET"])
 def inbredset_case_attribute_names(inbredset_id: int) -> Response:
@@ -115,35 +116,44 @@ def inbredset_case_attribute_values(inbredset_id: int) -> Response:
     with database_connection(current_app.config["SQL_URI"]) as conn:
         return jsonify(__case_attribute_values_by_inbred_set__(conn, inbredset_id))
 
-def __process_orig_data__(data) -> tuple[dict, ...]:
+def __process_orig_data__(fieldnames, cadata, strains) -> tuple[dict, ...]:
     """Process data from database and return tuple of dicts."""
+    data = {item["StrainName"]: item for item in cadata}
     return tuple(
         {
-            "Strain": row["StrainName"],
+            "Strain": strain["Name"],
             **{
-                key: row["case-attributes"][key]
-                for key in sorted(row["case-attributes"].keys())
+                key: data.get(
+                    strain["Name"], {}).get("case-attributes", {}).get(key, "")
+                for key in fieldnames[1:]
             }
-        } for row in data)
+        } for strain in strains)
 
-def __process_edit_data__(form_data) -> tuple[dict, ...]:
+def __process_edit_data__(fieldnames, form_data) -> tuple[dict, ...]:
     """Process data from form and return tuple of dicts."""
-    raise NotImplementedError
+    def __process__(acc, strain_cattrs):
+        strain, cattrs = strain_cattrs
+        return acc + ({
+            "Strain": strain, **{
+            field: cattrs["case-attributes"].get(field, "")
+            for field in fieldnames[1:]
+            }
+        },)
+    return reduce(__process__, form_data.items(), tuple())
 
 def __write_csv__(fieldnames, data):
     """Write the given `data` to a csv file and return the path to the file."""
     fd, filepath = tempfile.mkstemp(".csv")
     os.close(fd)
     with open(filepath, "w", encoding="utf-8") as csvfile:
-        writer = csv.DictWriter(filename, fieldnames=fieldnames, dialect="unix")
+        writer = csv.DictWriter(csvfile, fieldnames=fieldnames, dialect="unix")
         writer.writeheader()
         writer.writerows(data)
 
     return filepath
 
-def __compute_diff__(calabels: tuple[str, ...], original_data: tuple[dict, ...], edit_data: tuple[dict, ...]):
+def __compute_diff__(fieldnames: tuple[str, ...], original_data: tuple[dict, ...], edit_data: tuple[dict, ...]):
     """Return the diff of the data."""
-    fieldnames = ["Strain"] + sorted(calabels) # Make first column the strain.
     basefilename = __write_csv__(fieldnames, original_data)
     deltafilename = __write_csv__(fieldnames, edit_data)
     diff_results = run_cmd(json.dumps(
@@ -207,13 +217,16 @@ def edit_case_attributes(inbredset_id: int) -> Response:
           database_connection(current_app.config["SQL_URI"]) as conn):
         # TODO: Check user has "edit case attribute privileges"
         user = the_token.user
+        fieldnames = (["Strain"] + sorted(
+            attr["Name"] for attr in
+            __case_attribute_labels_by_inbred_set__(conn, inbredset_id)))
         diff_filename = __queue_diff__(conn, user, __compute_diff__(
-            (["Strain"] + sorted(
-                attr["Name"] for attr in
-                __case_attribute_labels_by_inbred_set__(conn, inbredset_id))),
+            fieldnames,
             __process_orig_data__(
-                __case_attribute_values_by_inbred_set__(conn, inbredset_id)),
-            __process_edit_data__(request.form)))
+                fieldnames,
+                __case_attribute_values_by_inbred_set__(conn, inbredset_id),
+                __inbredset_strains__(conn, inbredset_id)),
+            __process_edit_data__(fieldnames, request.json["edit-data"])))
         try:
             __apply_diff__(conn, user, diff_filename)
             return jsonify({