aboutsummaryrefslogtreecommitdiff
# Copyright (C) University of Tennessee Health Science Center, Memphis, TN.
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License
# as published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU Affero General Public License for more details.
#
# This program is available from Source Forge: at GeneNetwork Project
# (sourceforge.net/projects/genenetwork/).
#
# Contact Dr. Robert W. Williams at rwilliams@uthsc.edu
#
#
# This module is used by GeneNetwork project (www.genenetwork.org)

import datetime
import random
import string
import numpy as np
import scipy

from gn2.base.data_set import create_dataset
from gn2.base.webqtlConfig import GENERATED_TEXT_DIR


from gn2.utility.helper_functions import get_trait_db_obs
from gn2.utility.corr_result_helpers import normalize_values
from gn2.utility.redis_tools import get_redis_conn


from gn3.computations.pca import compute_pca
from gn3.computations.pca import process_factor_loadings_tdata
from gn3.computations.pca import generate_pca_temp_traits
from gn3.computations.pca import cache_pca_dataset
from gn3.computations.pca import generate_scree_plot_data


class CorrelationMatrix:

    def __init__(self, start_vars):
        trait_db_list = [trait.strip()
                         for trait in start_vars['trait_list'].split(',')]

        get_trait_db_obs(self, trait_db_list)

        self.all_sample_list = []
        self.traits = []
        self.do_PCA = True
        # ZS: Getting initial group name before verifying all traits are in the same group in the following loop
        this_group = self.trait_list[0][1].group.name
        for trait_db in self.trait_list:
            this_group = trait_db[1].group.name
            this_trait = trait_db[0]
            self.traits.append(this_trait)
            this_sample_data = this_trait.data

            for sample in this_sample_data:
                if sample not in self.all_sample_list:
                    self.all_sample_list.append(sample)

        self.sample_data = []
        for trait_db in self.trait_list:
            this_trait = trait_db[0]
            this_sample_data = this_trait.data

            this_trait_vals = []
            for sample in self.all_sample_list:
                if sample in this_sample_data:
                    this_trait_vals.append(this_sample_data[sample].value)
                else:
                    this_trait_vals.append('')
            self.sample_data.append(this_trait_vals)

        # Shouldn't do PCA if there are more traits than observations/samples
        if len(this_trait_vals) < len(self.trait_list) or len(self.trait_list) < 3:
            self.do_PCA = False

        # ZS: Variable set to the lowest overlapping samples in order to notify user, or 8, whichever is lower (since 8 is when we want to display warning)
        self.lowest_overlap = 8

        self.corr_results = []
        self.pca_corr_results = []
        self.scree_data = []
        self.shared_samples_list = self.all_sample_list
        for trait_db in self.trait_list:
            this_trait = trait_db[0]
            this_db = trait_db[1]

            this_db_samples = this_db.group.all_samples_ordered()
            this_sample_data = this_trait.data

            corr_result_row = []
            pca_corr_result_row = []
            is_spearman = False  # ZS: To determine if it's above or below the diagonal
            for target in self.trait_list:
                target_trait = target[0]
                target_db = target[1]
                target_samples = target_db.group.all_samples_ordered()
                target_sample_data = target_trait.data

                this_trait_vals = []
                target_vals = []
                for index, sample in enumerate(target_samples):
                    if (sample in this_sample_data) and (sample in target_sample_data):
                        sample_value = this_sample_data[sample].value
                        target_sample_value = target_sample_data[sample].value
                        this_trait_vals.append(sample_value)
                        target_vals.append(target_sample_value)
                    else:
                        if sample in self.shared_samples_list:
                            self.shared_samples_list.remove(sample)

                this_trait_vals, target_vals, num_overlap = normalize_values(
                    this_trait_vals, target_vals)

                if num_overlap < self.lowest_overlap:
                    self.lowest_overlap = num_overlap
                if num_overlap < 2:
                    corr_result_row.append([target_trait, 0, num_overlap])
                    pca_corr_result_row.append(0)
                else:
                    pearson_r, pearson_p = scipy.stats.pearsonr(
                        this_trait_vals, target_vals)
                    if is_spearman == False:
                        sample_r, sample_p = pearson_r, pearson_p
                        if sample_r > 0.999:
                            is_spearman = True
                    else:
                        sample_r, sample_p = scipy.stats.spearmanr(
                            this_trait_vals, target_vals)

                    corr_result_row.append(
                        [target_trait, sample_r, num_overlap])
                    pca_corr_result_row.append(pearson_r)

            self.corr_results.append(corr_result_row)
            self.pca_corr_results.append(pca_corr_result_row)

        self.export_filename, self.export_filepath = export_corr_matrix(
            self.corr_results)

        self.trait_data_array = []
        for trait_db in self.trait_list:
            this_trait = trait_db[0]
            this_db = trait_db[1]
            this_db_samples = this_db.group.all_samples_ordered()
            this_sample_data = this_trait.data

            this_trait_vals = []
            for index, sample in enumerate(this_db_samples):
                if (sample in this_sample_data) and (sample in self.shared_samples_list):
                    sample_value = this_sample_data[sample].value
                    this_trait_vals.append(sample_value)
            self.trait_data_array.append(this_trait_vals)

        groups = []
        for sample in self.all_sample_list:
            groups.append(1)

        self.pca_works = "False"
        try:

            if self.do_PCA:
                self.pca_works = "True"
                self.pca_trait_ids = []
                pca = self.calculate_pca()
                self.loadings_array = process_factor_loadings_tdata(
                    factor_loadings=self.loadings, traits_num=len(self.trait_list))

            else:
                self.pca_works = "False"
        except:
            self.pca_works = "False"

        self.js_data = dict(traits=[trait.name for trait in self.traits],
                            groups=groups,
                            scree_data = self.scree_data,
                            cols=list(range(len(self.traits))),
                            rows=list(range(len(self.traits))),
                            samples=self.all_sample_list,
                            sample_data=self.sample_data,)

    def calculate_pca(self):

        pca = compute_pca(self.pca_corr_results)

        self.loadings = pca["components"]
        self.scores = pca["scores"]
        self.pca_obj = pca["pca"]

        this_group_name = self.trait_list[0][1].group.name
        temp_dataset = create_dataset(
            dataset_name="Temp", dataset_type="Temp",
            group_name=this_group_name)
        temp_dataset.group.get_samplelist(redis_conn=get_redis_conn())

        pca_temp_traits = generate_pca_temp_traits(species=temp_dataset.group.species, group=this_group_name,
                                                   traits_data=self.trait_data_array, corr_array=self.pca_corr_results,
                                                   dataset_samples=temp_dataset.group.all_samples_ordered(),
                                                   shared_samples=self.shared_samples_list,
                                                   create_time=datetime.datetime.now().strftime("%m%d%H%M%S"))

        cache_pca_dataset(redis_conn=get_redis_conn(
        ), exp_days=60 * 60 * 24 * 30, pca_trait_dict=pca_temp_traits)

        self.pca_trait_ids = list(pca_temp_traits.keys())

        x_coord, y_coord = generate_scree_plot_data(
            list(self.pca_obj.explained_variance_ratio_))

        self.scree_data = {
            "x_coord": x_coord,
            "y_coord": y_coord
        }
        return pca


