about summary refs log tree commit diff
diff options
context:
space:
mode:
authorBonfaceKilz2022-03-12 16:59:44 +0300
committerBonfaceKilz2022-03-12 17:38:12 +0300
commitf27f8470e79857c9c088e230a141995c3127640b (patch)
tree9886a91efb1dd0755163448758f89ed46aad304d
parent2014c6c166a7659f30f36c829c09d84f97297b88 (diff)
downloadgenenetwork3-f27f8470e79857c9c088e230a141995c3127640b.tar.gz
Fix pylint issues
-rw-r--r--gn3/csvcmp.py64
-rw-r--r--gn3/db/sample_data.py93
2 files changed, 85 insertions, 72 deletions
diff --git a/gn3/csvcmp.py b/gn3/csvcmp.py
index ac09cc3..82d74d0 100644
--- a/gn3/csvcmp.py
+++ b/gn3/csvcmp.py
@@ -1,50 +1,59 @@
+"""This module contains functions for manipulating and working with csv
+texts"""
 import json
 import os
 import uuid
 from gn3.commands import run_cmd
 
 
-def extract_strain_name(csv_header, data, seek="Strain Name"):
+def extract_strain_name(csv_header, data, seek="Strain Name") -> str:
+    """Extract a strain's name given a csv header"""
     for column, value in zip(csv_header.split(","), data.split(",")):
         if seek in column:
             return value
     return ""
 
 
-def create_dirs_if_not_exists(dirs: list):
+def create_dirs_if_not_exists(dirs: list) -> None:
+    """Create directories from a list"""
     for dir_ in dirs:
         if not os.path.exists(dir_):
             os.makedirs(dir_)
 
 
 def remove_insignificant_edits(diff_data, epsilon=0.001):
-    _mod = []
+    """Remove or ignore edits that are not within ε"""
+    __mod = []
     for mod in diff_data.get("Modifications"):
         original = mod.get("Original").split(",")
         current = mod.get("Current").split(",")
-        for i, (x, y) in enumerate(zip(original, current)):
-            if (x.replace('.', '').isdigit() and
-                y.replace('.', '').isdigit() and
-                    abs(float(x) - float(y)) < epsilon):
-                current[i] = x
+        for i, (_x, _y) in enumerate(zip(original, current)):
+            if (
+                _x.replace(".", "").isdigit()
+                and _y.replace(".", "").isdigit()
+                and abs(float(_x) - float(_y)) < epsilon
+            ):
+                current[i] = _x
         if not (__o := ",".join(original)) == (__c := ",".join(current)):
-            _mod.append({
-                "Original": __o,
-                "Current": __c,
-            })
-    diff_data['Modifications'] = _mod
+            __mod.append(
+                {
+                    "Original": __o,
+                    "Current": __c,
+                }
+            )
+    diff_data["Modifications"] = __mod
     return diff_data
 
 
-def csv_diff(base_csv, delta_csv, tmp_dir="/tmp"):
+def csv_diff(base_csv, delta_csv, tmp_dir="/tmp") -> dict:
+    """Diff 2 csv strings"""
     base_csv_list = base_csv.strip().split("\n")
     delta_csv_list = delta_csv.strip().split("\n")
 
-    base_csv_header, delta_csv_header, header = "", "", ""
+    base_csv_header, delta_csv_header = "", ""
     for i, line in enumerate(base_csv_list):
         if line.startswith("Strain Name,Value,SE,Count"):
-            header = line
-            base_csv_header, delta_csv_header= line, delta_csv_list[i]
+            base_csv_header, delta_csv_header = line, delta_csv_list[i]
             break
     longest_header = max(base_csv_header, delta_csv_header)
 
@@ -53,22 +62,21 @@ def csv_diff(base_csv, delta_csv, tmp_dir="/tmp"):
             base_csv = base_csv.replace("Strain Name,Value,SE,Count",
                                         longest_header, 1)
         else:
-            delta_csv = delta_csv.replace("Strain Name,Value,SE,Count",
-                                          longest_header, 1)
+            delta_csv = delta_csv.replace(
+                "Strain Name,Value,SE,Count", longest_header, 1
+            )
     file_name1 = os.path.join(tmp_dir, str(uuid.uuid4()))
     file_name2 = os.path.join(tmp_dir, str(uuid.uuid4()))
 
-    with open(file_name1, "w") as f_:
+    with open(file_name1, "w", encoding="utf-8") as _f:
         _l = len(longest_header.split(","))
