aboutsummaryrefslogtreecommitdiff
import json
import math

from redis import Redis
Redis = Redis()

from gn2.base.trait import create_trait, retrieve_sample_data
from gn2.base import data_set, webqtlCaseData
from gn2.utility import corr_result_helpers
from gn2.wqflask.oauth2.collections import num_collections

from scipy import stats
import numpy as np

import logging
logger = logging.getLogger(__name__)

class CorrScatterPlot:
    """Page that displays a correlation scatterplot with a line fitted to it"""

    def __init__(self, params):
        if "Temp" in params['dataset_1']:
            temp_group = params['trait_1'].split("_")[2]
            self.dataset_1 = data_set.create_dataset(
                dataset_name="Temp", dataset_type="Temp", group_name=temp_group)
        else:
            self.dataset_1 = data_set.create_dataset(params['dataset_1'])
        if "Temp" in params['dataset_2']:
            temp_group = params['trait_2'].split("_")[2]
            self.dataset_2 = data_set.create_dataset(
                dataset_name="Temp", dataset_type="Temp", group_name=temp_group)
        else:
            self.dataset_2 = data_set.create_dataset(params['dataset_2'])

        self.trait_1 = create_trait(
            name=params['trait_1'], dataset=self.dataset_1)
        self.trait_2 = create_trait(
            name=params['trait_2'], dataset=self.dataset_2)

        self.method = params['method']

        primary_samples = self.dataset_1.group.samplelist
        if self.dataset_1.group.parlist != None:
            primary_samples += self.dataset_1.group.parlist
        if self.dataset_1.group.f1list != None:
            primary_samples += self.dataset_1.group.f1list

        self.effect_plot = True if 'effect' in params else False

        if 'dataid' in params:
            trait_data_dict = json.loads(Redis.get(params['dataid']))
            trait_data = {key:webqtlCaseData.webqtlCaseData(key, float(trait_data_dict[key])) for (key, value) in trait_data_dict.items() if trait_data_dict[key] != "x"}
            trait_1_data = trait_data
            trait_2_data = self.trait_2.data
            # Check if the cached data should be used for the second trait instead
            if 'cached_trait' in params:
                if params['cached_trait'] == 'trait_2':
                    trait_2_data = trait_data
                    trait_1_data = self.trait_1.data
            samples_1, samples_2, num_overlap = corr_result_helpers.normalize_values_with_samples(
                trait_1_data, trait_2_data)
        else:
            samples_1, samples_2, num_overlap = corr_result_helpers.normalize_values_with_samples(
                self.trait_1.data, self.trait_2.data)

        self.data = []
        self.indIDs = list(samples_1.keys())
        vals_1 = []
        for sample in list(samples_1.keys()):
            vals_1.append(samples_1[sample].value)
        self.data.append(vals_1)
        vals_2 = []
        for sample in list(samples_2.keys()):
            vals_2.append(samples_2[sample].value)
        self.data.append(vals_2)

        slope, intercept, r_value, p_value, std_err = stats.linregress(
            vals_1, vals_2)

        if slope < 0.001:
            slope_string = '%.3E' % slope
        else:
            slope_string = '%.3f' % slope

        x_buffer = (max(vals_1) - min(vals_1)) * 0.1
        y_buffer = (max(vals_2) - min(vals_2)) * 0.1

        x_range = [min(vals_1) - x_buffer, max(vals_1) + x_buffer]
        y_range = [min(vals_2) - y_buffer, max(vals_2) + y_buffer]

        intercept_coords = get_intercept_coords(
            slope, intercept, x_range, y_range)

        rx = stats.rankdata(vals_1)
        ry = stats.rankdata(vals_2)
        self.rdata = []
        self.rdata.append(rx.tolist())
        self.rdata.append(ry.tolist())
        srslope, srintercept, srr_value, srp_value, srstd_err = stats.linregress(
            rx, ry)

        if srslope < 0.001:
            srslope_string = '%.3E' % srslope
        else:
            srslope_string = '%.3f' % srslope

        x_buffer = (max(rx) - min(rx)) * 0.1
        y_buffer = (max(ry) - min(ry)) * 0.1

        sr_range = [min(rx) - x_buffer, max(rx) + x_buffer]

        sr_intercept_coords = get_intercept_coords(
            srslope, srintercept, sr_range, sr_range)

        self.collections_exist = "False"
        if num_collections() > 0:
                self.collections_exist = "True"

        self.js_data = dict(
            data=self.data,
            effect_plot=self.effect_plot,
            rdata=self.rdata,
            indIDs=self.indIDs,
            trait_1=self.trait_1.dataset.name + ": " + str(self.trait_1.name),
            trait_2=self.trait_2.dataset.name + ": " + str(self.trait_2.name),
            samples_1=samples_1,
            samples_2=samples_2,
            num_overlap=num_overlap,
            vals_1=vals_1,
            vals_2=vals_2,
            x_range=x_range,
            y_range=y_range,
            sr_range=sr_range,
            intercept_coords=intercept_coords,
            sr_intercept_coords=sr_intercept_coords,

            slope=slope,
            slope_string=slope_string,
            intercept=intercept,
            r_value=r_value,
            p_value=p_value,

            srslope=srslope,
            srslope_string=srslope_string,
            srintercept=srintercept,
            srr_value=srr_value,
            srp_value=srp_value
        )
        self.jsdata = self.js_data


def get_intercept_coords(slope, intercept, x_range, y_range):
    intercept_coords = []

    y1 = slope * x_range[0] + intercept
    y2 = slope * x_range[1] + intercept
    x1 = (y1 - intercept) / slope
    x2 = (y2 - intercept) / slope

    intercept_coords.append([x1, y1])
    intercept_coords.append([x2, y2])

    return intercept_coords