aboutsummaryrefslogtreecommitdiff
path: root/uploader/samples/models.py
diff options
context:
space:
mode:
Diffstat (limited to 'uploader/samples/models.py')
-rw-r--r--uploader/samples/models.py104
1 files changed, 104 insertions, 0 deletions
diff --git a/uploader/samples/models.py b/uploader/samples/models.py
new file mode 100644
index 0000000..d7d5384
--- /dev/null
+++ b/uploader/samples/models.py
@@ -0,0 +1,104 @@
+"""Functions for handling samples."""
+import csv
+from typing import Iterator
+
+import MySQLdb as mdb
+from MySQLdb.cursors import DictCursor
+
+from functional_tools import take
+
+def samples_by_species_and_population(
+ conn: mdb.Connection,
+ species_id: int,
+ population_id: int
+) -> tuple[dict, ...]:
+ """Fetch the samples by their species and population."""
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ "SELECT iset.InbredSetId, s.* FROM InbredSet AS iset "
+ "INNER JOIN StrainXRef AS sxr ON iset.InbredSetId=sxr.InbredSetId "
+ "INNER JOIN Strain AS s ON sxr.StrainId=s.Id "
+ "WHERE s.SpeciesId=%(species_id)s "
+ "AND iset.InbredSetId=%(population_id)s",
+ {"species_id": species_id, "population_id": population_id})
+ return tuple(cursor.fetchall())
+
+
+def read_samples_file(filepath, separator: str, firstlineheading: bool, **kwargs) -> Iterator[dict]:
+ """Read the samples file."""
+ with open(filepath, "r", encoding="utf-8") as inputfile:
+ reader = csv.DictReader(
+ inputfile,
+ fieldnames=(
+ None if firstlineheading
+ else ("Name", "Name2", "Symbol", "Alias")),
+ delimiter=separator,
+ quotechar=kwargs.get("quotechar", '"'))
+ for row in reader:
+ yield row
+
+
+def save_samples_data(conn: mdb.Connection,
+ speciesid: int,
+ file_data: Iterator[dict]):
+ """Save the samples to DB."""
+ data = ({**row, "SpeciesId": speciesid} for row in file_data)
+ total = 0
+ with conn.cursor() as cursor:
+ while True:
+ batch = take(data, 5000)
+ if len(batch) == 0:
+ break
+ cursor.executemany(
+ "INSERT INTO Strain(Name, Name2, SpeciesId, Symbol, Alias) "
+ "VALUES("
+ " %(Name)s, %(Name2)s, %(SpeciesId)s, %(Symbol)s, %(Alias)s"
+ ") ON DUPLICATE KEY UPDATE Name=Name",
+ batch)
+ total += len(batch)
+ print(f"\tSaved {total} samples total so far.")
+
+
+def cross_reference_samples(conn: mdb.Connection,
+ species_id: int,
+ population_id: int,
+ strain_names: Iterator[str]):
+ """Link samples to their population."""
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ "SELECT MAX(OrderId) AS loid FROM StrainXRef WHERE InbredSetId=%s",
+ (population_id,))
+ last_order_id = (cursor.fetchone()["loid"] or 10)
+ total = 0
+ while True:
+ batch = take(strain_names, 5000)
+ if len(batch) == 0:
+ break
+ params_str = ", ".join(["%s"] * len(batch))
+ ## This query is slow -- investigate.
+ cursor.execute(
+ "SELECT s.Id FROM Strain AS s LEFT JOIN StrainXRef AS sx "
+ "ON s.Id = sx.StrainId WHERE s.SpeciesId=%s AND s.Name IN "
+ f"({params_str}) AND sx.StrainId IS NULL",
+ (species_id,) + tuple(batch))
+ strain_ids = (sid["Id"] for sid in cursor.fetchall())
+ params = tuple({
+ "pop_id": population_id,
+ "strain_id": strain_id,
+ "order_id": last_order_id + (order_id * 10),
+ "mapping": "N",
+ "pedigree": None
+ } for order_id, strain_id in enumerate(strain_ids, start=1))
+ cursor.executemany(
+ "INSERT INTO StrainXRef( "
+ " InbredSetId, StrainId, OrderId, Used_for_mapping, PedigreeStatus"
+ ")"
+ "VALUES ("
+ " %(pop_id)s, %(strain_id)s, %(order_id)s, %(mapping)s, "
+ " %(pedigree)s"
+ ")",
+ params)
+ last_order_id += (len(params) * 10)
+ total += len(batch)
+ print(f"\t{total} total samples cross-referenced to the population "
+ "so far.")