about summary refs log tree commit diff
path: root/uploader/samples/models.py
diff options
context:
space:
mode:
Diffstat (limited to 'uploader/samples/models.py')
-rw-r--r--uploader/samples/models.py103
1 files changed, 103 insertions, 0 deletions
diff --git a/uploader/samples/models.py b/uploader/samples/models.py
new file mode 100644
index 0000000..1e9293f
--- /dev/null
+++ b/uploader/samples/models.py
@@ -0,0 +1,103 @@
+"""Functions for handling samples."""
+import csv
+from typing import Iterator
+
+import MySQLdb as mdb
+from MySQLdb.cursors import DictCursor
+
+from functional_tools import take
+
+def samples_by_species_and_population(
+        conn: mdb.Connection,
+        species_id: int,
+        population_id: int
+) -> tuple[dict, ...]:
+    """Fetch the samples by their species and population."""
+    with conn.cursor(cursorclass=DictCursor) as cursor:
+        cursor.execute(
+            "SELECT InbredSet.InbredSetId, Strain.* FROM InbredSet "
+            "INNER JOIN StrainXRef ON InbredSet.InbredSetId=StrainXRef.InbredSetId "
+            "INNER JOIN Strain ON StrainXRef.StrainId=Strain.Id "
+            "WHERE Strain.SpeciesId=%(species_id)s "
+            "AND InbredSet.InbredSetId=%(population_id)s",
+            {"species_id": species_id, "population_id": population_id})
+        return tuple(cursor.fetchall())
+
+
+def read_samples_file(filepath, separator: str, firstlineheading: bool, **kwargs) -> Iterator[dict]:
+    """Read the samples file."""
+    with open(filepath, "r", encoding="utf-8") as inputfile:
+        reader = csv.DictReader(
+            inputfile,
+            fieldnames=(
+                None if firstlineheading
+                else ("Name", "Name2", "Symbol", "Alias")),
+            delimiter=separator,
+            quotechar=kwargs.get("quotechar", '"'))
+        yield from reader
+
+
+def save_samples_data(conn: mdb.Connection,
+                      speciesid: int,
+                      file_data: Iterator[dict]):
+    """Save the samples to DB."""
+    data = ({**row, "SpeciesId": speciesid} for row in file_data)
+    total = 0
+    with conn.cursor() as cursor:
+        while True:
+            batch = take(data, 5000)
+            if len(batch) == 0:
+                break
+            cursor.executemany(
+                "INSERT INTO Strain(Name, Name2, SpeciesId, Symbol, Alias) "
+                "VALUES("
+                "    %(Name)s, %(Name2)s, %(SpeciesId)s, %(Symbol)s, %(Alias)s"
+                ") ON DUPLICATE KEY UPDATE Name=Name",
+                batch)
+            total += len(batch)
+            print(f"\tSaved {total} samples total so far.")
+
+
+def cross_reference_samples(conn: mdb.Connection,
+                            species_id: int,
+                            population_id: int,
+                            strain_names: Iterator[str]):
+    """Link samples to their population."""
+    with conn.cursor(cursorclass=DictCursor) as cursor:
+        cursor.execute(
+            "SELECT MAX(OrderId) AS loid FROM StrainXRef WHERE InbredSetId=%s",
+            (population_id,))
+        last_order_id = (cursor.fetchone()["loid"] or 10)
+        total = 0
+        while True:
+            batch = take(strain_names, 5000)
+            if len(batch) == 0:
+                break
+            params_str = ", ".join(["%s"] * len(batch))
+            ## This query is slow -- investigate.
+            cursor.execute(
+                "SELECT s.Id FROM Strain AS s LEFT JOIN StrainXRef AS sx "
+                "ON s.Id = sx.StrainId WHERE s.SpeciesId=%s AND s.Name IN "
+                f"({params_str}) AND sx.StrainId IS NULL",
+                (species_id,) + tuple(batch))
+            strain_ids = (sid["Id"] for sid in cursor.fetchall())
+            params = tuple({
+                "pop_id": population_id,
+                "strain_id": strain_id,
+                "order_id": last_order_id + (order_id * 10),
+                "mapping": "N",
+                "pedigree": None
+            } for order_id, strain_id in enumerate(strain_ids, start=1))
+            cursor.executemany(
+                "INSERT INTO StrainXRef( "
+                "  InbredSetId, StrainId, OrderId, Used_for_mapping, PedigreeStatus"
+                ")"
+                "VALUES ("
+                "  %(pop_id)s, %(strain_id)s, %(order_id)s, %(mapping)s, "
+                "  %(pedigree)s"
+                ")",
+                params)
+            last_order_id += (len(params) * 10)
+            total += len(batch)
+            print(f"\t{total} total samples cross-referenced to the population "
+                  "so far.")