"""Insert means/averages or standard-error data into the database.""" import sys import string import random import logging import argparse from functools import reduce from typing import Tuple, Iterator import MySQLdb as mdb from redis import Redis from MySQLdb.cursors import DictCursor from gn_libs.mysqldb import database_connection from functional_tools import take from quality_control.file_utils import open_file from uploader.check_connections import check_db, check_redis # Set up logging stderr_handler = logging.StreamHandler(stream=sys.stderr) root_logger = logging.getLogger() root_logger.addHandler(stderr_handler) root_logger.setLevel("WARNING") def random_string(count: int = 10) -> str: """Generate a random, alphanumeric string.""" return "".join(random.choices( string.digits + string.ascii_letters, k=count)) def translate_alias(heading): "Translate strain aliases into canonical names" translations = {"B6": "C57BL/6J", "D2": "DBA/2J"} return translations.get(heading, heading) def read_file_headings(filepath) -> Tuple[str, ...]: "Get the file headings" with open_file(filepath) as input_file: for line_contents in input_file: headings = tuple( translate_alias(heading.strip()) for heading in line_contents.split("\t")) break return headings def read_file_contents(filepath): "Get the file contents" with open_file(filepath) as input_file: for line_number, line_contents in enumerate(input_file): if line_number == 0: continue if line_number > 0: yield tuple( field.strip() for field in line_contents.split("\t")) def strains_info( dbconn: mdb.Connection, strain_names: Tuple[str, ...], speciesid: int) -> dict: "Retrieve information for the strains" with dbconn.cursor(cursorclass=DictCursor) as cursor: query = ( "SELECT * FROM Strain WHERE Name IN " f"({', '.join(['%s']*len(strain_names))}) " "AND SpeciesId = %s") cursor.execute(query, tuple(strain_names) + (speciesid,)) return {strain["Name"]: strain for strain in cursor.fetchall()} def read_datavalues(filepath, headings, strain_info): """Read numerical, data values from the file.""" id_key = headings[0] return { str(row[id_key]): tuple({ "ProbeSetName": str(row[id_key]), "StrainId": strain_info[sname]["Id"], "DataValue": float(row[sname]) } for sname in headings[1:]) for row in (dict(zip(headings, line)) for line in read_file_contents(filepath)) } def read_probesets(filepath, headings): """Read the ProbeSet names.""" id_key = headings[0] for row in (dict(zip(headings, line)) for line in read_file_contents(filepath)): yield {"Name": str(row[id_key])} def last_data_id(dbconn: mdb.Connection) -> int: "Get the last id from the database" with dbconn.cursor() as cursor: cursor.execute("SELECT MAX(Id) FROM ProbeSetData") return int(cursor.fetchone()[0]) def check_strains(headings_strains, db_strains): "Check strains in headings exist in database" from_db = tuple(db_strains.keys()) not_in_db = tuple( strain for strain in headings_strains if strain not in from_db) if len(not_in_db) == 0: return True str_not_in_db = "', '".join(not_in_db) print( (f"ERROR: The strain(s) '{str_not_in_db}' w(as|ere) not found in the " "database."), file=sys.stderr) sys.exit(1) def annotationinfo( dbconn: mdb.Connection, platformid: int, datasetid: int ) -> dict[str, dict]: "Get annotation information from the database." # This is somewhat slow. Look into optimising the behaviour def __organise_annotations__(accm, item): names_dict = ( {**accm[0], item["Name"]: item} if bool(item["Name"]) else accm[0]) targs_dict = ( {**accm[1], item["TargetId"]: item} if bool(item["TargetId"]) else accm[1]) return (names_dict, targs_dict) query = ( "SELECT ProbeSet.Name, ProbeSet.ChipId, ProbeSet.TargetId, " "ProbeSetXRef.DataId, ProbeSetXRef.ProbeSetFreezeId " "FROM ProbeSet INNER JOIN ProbeSetXRef " "ON ProbeSet.Id=ProbeSetXRef.ProbeSetId " "WHERE ProbeSet.ChipId=%s AND ProbeSetXRef.ProbeSetFreezeId=%s") with dbconn.cursor(cursorclass=DictCursor) as cursor: cursor.execute(query, (platformid, datasetid)) annot_dicts = reduce(# type: ignore[var-annotated] __organise_annotations__, cursor.fetchall(), ({}, {})) return {**annot_dicts[0], **annot_dicts[1]} return {} def __format_query__(query, params): "Format the query for output" def __param_str__(param): return "', '".join(str(elt) for elt in param) idx = query.find("VALUES") idx = query.find("%") fields = tuple( elt.replace("%(", "").replace(")s", "").replace(")", "").strip() for elt in query[idx:-1].split(",")) values = (tuple(param[field] for field in fields) for param in params) values_str = ", ".join( f"('{__param_str__(value_tup)}')" for value_tup in values) insert_str = query[:idx].replace( "INSERT INTO ", "INSERT INTO\n\t") return f"{insert_str}\nVALUES\n\t{values_str};" def insert_probesets(filepath: str, dbconn: mdb.Connection, platform_id: int, headings: tuple[str, ...], session_rand_str: str) -> tuple[str, ...]: """Save new ProbeSets into the database.""" probeset_query = ( "INSERT INTO ProbeSet(ChipId, Name) " "VALUES (%(ChipId)s, %(Name)s) ") the_probesets = ({ **row, "Name": f"{row['Name']}{session_rand_str}", "ChipId": platform_id } for row in read_probesets(filepath, headings)) probeset_names: tuple[str, ...] = tuple() with dbconn.cursor(cursorclass=DictCursor) as cursor: while True: probeset_params = tuple(take(the_probesets, 10000)) if not bool(probeset_params): break print(__format_query__(probeset_query, probeset_params)) print() cursor.executemany(probeset_query, probeset_params) probeset_names = probeset_names + tuple( row["Name"] for row in probeset_params) return probeset_names def probeset_ids(dbconn: mdb.Connection, chip_id: int, probeset_names: tuple[str, ...]) -> Iterator[tuple[str, int]]: """Fetch the IDs of the probesets with the given names.""" with dbconn.cursor() as cursor: params_str = ", ".join(["%s"] * len(probeset_names)) cursor.execute( "SELECT Name, Id FROM ProbeSet " "WHERE ChipId=%s " f"AND Name IN ({params_str})", (chip_id,) + probeset_names) while True: row = cursor.fetchone() if not bool(row): break yield row def insert_means(# pylint: disable=[too-many-locals, too-many-arguments] filepath: str, speciesid: int, platform_id: int, datasetid: int, dbconn: mdb.Connection, rconn: Redis) -> int: # pylint: disable=[unused-argument] "Insert the means/averages data into the database" headings = read_file_headings(filepath) strains = strains_info(dbconn, headings[1:], speciesid) check_strains(headings[1:], strains) means_query = ( "INSERT INTO ProbeSetData " "VALUES(%(ProbeSetDataId)s, %(StrainId)s, %(DataValue)s)") xref_query = ( "INSERT INTO ProbeSetXRef(ProbeSetFreezeId, ProbeSetId, DataId) " "VALUES(%(ProbeSetFreezeId)s, %(ProbeSetId)s, %(ProbeSetDataId)s)") # A random string to avoid over-write chances. # This is needed because the `ProbeSet` table is defined with # UNIQUE KEY `ProbeSetId` (`ChipId`,`Name`) # which means that we cannot have 2 (or more) ProbeSets which share both # the name and chip_id (platform) at the same time. rand_str = f"::RAND_{random_string()}" pset_ids = { name[0:name.index("::RAND_")]: pset_id for name, pset_id in probeset_ids( dbconn, platform_id, insert_probesets( filepath, dbconn, platform_id, headings, rand_str)) } the_means = ({ **mean, "ProbeSetFreezeId": datasetid, "ProbeSetDataId": data_id, "ChipId": platform_id, "ProbeSetId": pset_ids[mean["ProbeSetName"]] } for data_id, mean in enumerate(( item for sublist in read_datavalues(filepath, headings, strains).values() for item in sublist), start=(last_data_id(dbconn)+1))) with dbconn.cursor(cursorclass=DictCursor) as cursor: while True: means = tuple(take(the_means, 10000)) if not bool(means): break print(__format_query__(means_query, means)) print() print(__format_query__(xref_query, means)) cursor.executemany(means_query, means) cursor.executemany(xref_query, means) return 0 def insert_se(# pylint: disable = [too-many-arguments,too-many-locals] filepath: str, speciesid: int, platformid: int, datasetid: int, dbconn: mdb.Connection, rconn: Redis) -> int: # pylint: disable=[unused-argument] "Insert the standard-error data into the database" headings = read_file_headings(filepath) strains = strains_info(dbconn, headings[1:], speciesid) check_strains(headings[1:], strains) se_query = ( "INSERT INTO ProbeSetSE " "VALUES(%(DataId)s, %(StrainId)s, %(DataValue)s)") annotations = annotationinfo(dbconn, platformid, datasetid) if not bool(annotations): with dbconn.cursor(cursorclass=DictCursor) as cursor: cursor.execute( ("SELECT " "gc.GeneChipName AS platformname, pf.Name AS studyname, " "psf.FullName AS datasetname " "FROM GeneChip AS gc INNER JOIN ProbeFreeze AS pf " "ON gc.Id=pf.ChipId INNER JOIN ProbeSetFreeze AS psf " "ON pf.Id=psf.ProbeFreezeId " "WHERE gc.Id=%s AND psf.Id=%s"), (platformid, datasetid)) errorinfo = cursor.fetchone() print(("ERROR: No annotations found for the " f"'{errorinfo['datasetname']}' dataset (Id: {datasetid}) " f"under the '{errorinfo['studyname']}' study linked to the " f"'{errorinfo['platformname']}' platform (Id: {platformid})." "\n\n" " Please verify you selected the correct platform, " "study and dataset for the standard-error file(s) you were " "trying to upload."), file=sys.stderr) return 1 namemappings = { _key[0:_key.find("::RAND_")]: _key for _key in annotations.keys() } se_values = ( {"DataId": annotations[namemappings[str(item["ProbeSetName"])]]["DataId"], **item} for item in ( row for psrows in read_datavalues(filepath, headings, strains).values() for row in psrows)) with dbconn.cursor(cursorclass=DictCursor) as cursor: try: while True: serrors = tuple(take(se_values, 1000)) if not bool(serrors): break print(__format_query__(se_query, serrors)) cursor.executemany(se_query, serrors) except KeyError as kerr: print( ( f"The following name(s) or identifier(s) do(es) not " "exist in the database and did not exist in the original " f"values file: {', '.join(kerr.args)}"), file=sys.stderr) return 1 return 0 if __name__ == "__main__": def cli_args(): "Compute the CLI arguments" parser = argparse.ArgumentParser( prog="InsertData", description=( "Script to insert data from an 'averages' file into the " "database.")) parser.add_argument( "filetype", help="type of data to insert.", choices=("average", "standard-error")) parser.add_argument( "filepath", help="path to the file with the 'averages' data.") parser.add_argument( "speciesid", help="Identifier for the species in the database.", type=int) parser.add_argument( "platformid", help="Identifier for the platform in the database.", type=int) parser.add_argument( "datasetid", help="Identifier for the dataset in the database.", type=int) parser.add_argument( "database_uri", help="URL to be used to initialise the connection to the database") parser.add_argument( "redisuri", help="URL to initialise connection to redis", default="redis:///") args = parser.parse_args() check_db(args.database_uri) check_redis(args.redisuri) return args insert_fns = { "average": insert_means, "standard-error": insert_se } extract_args = { "average": lambda args, dbconn, rconn: ( args.filepath, args.speciesid, args.platformid, args.datasetid, dbconn, rconn), "standard-error": lambda args, dbconn, rconn: ( args.filepath, args.speciesid, args.platformid, args.datasetid, dbconn, rconn), } def main(): "Main entry point" args = cli_args() with Redis.from_url(args.redisuri, decode_responses=True) as rconn: with database_connection(args.database_uri) as dbconn: return insert_fns[args.filetype]( *extract_args[args.filetype](args, dbconn, rconn)) return 2 sys.exit(main())