aboutsummaryrefslogtreecommitdiff
path: root/uploader/samples/models.py
blob: d7d5384b561fe074dd6aab07f65770662eeecafb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""Functions for handling samples."""
import csv
from typing import Iterator

import MySQLdb as mdb
from MySQLdb.cursors import DictCursor

from functional_tools import take

def samples_by_species_and_population(
        conn: mdb.Connection,
        species_id: int,
        population_id: int
) -> tuple[dict, ...]:
    """Fetch the samples by their species and population."""
    with conn.cursor(cursorclass=DictCursor) as cursor:
        cursor.execute(
            "SELECT iset.InbredSetId, s.* FROM InbredSet AS iset "
            "INNER JOIN StrainXRef AS sxr ON iset.InbredSetId=sxr.InbredSetId "
            "INNER JOIN Strain AS s ON sxr.StrainId=s.Id "
            "WHERE s.SpeciesId=%(species_id)s "
            "AND iset.InbredSetId=%(population_id)s",
            {"species_id": species_id, "population_id": population_id})
        return tuple(cursor.fetchall())


def read_samples_file(filepath, separator: str, firstlineheading: bool, **kwargs) -> Iterator[dict]:
    """Read the samples file."""
    with open(filepath, "r", encoding="utf-8") as inputfile:
        reader = csv.DictReader(
            inputfile,
            fieldnames=(
                None if firstlineheading
                else ("Name", "Name2", "Symbol", "Alias")),
            delimiter=separator,
            quotechar=kwargs.get("quotechar", '"'))
        for row in reader:
            yield row


def save_samples_data(conn: mdb.Connection,
                      speciesid: int,
                      file_data: Iterator[dict]):
    """Save the samples to DB."""
    data = ({**row, "SpeciesId": speciesid} for row in file_data)
    total = 0
    with conn.cursor() as cursor:
        while True:
            batch = take(data, 5000)
            if len(batch) == 0:
                break
            cursor.executemany(
                "INSERT INTO Strain(Name, Name2, SpeciesId, Symbol, Alias) "
                "VALUES("
                "    %(Name)s, %(Name2)s, %(SpeciesId)s, %(Symbol)s, %(Alias)s"
                ") ON DUPLICATE KEY UPDATE Name=Name",
                batch)
            total += len(batch)
            print(f"\tSaved {total} samples total so far.")


def cross_reference_samples(conn: mdb.Connection,
                            species_id: int,
                            population_id: int,
                            strain_names: Iterator[str]):
    """Link samples to their population."""
    with conn.cursor(cursorclass=DictCursor) as cursor:
        cursor.execute(
            "SELECT MAX(OrderId) AS loid FROM StrainXRef WHERE InbredSetId=%s",
            (population_id,))
        last_order_id = (cursor.fetchone()["loid"] or 10)
        total = 0
        while True:
            batch = take(strain_names, 5000)
            if len(batch) == 0:
                break
            params_str = ", ".join(["%s"] * len(batch))
            ## This query is slow -- investigate.
            cursor.execute(
                "SELECT s.Id FROM Strain AS s LEFT JOIN StrainXRef AS sx "
                "ON s.Id = sx.StrainId WHERE s.SpeciesId=%s AND s.Name IN "
                f"({params_str}) AND sx.StrainId IS NULL",
                (species_id,) + tuple(batch))
            strain_ids = (sid["Id"] for sid in cursor.fetchall())
            params = tuple({
                "pop_id": population_id,
                "strain_id": strain_id,
                "order_id": last_order_id + (order_id * 10),
                "mapping": "N",
                "pedigree": None
            } for order_id, strain_id in enumerate(strain_ids, start=1))
            cursor.executemany(
                "INSERT INTO StrainXRef( "
                "  InbredSetId, StrainId, OrderId, Used_for_mapping, PedigreeStatus"
                ")"
                "VALUES ("
                "  %(pop_id)s, %(strain_id)s, %(order_id)s, %(mapping)s, "
                "  %(pedigree)s"
                ")",
                params)
            last_order_id += (len(params) * 10)
            total += len(batch)
            print(f"\t{total} total samples cross-referenced to the population "
                  "so far.")