def export_corr_matrix(corr_results):
    corr_matrix_filename = "corr_matrix_" + \
        ''.join(random.choice(string.ascii_uppercase + string.digits)
                for _ in range(6))
    matrix_export_path = "{}{}.csv".format(
        GENERATED_TEXT_DIR, corr_matrix_filename)
    with open(matrix_export_path, "w+") as output_file:
        output_file.write(
            "Time/Date: " + datetime.datetime.now().strftime("%x / %X") + "\n")
        output_file.write("\n")
        output_file.write("Correlation ")
        for i, item in enumerate(corr_results[0]):
            output_file.write("Trait" + str(i + 1) + ": " +
                              str(item[0].dataset.name) + "::" + str(item[0].name) + "\t")
        output_file.write("\n")
        for i, row in enumerate(corr_results):
            output_file.write("Trait" + str(i + 1) + ": " +
                              str(row[0][0].dataset.name) + "::" + str(row[0][0].name) + "\t")
            for item in row:
                output_file.write(str(item[1]) + "\t")
            output_file.write("\n")

        output_file.write("\n")
        output_file.write("\n")
        output_file.write("N ")
        for i, item in enumerate(corr_results[0]):
            output_file.write("Trait" + str(i) + ": " +
                              str(item[0].dataset.name) + "::" + str(item[0].name) + "\t")
        output_file.write("\n")
        for i, row in enumerate(corr_results):
            output_file.write("Trait" + str(i) + ": " +
                              str(row[0][0].dataset.name) + "::" + str(row[0][0].name) + "\t")
            for item in row:
                output_file.write(str(item[2]) + "\t")
            output_file.write("\n")

    return corr_matrix_filename, matrix_export_path