-        f_.write(fill_csv(csv_text=base_csv,
-                          width=_l))
-    with open(file_name2, "w") as f_:
-        f_.write(fill_csv(delta_csv,
-                          width=_l))
+        _f.write(fill_csv(csv_text=base_csv, width=_l))
+    with open(file_name2, "w", encoding="utf-8") as _f:
+        _f.write(fill_csv(delta_csv, width=_l))
 
     # Now we can run the diff!
     _r = run_cmd(cmd=('"csvdiff '
-                      f'{file_name1} {file_name2} '
+                      f"{file_name1} {file_name2} "
                       '--format json"'))
     if _r.get("code") == 0:
         _r = json.loads(_r.get("output"))
@@ -86,6 +94,7 @@ def csv_diff(base_csv, delta_csv, tmp_dir="/tmp"):
 
 
 def fill_csv(csv_text, width, value="x"):
+    """Fill a csv text with 'value' if it's length is less than width"""
     data = []
     for line in csv_text.strip().split("\n"):
         if line.startswith("Strain") or line.startswith("#"):
@@ -95,6 +104,5 @@ def fill_csv(csv_text, width, value="x"):
             for i, val in enumerate(_n):
                 if not val.strip():
                     _n[i] = value
-            data.append(
-                ",".join(_n + [value] * (width - len(_n))))
+            data.append(",".join(_n + [value] * (width - len(_n))))
     return "\n".join(data)
diff --git a/gn3/db/sample_data.py b/gn3/db/sample_data.py
index b2d6aed..e3daa21 100644
--- a/gn3/db/sample_data.py
+++ b/gn3/db/sample_data.py
@@ -1,7 +1,8 @@
-from gn3.csvcmp import extract_strain_name
-from typing import Any, Tuple
+"""Module containing functions that work with sample data"""
+from typing import Any, Tuple, Dict, Callable
 
 import MySQLdb
