about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--gn3/computations/correlations2.py21
-rw-r--r--tests/unit/computations/test_correlation.py5
2 files changed, 7 insertions, 19 deletions
diff --git a/gn3/computations/correlations2.py b/gn3/computations/correlations2.py
index 69921b1..d0222ae 100644
--- a/gn3/computations/correlations2.py
+++ b/gn3/computations/correlations2.py
@@ -6,7 +6,7 @@ FUNCTIONS:
 compute_correlation:
     TODO: Describe what the function does..."""
 
-from math import sqrt
+from scipy import stats
 ## From GN1: mostly for clustering and heatmap generation
 
 def __items_with_values(dbdata, userdata):
@@ -16,24 +16,11 @@ def __items_with_values(dbdata, userdata):
     return tuple(zip(*filtered)) if filtered else ([], [])
 
 def compute_correlation(dbdata, userdata):
-    """Compute some form of correlation.
+    """Compute the Pearson correlation coefficient.
 
     This is extracted from
     https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/utility/webqtlUtil.py#L622-L647
     """
     x_items, y_items = __items_with_values(dbdata, userdata)
-    if len(x_items) < 6:
-        return (0.0, len(x_items))
-    meanx = sum(x_items)/len(x_items)
-    meany = sum(y_items)/len(y_items)
-    def cal_corr_vals(acc, item):
-        xitem, yitem = item
-        return [
-            acc[0] + ((xitem - meanx) * (yitem - meany)),
-            acc[1] + ((xitem - meanx) * (xitem - meanx)),
-            acc[2] + ((yitem - meany) * (yitem - meany))]
-    xyd, sxd, syd = reduce(cal_corr_vals, zip(x_items, y_items), [0.0, 0.0, 0.0])
-    try:
-        return ((xyd/(sqrt(sxd)*sqrt(syd))), len(x_items))
-    except ZeroDivisionError:
-        return(0, len(x_items))
+    correlation = stats.pearsonr(x_items, y_items)[0] if len(x_items) >= 6 else 0
+    return (correlation, len(x_items))
diff --git a/tests/unit/computations/test_correlation.py b/tests/unit/computations/test_correlation.py
index e6cf198..d60dd62 100644
--- a/tests/unit/computations/test_correlation.py
+++ b/tests/unit/computations/test_correlation.py
@@ -4,6 +4,7 @@ from unittest import mock
 import unittest
 
 from collections import namedtuple
+import math
 from numpy.testing import assert_almost_equal
 
 from gn3.computations.correlations import normalize_values
@@ -471,10 +472,10 @@ class TestCorrelation(TestCase):
                  [None, None, None, None, None, None, None, None, None, 0],
                  (0.0, 1)],
                 [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-                 (0, 10)],
+                 (math.nan, 10)],
                 [[9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87],
                  [9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87],
-                 (0.9999999999999998, 10)],
+                 (math.nan, 10)],
                 [[9.3, 2.2, 5.4, 7.2, 6.4, 7.6, 3.8, 1.8, 8.4, 0.2],
                  [0.6, 3.97, 5.82, 8.21, 1.65, 4.55, 6.72, 9.5, 7.33, 2.34],
                  (-0.12720361919462056, 10)],