aboutsummaryrefslogtreecommitdiff
path: root/gn2/maintenance/generate_probesetfreeze_file.py
blob: 00c2cddf50037dbcaee36f3a9795f86931bbf0f8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/python

import sys

# sys.path.insert(0, "..") - why?

import os
import collections
import csv

from gn2.base import webqtlConfig

from pprint import pformat as pf

from gn2.utility.tools import get_setting
from gn2.wqflask.database import database_connection


def show_progress(process, counter):
    if counter % 1000 == 0:
        print("{}: {}".format(process, counter))


def get_strains(cursor):
    cursor.execute("""select Strain.Name
                      from Strain, StrainXRef, InbredSet
                      where Strain.Id = StrainXRef.StrainId and
                            StrainXRef.InbredSetId = InbredSet.Id
                            and InbredSet.Name=%s;
                """, "BXD")

    strains = [strain[0] for strain in cursor.fetchall()]
    print("strains:", pf(strains))
    for strain in strains:
        print(" -", strain)

    return strains


def get_probeset_vals(cursor, dataset_name):
    cursor.execute(""" select ProbeSet.Id, ProbeSet.Name
                from ProbeSetXRef,
                     ProbeSetFreeze,
                     ProbeSet
                where ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id and
                      ProbeSetFreeze.Name = %s and
                      ProbeSetXRef.ProbeSetId = ProbeSet.Id;
            """, dataset_name)

    probesets = cursor.fetchall()

    print("Fetched probesets")

    probeset_vals = collections.OrderedDict()

    for counter, probeset in enumerate(probesets):
        cursor.execute(""" select Strain.Name, ProbeSetData.value
                       from ProbeSetData, ProbeSetXRef, ProbeSetFreeze, Strain
                       where ProbeSetData.Id = ProbeSetXRef.DataId
                       and ProbeSetData.StrainId = Strain.Id
                       and ProbeSetXRef.ProbeSetId = %s
                       and ProbeSetFreeze.Id = ProbeSetXRef.ProbeSetFreezeId
                       and ProbeSetFreeze.Name = %s;
                """, (probeset[0], dataset_name))
        val_dic = collections.OrderedDict()
        vals = cursor.fetchall()
        for val in vals:
            val_dic[val[0]] = val[1]

        probeset_vals[probeset[1]] = val_dic
        show_progress("Querying DB", counter)

    return probeset_vals


def trim_strains(strains, probeset_vals):
    trimmed_strains = []
    #print("probeset_vals is:", pf(probeset_vals))
    first_probeset = list(probeset_vals.values())[0]
    print("\n**** first_probeset is:", pf(first_probeset))
    for strain in strains:
        print("\n**** strain is:", pf(strain))
        if strain in first_probeset:
            trimmed_strains.append(strain)
    print("trimmed_strains:", pf(trimmed_strains))
    return trimmed_strains


def write_data_matrix_file(strains, probeset_vals, filename):
    with open(filename, "wb") as fh:
        csv_writer = csv.writer(fh, delimiter=",", quoting=csv.QUOTE_ALL)
        #print("strains is:", pf(strains))
        csv_writer.writerow(['ID'] + strains)
        for counter, probeset in enumerate(probeset_vals):
            row_data = [probeset]
            for strain in strains:
                #print("probeset is: ", pf(probeset_vals[probeset]))
                row_data.append(probeset_vals[probeset][strain])
            #print("row_data is: ", pf(row_data))
            csv_writer.writerow(row_data)
            show_progress("Writing", counter)


def main():
    filename = os.path.expanduser(
        "~/gene/wqflask/maintenance/"
        "ProbeSetFreezeId_210_FullName_Eye_AXBXA_Illumina_V6.2"
        "(Oct08)_RankInv_Beta.txt")
    dataset_name = "Eye_AXBXA_1008_RankInv"

    with database_connection(get_setting("SQL_URI")) as conn:
        with conn.cursor() as cursor:
            strains = get_strains(cursor)
            print("Getting probset_vals")
            probeset_vals = get_probeset_vals(cursor, dataset_name)
            print("Finished getting probeset_vals")
            trimmed_strains = trim_strains(strains, probeset_vals)
            write_data_matrix_file(trimmed_strains, probeset_vals, filename)


if __name__ == '__main__':
    main()