From 3c47afe8092ff9f5bc152320e724c07973c1941a Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Tue, 19 Jul 2022 15:55:05 +0300 Subject: Save standard error data. Fix linting and typing errors. --- scripts/insert_data.py | 136 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 112 insertions(+), 24 deletions(-) (limited to 'scripts/insert_data.py') diff --git a/scripts/insert_data.py b/scripts/insert_data.py index 5e596ff..53bf4bd 100644 --- a/scripts/insert_data.py +++ b/scripts/insert_data.py @@ -2,28 +2,32 @@ import sys import argparse from typing import Tuple +from functools import reduce import MySQLdb as mdb from redis import Redis from MySQLdb.cursors import DictCursor from quality_control.parsing import take -from qc_app.db_utils import database_connection from quality_control.file_utils import open_file +from qc_app.db_utils import database_connection from qc_app.check_connections import check_db, check_redis def translate_alias(heading): + "Translate strain aliases into canonical names" translations = {"B6": "C57BL/6J", "D2": "DBA/2J"} return translations.get(heading, heading) -def read_file_headings(filepath): +def read_file_headings(filepath) -> Tuple[str, ...]: "Get the file headings" with open_file(filepath) as input_file: - for line_number, line_contents in enumerate(input_file): - if line_number == 0: - return tuple( - translate_alias(heading.strip()) - for heading in line_contents.split("\t")) + for line_contents in input_file: + headings = tuple( + translate_alias(heading.strip()) + for heading in line_contents.split("\t")) + break + + return headings def read_file_contents(filepath): "Get the file contents" @@ -35,16 +39,20 @@ def read_file_contents(filepath): yield tuple( field.strip() for field in line_contents.split("\t")) -def strains_info(dbconn: mdb.Connection, strain_names: Tuple[str, ...]) -> dict: +def strains_info( + dbconn: mdb.Connection, strain_names: Tuple[str, ...], + speciesid: int) -> dict: "Retrieve information for the strains" with dbconn.cursor(cursorclass=DictCursor) as cursor: query = ( "SELECT * FROM Strain WHERE Name IN " - f"({', '.join(['%s']*len(strain_names))})") - cursor.execute(query, tuple(strain_names)) + f"({', '.join(['%s']*len(strain_names))}) " + "AND SpeciesId = %s") + cursor.execute(query, tuple(strain_names) + (speciesid,)) return {strain["Name"]: strain for strain in cursor.fetchall()} -def read_means(filepath, headings, strain_info): +def read_datavalues(filepath, headings, strain_info): + "Read data values from file" for row in ( dict(zip(headings, line)) for line in read_file_contents(filepath)): @@ -52,7 +60,7 @@ def read_means(filepath, headings, strain_info): yield { "ProbeSetId": int(row["ProbeSetID"]), "StrainId": strain_info[sname]["Id"], - "ProbeSetDataValue": float(row[sname]) + "DataValue": float(row[sname]) } def last_data_id(dbconn: mdb.Connection) -> int: @@ -61,24 +69,65 @@ def last_data_id(dbconn: mdb.Connection) -> int: cursor.execute("SELECT MAX(Id) FROM ProbeSetData") return int(cursor.fetchone()[0]) +def check_strains(headings_strains, db_strains): + "Check strains in headings exist in database" + from_db = tuple(db_strains.keys()) + not_in_db = tuple( + strain for strain in headings_strains if strain not in from_db) + if len(not_in_db) == 0: + return True + + str_not_in_db = "', '".join(not_in_db) + print( + (f"ERROR: The strain(s) '{str_not_in_db}' w(as|ere) not found in the " + "database."), + file=sys.stderr) + sys.exit(1) + +def annotationinfo( + dbconn: mdb.Connection, platformid: int, datasetid: int) -> dict: + "Get annotation information from the database." + # This is somewhat slow. Look into optimising the behaviour + def __organise_annotations__(accm, item): + names_dict = ( + {**accm[0], item["Name"]: item} if bool(item["Name"]) else accm[0]) + targs_dict = ( + {**accm[1], item["TargetId"]: item} + if bool(item["TargetId"]) else accm[1]) + return (names_dict, targs_dict) + + query = ( + "SELECT ProbeSet.Name, ProbeSet.ChipId, ProbeSet.TargetId, " + "ProbeSetXRef.DataId, ProbeSetXRef.ProbeSetFreezeId " + "FROM ProbeSet INNER JOIN ProbeSetXRef " + "ON ProbeSet.Id=ProbeSetXRef.ProbeSetId " + "WHERE ProbeSet.ChipId=%s AND ProbeSetXRef.ProbeSetFreezeId=%s") + with dbconn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute(query, (platformid, datasetid)) + annot_dicts = reduce(# type: ignore[var-annotated] + __organise_annotations__, cursor.fetchall(), ({}, {})) + return {**annot_dicts[0], **annot_dicts[1]} + return {} + def insert_means( - filepath: str, dataset_id: int, dbconn: mdb.Connection, - rconn: Redis) -> int: + filepath: str, speciesid: int, datasetid: int, dbconn: mdb.Connection, + rconn: Redis) -> int: # pylint: disable=[unused-argument] "Insert the means/averages data into the database" print("INSERTING MEANS/AVERAGES DATA.") headings = read_file_headings(filepath) - strains = strains_info(dbconn, headings[1:]) + strains = strains_info(dbconn, headings[1:], speciesid) + check_strains(headings[1:], strains) means_query = ( "INSERT INTO ProbeSetData " - "VALUES(%(ProbeSetDataId)s, %(StrainId)s, %(ProbeSetDataValue)s)") + "VALUES(%(ProbeSetDataId)s, %(StrainId)s, %(DataValue)s)") xref_query = ( "INSERT INTO ProbeSetXRef(ProbeSetFreezeId, ProbeSetId, DataId) " "VALUES (%(ProbeSetFreezeId)s, %(ProbeSetId)s, %(ProbeSetDataId)s)") the_means = ( - {"ProbeSetFreezeId": dataset_id, "ProbeSetDataId": data_id, **mean} + {"ProbeSetFreezeId": datasetid, "ProbeSetDataId": data_id, **mean} for data_id, mean in enumerate( - read_means(filepath, headings, strains), + read_datavalues(filepath, headings, strains), start=(last_data_id(dbconn)+1))) with dbconn.cursor(cursorclass=DictCursor) as cursor: while True: @@ -92,15 +141,42 @@ def insert_means( cursor.executemany(xref_query, means) return 0 -def insert_se( - filepath: str, dataset_id: int, dbconn: mdb.Connection, - rconn: Redis) -> int: +def insert_se(# pylint: disable = [too-many-arguments] + filepath: str, speciesid: int, platformid: int, datasetid: int, + dbconn: mdb.Connection, rconn: Redis) -> int: # pylint: disable=[unused-argument] "Insert the standard-error data into the database" print("INSERTING STANDARD ERROR DATA...") + headings = read_file_headings(filepath) + strains = strains_info(dbconn, headings[1:], speciesid) + check_strains(headings[1:], strains) + se_query = ( + "INSERT INTO ProbeSetSE " + "VALUES(%(DataId)s, %(StrainId)s, %(DataValue)s)") + annotations = annotationinfo(dbconn, platformid, datasetid) + if not bool(annotations): + print( + (f"ERROR: No annotations found for platform {platformid} and " + f"dataset {datasetid}. Quiting!"), + file=sys.stderr) + return 1 + + se_values = ( + {"DataId": annotations[str(item["ProbeSetId"])]["DataId"], **item} + for item in read_datavalues(filepath, headings, strains)) + with dbconn.cursor(cursorclass=DictCursor) as cursor: + while True: + serrors = tuple(take(se_values, 1000)) + if not bool(serrors): + break + print( + f"\nEXECUTING QUERY:\n\t* {se_query}\n\tWITH PARAMETERS:\n\t" + f"{serrors}") + cursor.executemany(se_query, serrors) return 0 if __name__ == "__main__": def cli_args(): + "Compute the CLI arguments" parser = argparse.ArgumentParser( prog="InsertData", description=( "Script to insert data from an 'averages' file into the " @@ -111,10 +187,13 @@ if __name__ == "__main__": parser.add_argument( "filepath", help="path to the file with the 'averages' data.") parser.add_argument( - "species_id", help="Identifier for the species in the database.", + "speciesid", help="Identifier for the species in the database.", type=int) parser.add_argument( - "dataset_id", help="Identifier for the dataset in the database.", + "platformid", help="Identifier for the platform in the database.", + type=int) + parser.add_argument( + "datasetid", help="Identifier for the dataset in the database.", type=int) parser.add_argument( "database_uri", @@ -134,12 +213,21 @@ if __name__ == "__main__": "standard-error": insert_se } + extract_args = { + "average": lambda args, dbconn, rconn: ( + args.filepath, args.speciesid, args.datasetid, dbconn, rconn), + "standard-error": lambda args, dbconn, rconn: ( + args.filepath, args.speciesid, args.platformid, args.datasetid, + dbconn, rconn), + } + def main(): + "Main entry point" args = cli_args() with Redis.from_url(args.redisuri, decode_responses=True) as rconn: with database_connection(args.database_uri) as dbconn: return insert_fns[args.filetype]( - args.filepath, args.dataset_id, dbconn, rconn) + *extract_args[args.filetype](args, dbconn, rconn)) return 2 -- cgit v1.2.3