1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
|
"""module that calls the gn3 api's to do the correlation """
import json
import time
from functools import wraps
from gn2.utility.tools import SQL_URI
from gn2.wqflask.correlation import correlation_functions
from gn2.base import data_set
from gn2.base.trait import create_trait
from gn2.base.trait import retrieve_sample_data
from gn3.db_utils import database_connection
from gn3.commands import run_sample_corr_cmd
from gn3.computations.correlations import map_shared_keys_to_values
from gn3.computations.correlations import compute_all_lit_correlation
from gn3.computations.correlations import compute_tissue_correlation
from gn3.computations.correlations import fast_compute_all_sample_correlation
def create_target_this_trait(start_vars):
"""this function creates the required trait and target dataset for correlation"""
if start_vars['dataset'] == "Temp":
this_dataset = data_set.create_dataset(
dataset_name="Temp", dataset_type="Temp", group_name=start_vars['group'])
else:
this_dataset = data_set.create_dataset(
dataset_name=start_vars['dataset'])
target_dataset = data_set.create_dataset(
dataset_name=start_vars['corr_dataset'])
this_trait = create_trait(dataset=this_dataset,
name=start_vars['trait_id'])
sample_data = ()
return (this_dataset, this_trait, target_dataset, sample_data)
def test_process_data(this_trait, dataset, start_vars):
"""test function for bxd,all and other sample data"""
corr_samples_group = start_vars["corr_samples_group"]
primary_samples = dataset.group.samplelist
if dataset.group.parlist != None:
primary_samples += dataset.group.parlist
if dataset.group.f1list != None:
primary_samples += dataset.group.f1list
# If either BXD/whatever Only or All Samples, append all of that group's samplelist
if corr_samples_group != 'samples_other':
sample_data = process_samples(start_vars, primary_samples)
# If either Non-BXD/whatever or All Samples, get all samples from this_trait.data and
# exclude the primary samples (because they would have been added in the previous
# if statement if the user selected All Samples)
if corr_samples_group != 'samples_primary':
if corr_samples_group == 'samples_other':
primary_samples = [x for x in primary_samples if x not in (
dataset.group.parlist + dataset.group.f1list)]
sample_data = process_samples(start_vars, list(
this_trait.data.keys()), primary_samples)
return sample_data
def process_samples(start_vars, sample_names=[], excluded_samples=[]):
"""code to fetch correct samples"""
sample_data = {}
sample_vals_dict = json.loads(start_vars["sample_vals"])
if sample_names:
for sample in sample_names:
if sample in sample_vals_dict and sample not in excluded_samples:
val = sample_vals_dict[sample]
if not val.strip().lower() == "x":
sample_data[str(sample)] = float(val)
else:
for sample in sample_vals_dict.keys():
if sample not in excluded_samples:
val = sample_vals_dict[sample]
if not val.strip().lower() == "x":
sample_data[str(sample)] = float(val)
return sample_data
def merge_correlation_results(correlation_results, target_correlation_results):
corr_dict = {}
for trait_dict in target_correlation_results:
for trait_name, values in trait_dict.items():
corr_dict[trait_name] = values
for trait_dict in correlation_results:
for trait_name, values in trait_dict.items():
if corr_dict.get(trait_name):
trait_dict[trait_name].update(corr_dict.get(trait_name))
return correlation_results
def sample_for_trait_lists(corr_results, target_dataset,
this_trait, this_dataset, start_vars):
"""interface function for correlation on top results"""
(this_trait_data, target_dataset) = fetch_sample_data(
start_vars, this_trait, this_dataset, target_dataset)
correlation_results = run_sample_corr_cmd(
corr_method="pearson", this_trait=this_trait_data,
target_dataset=target_dataset)
return correlation_results
def tissue_for_trait_lists(corr_results, this_dataset, this_trait):
"""interface function for doing tissue corr_results on trait_list"""
trait_lists = dict([(list(corr_result)[0], True)
for corr_result in corr_results])
# trait_lists = {list(corr_results)[0]: 1 for corr_result in corr_results}
traits_symbol_dict = this_dataset.retrieve_genes("Symbol")
traits_symbol_dict = dict({trait_name: symbol for (
trait_name, symbol) in traits_symbol_dict.items() if trait_lists.get(trait_name)})
tissue_input = get_tissue_correlation_input(
this_trait, traits_symbol_dict)
if tissue_input is not None:
(primary_tissue_data, target_tissue_data) = tissue_input
corr_results = compute_tissue_correlation(
primary_tissue_dict=primary_tissue_data,
target_tissues_data=target_tissue_data,
corr_method="pearson")
return corr_results
def lit_for_trait_list(corr_results, this_dataset, this_trait):
(this_trait_geneid, geneid_dict, species) = do_lit_correlation(
this_trait, this_dataset)
# trait_lists = {list(corr_results)[0]: 1 for corr_result in corr_results}
trait_lists = dict([(list(corr_result)[0], True)
for corr_result in corr_results])
geneid_dict = {trait_name: geneid for (trait_name, geneid) in geneid_dict.items() if
trait_lists.get(trait_name)}
with database_connection(SQL_URI) as conn:
correlation_results = compute_all_lit_correlation(
conn=conn, trait_lists=list(geneid_dict.items()),
species=species, gene_id=this_trait_geneid)
return correlation_results
def fetch_sample_data(start_vars, this_trait, this_dataset, target_dataset):
corr_samples_group = start_vars["corr_samples_group"]
if corr_samples_group == "samples_primary":
sample_data = process_samples(
start_vars, this_dataset.group.samplelist)
elif corr_samples_group == "samples_other":
sample_data = process_samples(
start_vars, excluded_samples=this_dataset.group.samplelist)
else:
sample_data = process_samples(start_vars,
this_dataset.group.all_samples_ordered())
target_dataset.get_trait_data(list(sample_data.keys()))
this_trait = retrieve_sample_data(this_trait, this_dataset)
this_trait_data = {
"trait_sample_data": sample_data,
"trait_id": start_vars["trait_id"]
}
results = map_shared_keys_to_values(
target_dataset.samplelist, target_dataset.trait_data)
return (this_trait_data, results)
def compute_correlation(start_vars, method="pearson", compute_all=False):
"""Compute correlations using GN3 API
Keyword arguments:
start_vars -- All input from form; includes things like the trait/dataset names
method -- Correlation method to be used (pearson, spearman, or bicor)
compute_all -- Include sample, tissue, and literature correlations (when applicable)
"""
from gn2.wqflask.correlation.rust_correlation import compute_correlation_rust
corr_type = start_vars['corr_type']
method = start_vars['corr_sample_method']
corr_return_results = int(start_vars.get("corr_return_results", 100))
return compute_correlation_rust(
start_vars, corr_type, method, corr_return_results, compute_all)
def compute_corr_for_top_results(start_vars,
correlation_results,
this_trait,
this_dataset,
target_dataset,
corr_type):
if corr_type != "tissue" and this_dataset.type == "ProbeSet" and target_dataset.type == "ProbeSet":
tissue_result = tissue_for_trait_lists(
correlation_results, this_dataset, this_trait)
if tissue_result:
correlation_results = merge_correlation_results(
correlation_results, tissue_result)
if corr_type != "lit" and this_dataset.type == "ProbeSet" and target_dataset.type == "ProbeSet":
lit_result = lit_for_trait_list(
correlation_results, this_dataset, this_trait)
if lit_result:
correlation_results = merge_correlation_results(
correlation_results, lit_result)
if corr_type != "sample" and this_dataset.type == "ProbeSet" and target_dataset.type == "ProbeSet":
sample_result = sample_for_trait_lists(
correlation_results, target_dataset, this_trait, this_dataset, start_vars)
if sample_result:
correlation_results = merge_correlation_results(
correlation_results, sample_result)
return correlation_results
def do_lit_correlation(this_trait, this_dataset):
"""function for fetching lit inputs"""
geneid_dict = this_dataset.retrieve_genes("GeneId")
species = this_dataset.group.species
if species:
species = species.lower()
trait_geneid = this_trait.geneid
return (trait_geneid, geneid_dict, species)
def get_tissue_correlation_input(this_trait, trait_symbol_dict):
"""Gets tissue expression values for the primary trait and target tissues values"""
primary_trait_tissue_vals_dict = correlation_functions.get_trait_symbol_and_tissue_values(
symbol_list=[this_trait.symbol])
if this_trait.symbol and this_trait.symbol.lower() in primary_trait_tissue_vals_dict:
primary_trait_tissue_values = primary_trait_tissue_vals_dict[this_trait.symbol.lower(
)]
corr_result_tissue_vals_dict = correlation_functions.get_trait_symbol_and_tissue_values(
symbol_list=list(trait_symbol_dict.values()))
primary_tissue_data = {
"this_id": this_trait.name,
"tissue_values": primary_trait_tissue_values
}
target_tissue_data = {
"trait_symbol_dict": trait_symbol_dict,
"symbol_tissue_vals_dict": corr_result_tissue_vals_dict
}
return (primary_tissue_data, target_tissue_data)
|