import json
import math
from redis import Redis
Redis = Redis()
from gn2.base.trait import create_trait, retrieve_sample_data
from gn2.base import data_set, webqtlCaseData
from gn2.utility import corr_result_helpers
from gn2.wqflask.oauth2.collections import num_collections
from scipy import stats
import numpy as np
import logging
logger = logging.getLogger(__name__)
class CorrScatterPlot:
"""Page that displays a correlation scatterplot with a line fitted to it"""
def __init__(self, params):
if "Temp" in params['dataset_1']:
temp_group = params['trait_1'].split("_")[2]
self.dataset_1 = data_set.create_dataset(
dataset_name="Temp", dataset_type="Temp", group_name=temp_group)
else:
self.dataset_1 = data_set.create_dataset(params['dataset_1'])
if "Temp" in params['dataset_2']:
temp_group = params['trait_2'].split("_")[2]
self.dataset_2 = data_set.create_dataset(
dataset_name="Temp", dataset_type="Temp", group_name=temp_group)
else:
self.dataset_2 = data_set.create_dataset(params['dataset_2'])
self.trait_1 = create_trait(
name=params['trait_1'], dataset=self.dataset_1)
self.trait_2 = create_trait(
name=params['trait_2'], dataset=self.dataset_2)
self.method = params['method']
primary_samples = self.dataset_1.group.samplelist
if self.dataset_1.group.parlist != None:
primary_samples += self.dataset_1.group.parlist
if self.dataset_1.group.f1list != None:
primary_samples += self.dataset_1.group.f1list
self.effect_plot = True if 'effect' in params else False
if 'dataid' in params:
trait_data_dict = json.loads(Redis.get(params['dataid']))
trait_data = {key:webqtlCaseData.webqtlCaseData(key, float(trait_data_dict[key])) for (key, value) in trait_data_dict.items() if trait_data_dict[key] != "x"}
trait_1_data = trait_data
trait_2_data = self.trait_2.data
# Check if the cached data should be used for the second trait instead
if 'cached_trait' in params:
if params['cached_trait'] == 'trait_2':
trait_2_data = trait_data
trait_1_data = self.trait_1.data
samples_1, samples_2, num_overlap = corr_result_helpers.normalize_values_with_samples(
trait_1_data, trait_2_data)
else:
samples_1, samples_2, num_overlap = corr_result_helpers.normalize_values_with_samples(
self.trait_1.data, self.trait_2.data)
self.data = []
self.indIDs = list(samples_1.keys())
vals_1 = []
for sample in list(samples_1.keys()):
vals_1.append(samples_1[sample].value)
self.data.append(vals_1)
vals_2 = []
for sample in list(samples_2.keys()):
vals_2.append(samples_2[sample].value)
self.data.append(vals_2)
slope, intercept, r_value, p_value, std_err = stats.linregress(
vals_1, vals_2)
if slope < 0.001:
slope_string = '%.3E' % slope
else:
slope_string = '%.3f' % slope
x_buffer = (max(vals_1) - min(vals_1)) * 0.1
y_buffer = (max(vals_2) - min(vals_2)) * 0.1
x_range = [min(vals_1) - x_buffer, max(vals_1) + x_buffer]
y_range = [min(vals_2) - y_buffer, max(vals_2) + y_buffer]
intercept_coords = get_intercept_coords(
slope, intercept, x_range, y_range)
rx = stats.rankdata(vals_1)
ry = stats.rankdata(vals_2)
self.rdata = []
self.rdata.append(rx.tolist())
self.rdata.append(ry.tolist())
srslope, srintercept, srr_value, srp_value, srstd_err = stats.linregress(
rx, ry)
if srslope < 0.001:
srslope_string = '%.3E' % srslope
else:
srslope_string = '%.3f' % srslope
x_buffer = (max(rx) - min(rx)) * 0.1
y_buffer = (max(ry) - min(ry)) * 0.1
sr_range = [min(rx) - x_buffer, max(rx) + x_buffer]
sr_intercept_coords = get_intercept_coords(
srslope, srintercept, sr_range, sr_range)
self.collections_exist = "False"
if num_collections() > 0:
self.collections_exist = "True"
self.js_data = dict(
data=self.data,
effect_plot=self.effect_plot,
rdata=self.rdata,
indIDs=self.indIDs,
trait_1=self.trait_1.dataset.name + ": " + str(self.trait_1.name),
trait_2=self.trait_2.dataset.name + ": " + str(self.trait_2.name),
samples_1=samples_1,
samples_2=samples_2,
num_overlap=num_overlap,
vals_1=vals_1,
vals_2=vals_2,
x_range=x_range,
y_range=y_range,
sr_range=sr_range,
intercept_coords=intercept_coords,
sr_intercept_coords=sr_intercept_coords,
slope=slope,
slope_string=slope_string,
intercept=intercept,
r_value=r_value,
p_value=p_value,
srslope=srslope,
srslope_string=srslope_string,
srintercept=srintercept,
srr_value=srr_value,
srp_value=srp_value
)
self.jsdata = self.js_data
def get_intercept_coords(slope, intercept, x_range, y_range):
intercept_coords = []
y1 = slope * x_range[0] + intercept
y2 = slope * x_range[1] + intercept
x1 = (y1 - intercept) / slope
x2 = (y2 - intercept) / slope
intercept_coords.append([x1, y1])
intercept_coords.append([x2, y2])
return intercept_coords