diff options
Diffstat (limited to 'scripts/rqtl2/install_genotypes.py')
-rw-r--r-- | scripts/rqtl2/install_genotypes.py | 136 |
1 files changed, 78 insertions, 58 deletions
diff --git a/scripts/rqtl2/install_genotypes.py b/scripts/rqtl2/install_genotypes.py index 68ae365..20a19da 100644 --- a/scripts/rqtl2/install_genotypes.py +++ b/scripts/rqtl2/install_genotypes.py @@ -1,11 +1,11 @@ """Load genotypes from R/qtl2 bundle into the database.""" import sys +import argparse import traceback -from pathlib import Path from zipfile import ZipFile from functools import reduce from typing import Iterator, Optional -from logging import Logger, getLogger, StreamHandler +from logging import Logger, getLogger import MySQLdb as mdb from MySQLdb.cursors import DictCursor @@ -19,10 +19,15 @@ from scripts.rqtl2.entry import build_main from scripts.rqtl2.cli_parser import add_common_arguments from scripts.cli_parser import init_cli_parser, add_global_data_arguments -def insert_markers(dbconn: mdb.Connection, - speciesid: int, - markers: tuple[str, ...], - pmapdata: Optional[Iterator[dict]]) -> int: +__MODULE__ = "scripts.rqtl2.install_genotypes" + +def insert_markers( + dbconn: mdb.Connection, + speciesid: int, + markers: tuple[str, ...], + pmapdata: Optional[Iterator[dict]], + _logger: Logger +) -> int: """Insert genotype and genotype values into the database.""" mdata = reduce(#type: ignore[var-annotated] lambda acc, row: ({#type: ignore[arg-type, return-value] @@ -45,12 +50,15 @@ def insert_markers(dbconn: mdb.Connection, "marker": marker, "chr": mdata.get(marker, {}).get("chr"), "pos": mdata.get(marker, {}).get("pos") - } for marker in markers}.items())) + } for marker in markers}.values())) return cursor.rowcount -def insert_individuals(dbconn: mdb.Connection, - speciesid: int, - individuals: tuple[str, ...]) -> int: +def insert_individuals( + dbconn: mdb.Connection, + speciesid: int, + individuals: tuple[str, ...], + _logger: Logger +) -> int: """Insert individuals/samples into the database.""" with dbconn.cursor() as cursor: cursor.executemany( @@ -61,10 +69,13 @@ def insert_individuals(dbconn: mdb.Connection, for individual in individuals)) return cursor.rowcount -def cross_reference_individuals(dbconn: mdb.Connection, - speciesid: int, - populationid: int, - individuals: tuple[str, ...]) -> int: +def cross_reference_individuals( + dbconn: mdb.Connection, + speciesid: int, + populationid: int, + individuals: tuple[str, ...], + _logger: Logger +) -> int: """Cross reference any inserted individuals.""" with dbconn.cursor(cursorclass=DictCursor) as cursor: paramstr = ", ".join(["%s"] * len(individuals)) @@ -80,11 +91,13 @@ def cross_reference_individuals(dbconn: mdb.Connection, tuple(ids)) return cursor.rowcount -def insert_genotype_data(dbconn: mdb.Connection, - speciesid: int, - genotypes: tuple[dict, ...], - individuals: tuple[str, ...]) -> tuple[ - int, tuple[dict, ...]]: +def insert_genotype_data( + dbconn: mdb.Connection, + speciesid: int, + genotypes: tuple[dict, ...], + individuals: tuple[str, ...], + _logger: Logger +) -> tuple[int, tuple[dict, ...]]: """Insert the genotype data values into the database.""" with dbconn.cursor(cursorclass=DictCursor) as cursor: paramstr = ", ".join(["%s"] * len(individuals)) @@ -120,11 +133,14 @@ def insert_genotype_data(dbconn: mdb.Connection, "markerid": row["markerid"] } for row in data) -def cross_reference_genotypes(dbconn: mdb.Connection, - speciesid: int, - datasetid: int, - dataids: tuple[dict, ...], - gmapdata: Optional[Iterator[dict]]) -> int: +def cross_reference_genotypes( + dbconn: mdb.Connection, + speciesid: int, + datasetid: int, + dataids: tuple[dict, ...], + gmapdata: Optional[Iterator[dict]], + _logger: Logger +) -> int: """Cross-reference the data to the relevant dataset.""" _rows, markers, mdata = reduce(#type: ignore[var-annotated] lambda acc, row: (#type: ignore[return-value,arg-type] @@ -140,31 +156,43 @@ def cross_reference_genotypes(dbconn: mdb.Connection, (tuple(), tuple(), {})) with dbconn.cursor(cursorclass=DictCursor) as cursor: - paramstr = ", ".join(["%s"] * len(markers)) - cursor.execute("SELECT Id, Name FROM Geno " - f"WHERE SpeciesId=%s AND Name IN ({paramstr})", - (speciesid,) + markers) - markersdict = {row["Id"]: row["Name"] for row in cursor.fetchall()} - cursor.executemany( + markersdict = {} + if len(markers) > 0: + paramstr = ", ".join(["%s"] * len(markers)) + insertparams = (speciesid,) + markers + selectquery = ("SELECT Id, Name FROM Geno " + f"WHERE SpeciesId=%s AND Name IN ({paramstr})") + _logger.debug( + "The select query was\n\t%s\n\nwith the parameters\n\t%s", + selectquery, + (speciesid,) + markers) + cursor.execute(selectquery, insertparams) + markersdict = {row["Id"]: row["Name"] for row in cursor.fetchall()} + + insertquery = ( "INSERT INTO GenoXRef(GenoFreezeId, GenoId, DataId, cM) " "VALUES(%(datasetid)s, %(markerid)s, %(dataid)s, %(pos)s) " - "ON DUPLICATE KEY UPDATE GenoFreezeId=GenoFreezeId", - tuple({ - **row, - "datasetid": datasetid, - "pos": mdata.get(markersdict.get( - row.get("markerid"), {}), {}).get("pos") - } for row in dataids)) + "ON DUPLICATE KEY UPDATE GenoFreezeId=GenoFreezeId") + insertparams = tuple({ + **row, + "datasetid": datasetid, + "pos": mdata.get(markersdict.get( + row.get("markerid"), "nosuchkey"), {}).get("pos") + } for row in dataids) + _logger.debug( + "The insert query was\n\t%s\n\nwith the parameters\n\t%s", + insertquery, insertparams) + cursor.executemany(insertquery, insertparams) return cursor.rowcount def install_genotypes(#pylint: disable=[too-many-arguments, too-many-locals] dbconn: mdb.Connection, - speciesid: int, - populationid: int, - datasetid: int, - rqtl2bundle: Path, - logger: Logger = getLogger()) -> int: + args: argparse.Namespace, + logger: Logger = getLogger(__name__) +) -> int: """Load any existing genotypes into the database.""" + (speciesid, populationid, datasetid, rqtl2bundle) = ( + args.speciesid, args.populationid, args.datasetid, args.rqtl2bundle) count = 0 with ZipFile(str(rqtl2bundle.absolute()), "r") as zfile: try: @@ -189,20 +217,22 @@ def install_genotypes(#pylint: disable=[too-many-arguments, too-many-locals] speciesid, tuple(key for key in batch[0].keys() if key != "id"), (rqtl2.file_data(zfile, "pmap", cdata) if "pmap" in cdata - else None)) + else None), + logger) individuals = tuple(row["id"] for row in batch) - insert_individuals(dbconn, speciesid, individuals) + insert_individuals(dbconn, speciesid, individuals, logger) cross_reference_individuals( - dbconn, speciesid, populationid, individuals) + dbconn, speciesid, populationid, individuals, logger) _num_rows, dataids = insert_genotype_data( - dbconn, speciesid, batch, individuals) + dbconn, speciesid, batch, individuals, logger) cross_reference_genotypes( dbconn, speciesid, datasetid, dataids, (rqtl2.file_data(zfile, "gmap", cdata) - if "gmap" in cdata else None)) + if "gmap" in cdata else None), + logger) count = count + len(batch) except rqtl2.InvalidFormat as exc: logger.error(str(exc)) @@ -224,15 +254,5 @@ if __name__ == "__main__": return parser.parse_args() - thelogger = getLogger("install_genotypes") - thelogger.addHandler(StreamHandler(stream=sys.stderr)) - main = build_main( - cli_args(), - lambda dbconn, args: install_genotypes(dbconn, - args.speciesid, - args.populationid, - args.datasetid, - args.rqtl2bundle), - thelogger, - "INFO") + main = build_main(cli_args(), install_genotypes, __MODULE__) sys.exit(main()) |