about summary refs log tree commit diff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rw-r--r--scripts/insert_data.py146
1 files changed, 146 insertions, 0 deletions
diff --git a/scripts/insert_data.py b/scripts/insert_data.py
new file mode 100644
index 0000000..5e596ff
--- /dev/null
+++ b/scripts/insert_data.py
@@ -0,0 +1,146 @@
+"""Insert means/averages or standard-error data into the database."""
+import sys
+import argparse
+from typing import Tuple
+
+import MySQLdb as mdb
+from redis import Redis
+from MySQLdb.cursors import DictCursor
+
+from quality_control.parsing import take
+from qc_app.db_utils import database_connection
+from quality_control.file_utils import open_file
+from qc_app.check_connections import check_db, check_redis
+
+def translate_alias(heading):
+    translations = {"B6": "C57BL/6J", "D2": "DBA/2J"}
+    return translations.get(heading, heading)
+
+def read_file_headings(filepath):
+    "Get the file headings"
+    with open_file(filepath) as input_file:
+        for line_number, line_contents in enumerate(input_file):
+            if line_number == 0:
+                return tuple(
+                    translate_alias(heading.strip())
+                    for heading in line_contents.split("\t"))
+
+def read_file_contents(filepath):
+    "Get the file contents"
+    with open_file(filepath) as input_file:
+        for line_number, line_contents in enumerate(input_file):
+            if line_number == 0:
+                continue
+            if line_number > 0:
+                yield tuple(
+                    field.strip() for field in line_contents.split("\t"))
+
+def strains_info(dbconn: mdb.Connection, strain_names: Tuple[str, ...]) -> dict:
+    "Retrieve information for the strains"
+    with dbconn.cursor(cursorclass=DictCursor) as cursor:
+        query = (
+            "SELECT * FROM Strain WHERE Name IN "
+            f"({', '.join(['%s']*len(strain_names))})")
+        cursor.execute(query, tuple(strain_names))
+        return {strain["Name"]: strain for strain in cursor.fetchall()}
+
+def read_means(filepath, headings, strain_info):
+    for row in (
+            dict(zip(headings, line))
+            for line in read_file_contents(filepath)):
+        for sname in headings[1:]:
+            yield {
+                "ProbeSetId": int(row["ProbeSetID"]),
+                "StrainId": strain_info[sname]["Id"],
+                "ProbeSetDataValue": float(row[sname])
+            }
+
+def last_data_id(dbconn: mdb.Connection) -> int:
+    "Get the last id from the database"
+    with dbconn.cursor() as cursor:
+        cursor.execute("SELECT MAX(Id) FROM ProbeSetData")
+        return int(cursor.fetchone()[0])
+
+def insert_means(
+        filepath: str, dataset_id: int, dbconn: mdb.Connection,
+        rconn: Redis) -> int:
+    "Insert the means/averages data into the database"
+    print("INSERTING MEANS/AVERAGES DATA.")
+    headings = read_file_headings(filepath)
+    strains = strains_info(dbconn, headings[1:])
+    means_query = (
+        "INSERT INTO ProbeSetData "
+        "VALUES(%(ProbeSetDataId)s, %(StrainId)s, %(ProbeSetDataValue)s)")
+    xref_query = (
+        "INSERT INTO ProbeSetXRef(ProbeSetFreezeId, ProbeSetId, DataId) "
+        "VALUES (%(ProbeSetFreezeId)s, %(ProbeSetId)s, %(ProbeSetDataId)s)")
+    the_means = (
+        {"ProbeSetFreezeId": dataset_id, "ProbeSetDataId": data_id, **mean}
+        for data_id, mean in
+        enumerate(
+            read_means(filepath, headings, strains),
+            start=(last_data_id(dbconn)+1)))
+    with dbconn.cursor(cursorclass=DictCursor) as cursor:
+        while True:
+            means = tuple(take(the_means, 1000))
+            if not bool(means):
+                break
+            print(
+                f"\nEXECUTING QUERIES:\n\t* {means_query}\n\t* {xref_query}\n"
+                f"with parameters\n\t{means}")
+            cursor.executemany(means_query, means)
+            cursor.executemany(xref_query, means)
+    return 0
+
+def insert_se(
+        filepath: str, dataset_id: int, dbconn: mdb.Connection,
+        rconn: Redis) -> int:
+    "Insert the standard-error data into the database"
+    print("INSERTING STANDARD ERROR DATA...")
+    return 0
+
+if __name__ == "__main__":
+    def cli_args():
+        parser = argparse.ArgumentParser(
+            prog="InsertData", description=(
+                "Script to insert data from an 'averages' file into the "
+                "database."))
+        parser.add_argument(
+            "filetype", help="type of data to insert.",
+            choices=("average", "standard-error"))
+        parser.add_argument(
+            "filepath", help="path to the file with the 'averages' data.")
+        parser.add_argument(
+            "species_id", help="Identifier for the species in the database.",
+            type=int)
+        parser.add_argument(
+            "dataset_id", help="Identifier for the dataset in the database.",
+            type=int)
+        parser.add_argument(
+            "database_uri",
+            help="URL to be used to initialise the connection to the database")
+        parser.add_argument(
+            "redisuri",
+            help="URL to initialise connection to redis",
+            default="redis:///")
+
+        args = parser.parse_args()
+        check_db(args.database_uri)
+        check_redis(args.redisuri)
+        return args
+
+    insert_fns = {
+        "average": insert_means,
+        "standard-error": insert_se
+    }
+
+    def main():
+        args = cli_args()
+        with Redis.from_url(args.redisuri, decode_responses=True) as rconn:
+            with database_connection(args.database_uri) as dbconn:
+                return insert_fns[args.filetype](
+                    args.filepath, args.dataset_id, dbconn, rconn)
+
+        return 2
+
+    sys.exit(main())