aboutsummaryrefslogtreecommitdiff
path: root/scripts/insert_data.py
blob: 5e596ffc7bdd5236e428bd90097efb4f5c89a22b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""Insert means/averages or standard-error data into the database."""
import sys
import argparse
from typing import Tuple

import MySQLdb as mdb
from redis import Redis
from MySQLdb.cursors import DictCursor

from quality_control.parsing import take
from qc_app.db_utils import database_connection
from quality_control.file_utils import open_file
from qc_app.check_connections import check_db, check_redis

def translate_alias(heading):
    translations = {"B6": "C57BL/6J", "D2": "DBA/2J"}
    return translations.get(heading, heading)

def read_file_headings(filepath):
    "Get the file headings"
    with open_file(filepath) as input_file:
        for line_number, line_contents in enumerate(input_file):
            if line_number == 0:
                return tuple(
                    translate_alias(heading.strip())
                    for heading in line_contents.split("\t"))

def read_file_contents(filepath):
    "Get the file contents"
    with open_file(filepath) as input_file:
        for line_number, line_contents in enumerate(input_file):
            if line_number == 0:
                continue
            if line_number > 0:
                yield tuple(
                    field.strip() for field in line_contents.split("\t"))

def strains_info(dbconn: mdb.Connection, strain_names: Tuple[str, ...]) -> dict:
    "Retrieve information for the strains"
    with dbconn.cursor(cursorclass=DictCursor) as cursor:
        query = (
            "SELECT * FROM Strain WHERE Name IN "
            f"({', '.join(['%s']*len(strain_names))})")
        cursor.execute(query, tuple(strain_names))
        return {strain["Name"]: strain for strain in cursor.fetchall()}

def read_means(filepath, headings, strain_info):
    for row in (
            dict(zip(headings, line))
            for line in read_file_contents(filepath)):
        for sname in headings[1:]:
            yield {
                "ProbeSetId": int(row["ProbeSetID"]),
                "StrainId": strain_info[sname]["Id"],
                "ProbeSetDataValue": float(row[sname])
            }

def last_data_id(dbconn: mdb.Connection) -> int:
    "Get the last id from the database"
    with dbconn.cursor() as cursor:
        cursor.execute("SELECT MAX(Id) FROM ProbeSetData")
        return int(cursor.fetchone()[0])

def insert_means(
        filepath: str, dataset_id: int, dbconn: mdb.Connection,
        rconn: Redis) -> int:
    "Insert the means/averages data into the database"
    print("INSERTING MEANS/AVERAGES DATA.")
    headings = read_file_headings(filepath)
    strains = strains_info(dbconn, headings[1:])
    means_query = (
        "INSERT INTO ProbeSetData "
        "VALUES(%(ProbeSetDataId)s, %(StrainId)s, %(ProbeSetDataValue)s)")
    xref_query = (
        "INSERT INTO ProbeSetXRef(ProbeSetFreezeId, ProbeSetId, DataId) "
        "VALUES (%(ProbeSetFreezeId)s, %(ProbeSetId)s, %(ProbeSetDataId)s)")
    the_means = (
        {"ProbeSetFreezeId": dataset_id, "ProbeSetDataId": data_id, **mean}
        for data_id, mean in
        enumerate(
            read_means(filepath, headings, strains),
            start=(last_data_id(dbconn)+1)))
    with dbconn.cursor(cursorclass=DictCursor) as cursor:
        while True:
            means = tuple(take(the_means, 1000))
            if not bool(means):
                break
            print(
                f"\nEXECUTING QUERIES:\n\t* {means_query}\n\t* {xref_query}\n"
                f"with parameters\n\t{means}")
            cursor.executemany(means_query, means)
            cursor.executemany(xref_query, means)
    return 0

def insert_se(
        filepath: str, dataset_id: int, dbconn: mdb.Connection,
        rconn: Redis) -> int:
    "Insert the standard-error data into the database"
    print("INSERTING STANDARD ERROR DATA...")
    return 0

if __name__ == "__main__":
    def cli_args():
        parser = argparse.ArgumentParser(
            prog="InsertData", description=(
                "Script to insert data from an 'averages' file into the "
                "database."))
        parser.add_argument(
            "filetype", help="type of data to insert.",
            choices=("average", "standard-error"))
        parser.add_argument(
            "filepath", help="path to the file with the 'averages' data.")
        parser.add_argument(
            "species_id", help="Identifier for the species in the database.",
            type=int)
        parser.add_argument(
            "dataset_id", help="Identifier for the dataset in the database.",
            type=int)
        parser.add_argument(
            "database_uri",
            help="URL to be used to initialise the connection to the database")
        parser.add_argument(
            "redisuri",
            help="URL to initialise connection to redis",
            default="redis:///")

        args = parser.parse_args()
        check_db(args.database_uri)
        check_redis(args.redisuri)
        return args

    insert_fns = {
        "average": insert_means,
        "standard-error": insert_se
    }

    def main():
        args = cli_args()
        with Redis.from_url(args.redisuri, decode_responses=True) as rconn:
            with database_connection(args.database_uri) as dbconn:
                return insert_fns[args.filetype](
                    args.filepath, args.dataset_id, dbconn, rconn)

        return 2

    sys.exit(main())