+from gn3.csvcmp import extract_strain_name
 
 
 _MAP = {
@@ -14,6 +15,10 @@ _MAP = {
 def __extract_actions(original_data: str,
                       updated_data: str,
                       csv_header: str) -> dict:
+    """Return a dictionary containing elements that need to be deleted, inserted,
+or updated.
+
+    """
     original_data = original_data.strip().split(",")
     updated_data = updated_data.strip().split(",")
     csv_header = csv_header.strip().split(",")
@@ -28,7 +33,7 @@ def __extract_actions(original_data: str,
             strain_name = _o
         if _o == _u:  # No change
             continue
-        elif _o and _u == "x":  # Deletion
+        if _o and _u == "x":  # Deletion
             result["delete"]["data"].append(_o)
             result["delete"]["csv_header"].append(_h)
         elif _o == "x" and _u:  # Insert
@@ -84,23 +89,23 @@ def get_trait_csv_sample_data(conn: Any,
         if not case_attr_columns:
             return ("Strain Name,Value,SE,Count\n" +
                     "\n".join(csv_data.keys()))
-        else:
-            columns = sorted(case_attr_columns)
-            csv = ("Strain Name,Value,SE,Count," +
-                   ",".join(columns) + "\n")
-            for key, value in csv_data.items():
-                if not value:
-                    csv += (key + (len(case_attr_columns) * ",x") + "\n")
-                else:
-                    vals = [str(value.get(column, "x")) for column in columns]
-                    csv += (key + "," + ",".join(vals) + "\n")
-            return csv
+        columns = sorted(case_attr_columns)
+        csv = ("Strain Name,Value,SE,Count," +
+               ",".join(columns) + "\n")
+        for key, value in csv_data.items():
+            if not value:
+                csv += (key + (len(case_attr_columns) * ",x") + "\n")
+            else:
+                vals = [str(value.get(column, "x")) for column in columns]
+                csv += (key + "," + ",".join(vals) + "\n")
+        return csv
     return "No Sample Data Found"
 
 
 def get_sample_data_ids(conn: Any, publishxref_id: int,
                         phenotype_id: int,
                         strain_name: str) -> Tuple:
+    """Get the strain_id, publishdata_id and inbredset_id for a given strain"""
     strain_id, publishdata_id, inbredset_id = None, None, None
     with conn.cursor() as cursor:
         cursor.execute("SELECT st.id, pd.Id, pf.InbredSetId "
@@ -125,7 +130,8 @@ def get_sample_data_ids(conn: Any, publishxref_id: int,
     return (strain_id, publishdata_id, inbredset_id)
 
 
-def update_sample_data(conn: Any,  # pylint: disable=[R0913]
+# pylint: disable=[R0913, R0914]
+def update_sample_data(conn: Any,
                        trait_name: str,
                        original_data: str,
                        updated_data: str,
@@ -136,12 +142,14 @@ def update_sample_data(conn: Any,  # pylint: disable=[R0913]
     def __update_data(conn, table, value):
         if value and value != "x":
             with conn.cursor() as cursor:
-                sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) + " = %s")
+                sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) +
+                             " = %s")
                 _val = _MAP.get(table)[-1]
                 cursor.execute((f"UPDATE {table} SET {_val} = %s "
                                 f"WHERE {sub_query}"),
                                (value, strain_id, data_id))
                 return cursor.rowcount
+        return 0
 
     def __update_case_attribute(conn, value, strain_id,
                                 case_attr, inbredset_id):
@@ -177,6 +185,7 @@ def update_sample_data(conn: Any,  # pylint: disable=[R0913]
         if __actions.get("update"):
             _csv_header = __actions["update"]["csv_header"]
             _data = __actions["update"]["data"]
+            # pylint: disable=[E1101]
             for header, value in zip(_csv_header.split(","),
                                      _data.split(",")):
                 header = header.strip()
@@ -208,9 +217,9 @@ def update_sample_data(conn: Any,  # pylint: disable=[R0913]
                 phenotype_id=phenotype_id)
             if _rowcount:
                 count += 1
-    except Exception as e:  # pylint: disable=[C0103, W0612]
+    except Exception as _e:
         conn.rollback()
-        raise MySQLdb.Error
+        raise MySQLdb.Error(_e) from _e
     conn.commit()
     return count
 
@@ -223,28 +232,24 @@ def delete_sample_data(conn: Any,
     """Given the right parameters, delete sample-data from the relevant
     tables."""
     def __delete_data(conn, table):
-        if value and value != "x":
-            sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) + " = %s")
-            with conn.cursor() as cursor:
-                cursor.execute((f"DELETE FROM {table} "
-                                f"WHERE {sub_query}"),
-                               (strain_id, data_id))
-                return cursor.rowcount
-        return 0
+        sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) + " = %s")
+        with conn.cursor() as cursor:
+            cursor.execute((f"DELETE FROM {table} "
+                            f"WHERE {sub_query}"),
+                           (strain_id, data_id))
+            return cursor.rowcount
 
     def __delete_case_attribute(conn, strain_id,
                                 case_attr, inbredset_id):
-        if value != "x":
-            with conn.cursor() as cursor:
-                cursor.execute(
-                    "DELETE FROM CaseAttributeXRefNew "
-                    "WHERE StrainId = %s AND CaseAttributeId = "
-                    "(SELECT CaseAttributeId FROM "
-                    "CaseAttribute WHERE Name = %s) "
-                    "AND InbredSetId = %s",
-                    (strain_id, case_attr, inbredset_id))
-                return cursor.rowcount
-        return 0
+        with conn.cursor() as cursor:
+            cursor.execute(
+                "DELETE FROM CaseAttributeXRefNew "
+                "WHERE StrainId = %s AND CaseAttributeId = "
+                "(SELECT CaseAttributeId FROM "
+                "CaseAttribute WHERE Name = %s) "
+                "AND InbredSetId = %s",
+                (strain_id, case_attr, inbredset_id))
+            return cursor.rowcount
 
     strain_id, data_id, inbredset_id = get_sample_data_ids(
         conn=conn, publishxref_id=trait_name,
@@ -260,9 +265,8 @@ def delete_sample_data(conn: Any,
     count = 0
 
     try:
-        for header, value in zip(csv_header.split(","), data.split(",")):
+        for header in csv_header.split(","):
             header = header.strip()
-            value = value.strip()
             if header in none_case_attrs:
                 count += none_case_attrs.get(header)()
             else:
@@ -271,14 +275,15 @@ def delete_sample_data(conn: Any,
                     strain_id=strain_id,
                     case_attr=header,
                     inbredset_id=inbredset_id)
-    except Exception as e:  # pylint: disable=[C0103, W0612]
+    except Exception as _e:
         conn.rollback()
-        raise MySQLdb.Error
+        raise MySQLdb.Error(_e) from _e
     conn.commit()
     return count
 
 
-def insert_sample_data(conn: Any,  # pylint: disable=[R0913]
+# pylint: disable=[R0913, R0914]
+def insert_sample_data(conn: Any,
                        trait_name: str,
                        data: str,
                        csv_header: str,
@@ -355,6 +360,6 @@ def insert_sample_data(conn: Any,  # pylint: disable=[R0913]
                     case_attr=header,
                     value=value)
         return count
-    except Exception as e:  # pylint: disable=[C0103, W0612]
+    except Exception as _e:
         conn.rollback()
-        raise MySQLdb.Error
+        raise MySQLdb.Error(_e) from _e