about summary refs log tree commit diff
diff options
context:
space:
mode:
authorAlexander Kabui2022-03-18 13:49:19 +0300
committerGitHub2022-03-18 13:49:19 +0300
commit205ebb2d9501f431984359b15467cf573803b0a4 (patch)
treee1882e773f89294d15c8d3cc139fa0234838bc2a
parent13d00e885600157cc253e9d03f26e712fed17346 (diff)
parent6359dc2bf8973991072634e6a2b8d6a8a038166a (diff)
downloadgenenetwork2-205ebb2d9501f431984359b15467cf573803b0a4.tar.gz
Merge pull request #671 from Alexanderlacuna/feature/gn3-pca
Replace pca rpy2 code
-rw-r--r--wqflask/wqflask/correlation_matrix/show_corr_matrix.py160
1 files changed, 38 insertions, 122 deletions
diff --git a/wqflask/wqflask/correlation_matrix/show_corr_matrix.py b/wqflask/wqflask/correlation_matrix/show_corr_matrix.py
index e7b16e77..88d62045 100644
--- a/wqflask/wqflask/correlation_matrix/show_corr_matrix.py
+++ b/wqflask/wqflask/correlation_matrix/show_corr_matrix.py
@@ -19,27 +19,25 @@
 # This module is used by GeneNetwork project (www.genenetwork.org)
 
 import datetime
-import math
 import random
 import string
-
-import rpy2.robjects as ro
-from rpy2.robjects.packages import importr
-
 import numpy as np
 import scipy
 
-from base import data_set
+from base.data_set import create_dataset
 from base.webqtlConfig import GENERATED_TEXT_DIR
-from functools import reduce
-from functools import cmp_to_key
-from utility import webqtlUtil
-from utility import helper_functions
-from utility import corr_result_helpers
+
+
+from utility.helper_functions import get_trait_db_obs
+from utility.corr_result_helpers import normalize_values
 from utility.redis_tools import get_redis_conn
 
-Redis = get_redis_conn()
-THIRTY_DAYS = 60 * 60 * 24 * 30
+
+from gn3.computations.pca import compute_pca
+from gn3.computations.pca import process_factor_loadings_tdata
+from gn3.computations.pca import generate_pca_temp_traits
+from gn3.computations.pca import cache_pca_dataset
+
 
 class CorrelationMatrix:
 
@@ -47,11 +45,10 @@ class CorrelationMatrix:
         trait_db_list = [trait.strip()
                          for trait in start_vars['trait_list'].split(',')]
 
-        helper_functions.get_trait_db_obs(self, trait_db_list)
+        get_trait_db_obs(self, trait_db_list)
 
         self.all_sample_list = []
         self.traits = []
-        self.insufficient_shared_samples = False
         self.do_PCA = True
         # ZS: Getting initial group name before verifying all traits are in the same group in the following loop
         this_group = self.trait_list[0][1].group.name
@@ -116,7 +113,7 @@ class CorrelationMatrix:
                         if sample in self.shared_samples_list:
                             self.shared_samples_list.remove(sample)
 
-                this_trait_vals, target_vals, num_overlap = corr_result_helpers.normalize_values(
+                this_trait_vals, target_vals, num_overlap = normalize_values(
                     this_trait_vals, target_vals)
 
                 if num_overlap < self.lowest_overlap:
@@ -165,16 +162,13 @@ class CorrelationMatrix:
 
         self.pca_works = "False"
         try:
-            corr_result_eigen = np.linalg.eig(np.array(self.pca_corr_results))
-            corr_eigen_value, corr_eigen_vectors = sortEigenVectors(
-                corr_result_eigen)
 
-            if self.do_PCA == True:
+            if self.do_PCA:
                 self.pca_works = "True"
                 self.pca_trait_ids = []
-                pca = self.calculate_pca(
-                    list(range(len(self.traits))), corr_eigen_value, corr_eigen_vectors)
-                self.loadings_array = self.process_loadings()
+                pca = self.calculate_pca()
+                self.loadings_array = process_factor_loadings_tdata(
+                    factor_loadings=self.loadings, traits_num=len(self.trait_list))
             else:
                 self.pca_works = "False"
         except:
@@ -187,66 +181,31 @@ class CorrelationMatrix:
                             samples=self.all_sample_list,
                             sample_data=self.sample_data,)
 
-    def calculate_pca(self, cols, corr_eigen_value, corr_eigen_vectors):
-        base = importr('base')
-        stats = importr('stats')
-
-        corr_results_to_list = ro.FloatVector(
-            [item for sublist in self.pca_corr_results for item in sublist])
-
-        m = ro.r.matrix(corr_results_to_list, nrow=len(cols))
-        eigen = base.eigen(m)
-        pca = stats.princomp(m, cor="TRUE")
-        self.loadings = pca.rx('loadings')
-        self.scores = pca.rx('scores')
-        self.scale = pca.rx('scale')
+    def calculate_pca(self):
 
-        trait_array = zScore(self.trait_data_array)
-        trait_array_vectors = np.dot(corr_eigen_vectors, trait_array)
+        pca = compute_pca(self.pca_corr_results)
 
-        pca_traits = []
-        for i, vector in enumerate(trait_array_vectors):
-            # ZS: Check if below check is necessary
-            # if corr_eigen_value[i-1] > 100.0/len(self.trait_list):
-            pca_traits.append((vector * -1.0).tolist())
+        self.loadings = pca["components"]
+        self.scores = pca["scores"]
 
         this_group_name = self.trait_list[0][1].group.name
