"""Functions for handling samples.""" import csv from typing import Iterator import MySQLdb as mdb from MySQLdb.cursors import DictCursor from functional_tools import take def samples_by_species_and_population( conn: mdb.Connection, species_id: int, population_id: int ) -> tuple[dict, ...]: """Fetch the samples by their species and population.""" with conn.cursor(cursorclass=DictCursor) as cursor: cursor.execute( "SELECT iset.InbredSetId, s.* FROM InbredSet AS iset " "INNER JOIN StrainXRef AS sxr ON iset.InbredSetId=sxr.InbredSetId " "INNER JOIN Strain AS s ON sxr.StrainId=s.Id " "WHERE s.SpeciesId=%(species_id)s " "AND iset.InbredSetId=%(population_id)s", {"species_id": species_id, "population_id": population_id}) return tuple(cursor.fetchall()) def read_samples_file(filepath, separator: str, firstlineheading: bool, **kwargs) -> Iterator[dict]: """Read the samples file.""" with open(filepath, "r", encoding="utf-8") as inputfile: reader = csv.DictReader( inputfile, fieldnames=( None if firstlineheading else ("Name", "Name2", "Symbol", "Alias")), delimiter=separator, quotechar=kwargs.get("quotechar", '"')) for row in reader: yield row def save_samples_data(conn: mdb.Connection, speciesid: int, file_data: Iterator[dict]): """Save the samples to DB.""" data = ({**row, "SpeciesId": speciesid} for row in file_data) total = 0 with conn.cursor() as cursor: while True: batch = take(data, 5000) if len(batch) == 0: break cursor.executemany( "INSERT INTO Strain(Name, Name2, SpeciesId, Symbol, Alias) " "VALUES(" " %(Name)s, %(Name2)s, %(SpeciesId)s, %(Symbol)s, %(Alias)s" ") ON DUPLICATE KEY UPDATE Name=Name", batch) total += len(batch) print(f"\tSaved {total} samples total so far.") def cross_reference_samples(conn: mdb.Connection, species_id: int, population_id: int, strain_names: Iterator[str]): """Link samples to their population.""" with conn.cursor(cursorclass=DictCursor) as cursor: cursor.execute( "SELECT MAX(OrderId) AS loid FROM StrainXRef WHERE InbredSetId=%s", (population_id,)) last_order_id = (cursor.fetchone()["loid"] or 10) total = 0 while True: batch = take(strain_names, 5000) if len(batch) == 0: break params_str = ", ".join(["%s"] * len(batch)) ## This query is slow -- investigate. cursor.execute( "SELECT s.Id FROM Strain AS s LEFT JOIN StrainXRef AS sx " "ON s.Id = sx.StrainId WHERE s.SpeciesId=%s AND s.Name IN " f"({params_str}) AND sx.StrainId IS NULL", (species_id,) + tuple(batch)) strain_ids = (sid["Id"] for sid in cursor.fetchall()) params = tuple({ "pop_id": population_id, "strain_id": strain_id, "order_id": last_order_id + (order_id * 10), "mapping": "N", "pedigree": None } for order_id, strain_id in enumerate(strain_ids, start=1)) cursor.executemany( "INSERT INTO StrainXRef( " " InbredSetId, StrainId, OrderId, Used_for_mapping, PedigreeStatus" ")" "VALUES (" " %(pop_id)s, %(strain_id)s, %(order_id)s, %(mapping)s, " " %(pedigree)s" ")", params) last_order_id += (len(params) * 10) total += len(batch) print(f"\t{total} total samples cross-referenced to the population " "so far.")