diff options
Diffstat (limited to 'uploader/samples/models.py')
-rw-r--r-- | uploader/samples/models.py | 104 |
1 files changed, 104 insertions, 0 deletions
diff --git a/uploader/samples/models.py b/uploader/samples/models.py new file mode 100644 index 0000000..d7d5384 --- /dev/null +++ b/uploader/samples/models.py @@ -0,0 +1,104 @@ +"""Functions for handling samples.""" +import csv +from typing import Iterator + +import MySQLdb as mdb +from MySQLdb.cursors import DictCursor + +from functional_tools import take + +def samples_by_species_and_population( + conn: mdb.Connection, + species_id: int, + population_id: int +) -> tuple[dict, ...]: + """Fetch the samples by their species and population.""" + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + "SELECT iset.InbredSetId, s.* FROM InbredSet AS iset " + "INNER JOIN StrainXRef AS sxr ON iset.InbredSetId=sxr.InbredSetId " + "INNER JOIN Strain AS s ON sxr.StrainId=s.Id " + "WHERE s.SpeciesId=%(species_id)s " + "AND iset.InbredSetId=%(population_id)s", + {"species_id": species_id, "population_id": population_id}) + return tuple(cursor.fetchall()) + + +def read_samples_file(filepath, separator: str, firstlineheading: bool, **kwargs) -> Iterator[dict]: + """Read the samples file.""" + with open(filepath, "r", encoding="utf-8") as inputfile: + reader = csv.DictReader( + inputfile, + fieldnames=( + None if firstlineheading + else ("Name", "Name2", "Symbol", "Alias")), + delimiter=separator, + quotechar=kwargs.get("quotechar", '"')) + for row in reader: + yield row + + +def save_samples_data(conn: mdb.Connection, + speciesid: int, + file_data: Iterator[dict]): + """Save the samples to DB.""" + data = ({**row, "SpeciesId": speciesid} for row in file_data) + total = 0 + with conn.cursor() as cursor: + while True: + batch = take(data, 5000) + if len(batch) == 0: + break + cursor.executemany( + "INSERT INTO Strain(Name, Name2, SpeciesId, Symbol, Alias) " + "VALUES(" + " %(Name)s, %(Name2)s, %(SpeciesId)s, %(Symbol)s, %(Alias)s" + ") ON DUPLICATE KEY UPDATE Name=Name", + batch) + total += len(batch) + print(f"\tSaved {total} samples total so far.") + + +def cross_reference_samples(conn: mdb.Connection, + species_id: int, + population_id: int, + strain_names: Iterator[str]): + """Link samples to their population.""" + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + "SELECT MAX(OrderId) AS loid FROM StrainXRef WHERE InbredSetId=%s", + (population_id,)) + last_order_id = (cursor.fetchone()["loid"] or 10) + total = 0 + while True: + batch = take(strain_names, 5000) + if len(batch) == 0: + break + params_str = ", ".join(["%s"] * len(batch)) + ## This query is slow -- investigate. + cursor.execute( + "SELECT s.Id FROM Strain AS s LEFT JOIN StrainXRef AS sx " + "ON s.Id = sx.StrainId WHERE s.SpeciesId=%s AND s.Name IN " + f"({params_str}) AND sx.StrainId IS NULL", + (species_id,) + tuple(batch)) + strain_ids = (sid["Id"] for sid in cursor.fetchall()) + params = tuple({ + "pop_id": population_id, + "strain_id": strain_id, + "order_id": last_order_id + (order_id * 10), + "mapping": "N", + "pedigree": None + } for order_id, strain_id in enumerate(strain_ids, start=1)) + cursor.executemany( + "INSERT INTO StrainXRef( " + " InbredSetId, StrainId, OrderId, Used_for_mapping, PedigreeStatus" + ")" + "VALUES (" + " %(pop_id)s, %(strain_id)s, %(order_id)s, %(mapping)s, " + " %(pedigree)s" + ")", + params) + last_order_id += (len(params) * 10) + total += len(batch) + print(f"\t{total} total samples cross-referenced to the population " + "so far.") |