about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-07-19 15:55:05 +0300
committerFrederick Muriuki Muriithi2022-07-19 15:55:05 +0300
commit3c47afe8092ff9f5bc152320e724c07973c1941a (patch)
tree30f0f54a2ecac03a0bb67dda6b1ea7f0580af005
parentc52570a4069abb6b8953e486adb326392ce6714c (diff)
downloadgn-uploader-3c47afe8092ff9f5bc152320e724c07973c1941a.tar.gz
Save standard error data. Fix linting and typing errors.
-rw-r--r--quality_control/file_utils.py3
-rw-r--r--scripts/insert_data.py136
2 files changed, 113 insertions, 26 deletions
diff --git a/quality_control/file_utils.py b/quality_control/file_utils.py
index fdce1e1..002f22c 100644
--- a/quality_control/file_utils.py
+++ b/quality_control/file_utils.py
@@ -1,10 +1,9 @@
 "Common file utilities"
 from typing import Union
 from pathlib import Path
-from io import TextIOWrapper
 from zipfile import ZipFile, is_zipfile
 
-def open_file(filepath: Union[str, Path]) -> Union[ZipFile, TextIOWrapper]:
+def open_file(filepath: Union[str, Path]):
     "Transparently open both TSV and ZIP files"
     if not is_zipfile(filepath):
         return open(filepath, encoding="utf-8")
diff --git a/scripts/insert_data.py b/scripts/insert_data.py
index 5e596ff..53bf4bd 100644
--- a/scripts/insert_data.py
+++ b/scripts/insert_data.py
@@ -2,28 +2,32 @@
 import sys
 import argparse
 from typing import Tuple
+from functools import reduce
 
 import MySQLdb as mdb
 from redis import Redis
 from MySQLdb.cursors import DictCursor
 
 from quality_control.parsing import take
-from qc_app.db_utils import database_connection
 from quality_control.file_utils import open_file
+from qc_app.db_utils import database_connection
 from qc_app.check_connections import check_db, check_redis
 
 def translate_alias(heading):
+    "Translate strain aliases into canonical names"
     translations = {"B6": "C57BL/6J", "D2": "DBA/2J"}
     return translations.get(heading, heading)
 
-def read_file_headings(filepath):
+def read_file_headings(filepath) -> Tuple[str, ...]:
     "Get the file headings"
     with open_file(filepath) as input_file:
-        for line_number, line_contents in enumerate(input_file):
-            if line_number == 0:
-                return tuple(
-                    translate_alias(heading.strip())
-                    for heading in line_contents.split("\t"))
+        for line_contents in input_file:
+            headings = tuple(
+                translate_alias(heading.strip())
+                for heading in line_contents.split("\t"))
+            break
+
+    return headings
 
 def read_file_contents(filepath):
     "Get the file contents"
@@ -35,16 +39,20 @@ def read_file_contents(filepath):
                 yield tuple(
                     field.strip() for field in line_contents.split("\t"))
 
-def strains_info(dbconn: mdb.Connection, strain_names: Tuple[str, ...]) -> dict:
+def strains_info(
+        dbconn: mdb.Connection, strain_names: Tuple[str, ...],
+        speciesid: int) -> dict:
     "Retrieve information for the strains"
     with dbconn.cursor(cursorclass=DictCursor) as cursor:
         query = (
             "SELECT * FROM Strain WHERE Name IN "
-            f"({', '.join(['%s']*len(strain_names))})")
-        cursor.execute(query, tuple(strain_names))
+            f"({', '.join(['%s']*len(strain_names))}) "
+            "AND SpeciesId = %s")
+        cursor.execute(query, tuple(strain_names) + (speciesid,))
         return {strain["Name"]: strain for strain in cursor.fetchall()}
 
-def read_means(filepath, headings, strain_info):
+def read_datavalues(filepath, headings, strain_info):
+    "Read data values from file"
     for row in (
             dict(zip(headings, line))
             for line in read_file_contents(filepath)):