-        temp_dataset = data_set.create_dataset(
-            dataset_name="Temp", dataset_type="Temp", group_name=this_group_name)
+        temp_dataset = create_dataset(
+            dataset_name="Temp", dataset_type="Temp",
+            group_name=this_group_name)
         temp_dataset.group.get_samplelist()
-        for i, pca_trait in enumerate(pca_traits):
-            trait_id = "PCA" + str(i + 1) + "_" + temp_dataset.group.species + "_" + \
-                this_group_name + "_" + datetime.datetime.now().strftime("%m%d%H%M%S")
-            this_vals_string = ""
-            position = 0
-            for sample in temp_dataset.group.all_samples_ordered():
-                if sample in self.shared_samples_list:
-                    this_vals_string += str(pca_trait[position])
-                    this_vals_string += " "
-                    position += 1
-                else:
-                    this_vals_string += "x "
-            this_vals_string = this_vals_string[:-1]
 
-            Redis.set(trait_id, this_vals_string, ex=THIRTY_DAYS)
-            self.pca_trait_ids.append(trait_id)
+        pca_temp_traits = generate_pca_temp_traits(species=temp_dataset.group.species, group=this_group_name,
+                                                   traits_data=self.trait_data_array, corr_array=self.pca_corr_results,
+                                                   dataset_samples=temp_dataset.group.all_samples_ordered(),
+                                                   shared_samples=self.shared_samples_list,
+                                                   create_time=datetime.datetime.now().strftime("%m%d%H%M%S"))
 
-        return pca
+        cache_pca_dataset(redis_conn=get_redis_conn(
+        ), exp_days=60 * 60 * 24 * 30, pca_trait_dict=pca_temp_traits)
 
-    def process_loadings(self):
-        loadings_array = []
-        loadings_row = []
-        for i in range(len(self.trait_list)):
-            loadings_row = []
-            if len(self.trait_list) > 2:
-                the_range = 3
-            else:
-                the_range = 2
-            for j in range(the_range):
-                position = i + len(self.trait_list) * j
-                loadings_row.append(self.loadings[0][position])
-            loadings_array.append(loadings_row)
-        return loadings_array
+        self.pca_trait_ids = list(pca_temp_traits.keys())
+
+        return pca
 
 
 def export_corr_matrix(corr_results):
@@ -261,11 +220,11 @@ def export_corr_matrix(corr_results):
         output_file.write("\n")
         output_file.write("Correlation ")
         for i, item in enumerate(corr_results[0]):
-            output_file.write("Trait" + str(i + 1) + ": " + \
+            output_file.write("Trait" + str(i + 1) + ": " +
                               str(item[0].dataset.name) + "::" + str(item[0].name) + "\t")
         output_file.write("\n")
         for i, row in enumerate(corr_results):
-            output_file.write("Trait" + str(i + 1) + ": " + \
+            output_file.write("Trait" + str(i + 1) + ": " +
                               str(row[0][0].dataset.name) + "::" + str(row[0][0].name) + "\t")
             for item in row:
                 output_file.write(str(item[1]) + "\t")
@@ -275,57 +234,14 @@ def export_corr_matrix(corr_results):
         output_file.write("\n")
         output_file.write("N ")
         for i, item in enumerate(corr_results[0]):
-            output_file.write("Trait" + str(i) + ": " + \
+            output_file.write("Trait" + str(i) + ": " +
                               str(item[0].dataset.name) + "::" + str(item[0].name) + "\t")
         output_file.write("\n")
         for i, row in enumerate(corr_results):
-            output_file.write("Trait" + str(i) + ": " + \
+            output_file.write("Trait" + str(i) + ": " +
                               str(row[0][0].dataset.name) + "::" + str(row[0][0].name) + "\t")
             for item in row:
                 output_file.write(str(item[2]) + "\t")
             output_file.write("\n")
 
     return corr_matrix_filename, matrix_export_path
-
-
-def zScore(trait_data_array):
-    NN = len(trait_data_array[0])
-    if NN < 10:
-        return trait_data_array
-    else:
-        i = 0
-        for data in trait_data_array:
-            N = len(data)
-            S = reduce(lambda x, y: x + y, data, 0.)
-            SS = reduce(lambda x, y: x + y * y, data, 0.)
-            mean = S / N
-            var = SS - S * S / N
-            stdev = math.sqrt(var / (N - 1))
-            if stdev == 0:
-                stdev = 1e-100
-            data2 = [(x - mean) / stdev for x in data]
-            trait_data_array[i] = data2
-            i += 1
-        return trait_data_array
-
-
-def sortEigenVectors(vector):
-    try:
-        eigenValues = vector[0].tolist()
-        eigenVectors = vector[1].T.tolist()
-        combines = []
-        i = 0
-        for item in eigenValues:
-            combines.append([eigenValues[i], eigenVectors[i]])
-            i += 1
-        sorted(combines, key=cmp_to_key(webqtlUtil.cmpEigenValue))
-        A = []
-        B = []
-        for item in combines:
-            A.append(item[0])
-            B.append(item[1])
-        sum = reduce(lambda x, y: x + y, A, 0.0)
-        A = [x * 100.0 / sum for x in A]
-        return [A, B]
-    except:
-        return []