@@ -52,7 +60,7 @@ def read_means(filepath, headings, strain_info):
             yield {
                 "ProbeSetId": int(row["ProbeSetID"]),
                 "StrainId": strain_info[sname]["Id"],
-                "ProbeSetDataValue": float(row[sname])
+                "DataValue": float(row[sname])
             }
 
 def last_data_id(dbconn: mdb.Connection) -> int:
@@ -61,24 +69,65 @@ def last_data_id(dbconn: mdb.Connection) -> int:
         cursor.execute("SELECT MAX(Id) FROM ProbeSetData")
         return int(cursor.fetchone()[0])
 
+def check_strains(headings_strains, db_strains):
+    "Check strains in headings exist in database"
+    from_db = tuple(db_strains.keys())
+    not_in_db = tuple(
+        strain for strain in headings_strains if strain not in from_db)
+    if len(not_in_db) == 0:
+        return True
+
+    str_not_in_db = "', '".join(not_in_db)
+    print(
+        (f"ERROR: The strain(s) '{str_not_in_db}' w(as|ere) not found in the "
+         "database."),
+        file=sys.stderr)
+    sys.exit(1)
+
+def annotationinfo(
+        dbconn: mdb.Connection, platformid: int, datasetid: int) -> dict:
+    "Get annotation information from the database."
+    # This is somewhat slow. Look into optimising the behaviour
+    def __organise_annotations__(accm, item):
+        names_dict = (
+            {**accm[0], item["Name"]: item} if bool(item["Name"]) else accm[0])
+        targs_dict = (
+            {**accm[1], item["TargetId"]: item}
+            if bool(item["TargetId"]) else accm[1])
+        return (names_dict, targs_dict)
+
+    query = (
+        "SELECT ProbeSet.Name, ProbeSet.ChipId, ProbeSet.TargetId, "
+        "ProbeSetXRef.DataId, ProbeSetXRef.ProbeSetFreezeId "
+        "FROM ProbeSet INNER JOIN ProbeSetXRef "
+        "ON ProbeSet.Id=ProbeSetXRef.ProbeSetId "
+        "WHERE ProbeSet.ChipId=%s AND ProbeSetXRef.ProbeSetFreezeId=%s")
+    with dbconn.cursor(cursorclass=DictCursor) as cursor:
+        cursor.execute(query, (platformid, datasetid))
+        annot_dicts = reduce(# type: ignore[var-annotated]
+            __organise_annotations__, cursor.fetchall(), ({}, {}))
+        return {**annot_dicts[0], **annot_dicts[1]}
+    return {}
+
 def insert_means(
-        filepath: str, dataset_id: int, dbconn: mdb.Connection,
-        rconn: Redis) -> int:
+        filepath: str, speciesid: int, datasetid: int, dbconn: mdb.Connection,
+        rconn: Redis) -> int: # pylint: disable=[unused-argument]
     "Insert the means/averages data into the database"
     print("INSERTING MEANS/AVERAGES DATA.")
     headings = read_file_headings(filepath)
-    strains = strains_info(dbconn, headings[1:])
+    strains = strains_info(dbconn, headings[1:], speciesid)
+    check_strains(headings[1:], strains)
     means_query = (
         "INSERT INTO ProbeSetData "
-        "VALUES(%(ProbeSetDataId)s, %(StrainId)s, %(ProbeSetDataValue)s)")
+        "VALUES(%(ProbeSetDataId)s, %(StrainId)s, %(DataValue)s)")
     xref_query = (
         "INSERT INTO ProbeSetXRef(ProbeSetFreezeId, ProbeSetId, DataId) "
         "VALUES (%(ProbeSetFreezeId)s, %(ProbeSetId)s, %(ProbeSetDataId)s)")
     the_means = (
-        {"ProbeSetFreezeId": dataset_id, "ProbeSetDataId": data_id, **mean}
+        {"ProbeSetFreezeId": datasetid, "ProbeSetDataId": data_id, **mean}
         for data_id, mean in
         enumerate(
-            read_means(filepath, headings, strains),
+            read_datavalues(filepath, headings, strains),
             start=(last_data_id(dbconn)+1)))
     with dbconn.cursor(cursorclass=DictCursor) as cursor:
         while True:
@@ -92,15 +141,42 @@ def insert_means(
             cursor.executemany(xref_query, means)
     return 0
 
-def insert_se(
-        filepath: str, dataset_id: int, dbconn: mdb.Connection,
-        rconn: Redis) -> int:
+def insert_se(# pylint: disable = [too-many-arguments]
+        filepath: str, speciesid: int, platformid: int, datasetid: int,
+        dbconn: mdb.Connection, rconn: Redis) -> int: # pylint: disable=[unused-argument]
     "Insert the standard-error data into the database"
     print("INSERTING STANDARD ERROR DATA...")
+    headings = read_file_headings(filepath)
+    strains = strains_info(dbconn, headings[1:], speciesid)
+    check_strains(headings[1:], strains)
+    se_query = (
+        "INSERT INTO ProbeSetSE "
+        "VALUES(%(DataId)s, %(StrainId)s, %(DataValue)s)")
+    annotations = annotationinfo(dbconn, platformid, datasetid)
+    if not bool(annotations):
+        print(
+            (f"ERROR: No annotations found for platform {platformid} and "
+             f"dataset {datasetid}. Quiting!"),
+            file=sys.stderr)
+        return 1
+
+    se_values = (
+        {"DataId": annotations[str(item["ProbeSetId"])]["DataId"], **item}
+        for item in read_datavalues(filepath, headings, strains))
+    with dbconn.cursor(cursorclass=DictCursor) as cursor:
+        while True:
+            serrors = tuple(take(se_values, 1000))
+            if not bool(serrors):
+                break
+            print(
+                f"\nEXECUTING QUERY:\n\t* {se_query}\n\tWITH PARAMETERS:\n\t"
+                f"{serrors}")
+            cursor.executemany(se_query, serrors)
     return 0
 
 if __name__ == "__main__":
     def cli_args():
+        "Compute the CLI arguments"
         parser = argparse.ArgumentParser(
             prog="InsertData", description=(
                 "Script to insert data from an 'averages' file into the "
@@ -111,10 +187,13 @@ if __name__ == "__main__":
         parser.add_argument(
             "filepath", help="path to the file with the 'averages' data.")
         parser.add_argument(
-            "species_id", help="Identifier for the species in the database.",
+            "speciesid", help="Identifier for the species in the database.",
             type=int)
         parser.add_argument(
-            "dataset_id", help="Identifier for the dataset in the database.",
+            "platformid", help="Identifier for the platform in the database.",
+            type=int)
+        parser.add_argument(
+            "datasetid", help="Identifier for the dataset in the database.",
             type=int)
         parser.add_argument(
             "database_uri",
@@ -134,12 +213,21 @@ if __name__ == "__main__":
         "standard-error": insert_se
     }
 
+    extract_args = {
+        "average": lambda args, dbconn, rconn: (
+            args.filepath, args.speciesid, args.datasetid, dbconn, rconn),
+        "standard-error": lambda args, dbconn, rconn: (
+            args.filepath, args.speciesid, args.platformid, args.datasetid,
+            dbconn, rconn),
+    }
+
     def main():
+        "Main entry point"
         args = cli_args()
         with Redis.from_url(args.redisuri, decode_responses=True) as rconn:
             with database_connection(args.database_uri) as dbconn:
                 return insert_fns[args.filetype](
-                    args.filepath, args.dataset_id, dbconn, rconn)
+                    *extract_args[args.filetype](args, dbconn, rconn))
 
         return 2