aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gn3/computations/correlations2.py21
-rw-r--r--tests/unit/computations/test_correlation.py5
2 files changed, 7 insertions, 19 deletions
diff --git a/gn3/computations/correlations2.py b/gn3/computations/correlations2.py
index 69921b1..d0222ae 100644
--- a/gn3/computations/correlations2.py
+++ b/gn3/computations/correlations2.py
@@ -6,7 +6,7 @@ FUNCTIONS:
compute_correlation:
TODO: Describe what the function does..."""
-from math import sqrt
+from scipy import stats
## From GN1: mostly for clustering and heatmap generation
def __items_with_values(dbdata, userdata):
@@ -16,24 +16,11 @@ def __items_with_values(dbdata, userdata):
return tuple(zip(*filtered)) if filtered else ([], [])
def compute_correlation(dbdata, userdata):
- """Compute some form of correlation.
+ """Compute the Pearson correlation coefficient.
This is extracted from
https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/utility/webqtlUtil.py#L622-L647
"""
x_items, y_items = __items_with_values(dbdata, userdata)
- if len(x_items) < 6:
- return (0.0, len(x_items))
- meanx = sum(x_items)/len(x_items)
- meany = sum(y_items)/len(y_items)
- def cal_corr_vals(acc, item):
- xitem, yitem = item
- return [
- acc[0] + ((xitem - meanx) * (yitem - meany)),
- acc[1] + ((xitem - meanx) * (xitem - meanx)),
- acc[2] + ((yitem - meany) * (yitem - meany))]
- xyd, sxd, syd = reduce(cal_corr_vals, zip(x_items, y_items), [0.0, 0.0, 0.0])
- try:
- return ((xyd/(sqrt(sxd)*sqrt(syd))), len(x_items))
- except ZeroDivisionError:
- return(0, len(x_items))
+ correlation = stats.pearsonr(x_items, y_items)[0] if len(x_items) >= 6 else 0
+ return (correlation, len(x_items))
diff --git a/tests/unit/computations/test_correlation.py b/tests/unit/computations/test_correlation.py
index e6cf198..d60dd62 100644
--- a/tests/unit/computations/test_correlation.py
+++ b/tests/unit/computations/test_correlation.py
@@ -4,6 +4,7 @@ from unittest import mock
import unittest
from collections import namedtuple
+import math
from numpy.testing import assert_almost_equal
from gn3.computations.correlations import normalize_values
@@ -471,10 +472,10 @@ class TestCorrelation(TestCase):
[None, None, None, None, None, None, None, None, None, 0],
(0.0, 1)],
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- (0, 10)],
+ (math.nan, 10)],
[[9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87],
[9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87, 9.87],
- (0.9999999999999998, 10)],
+ (math.nan, 10)],
[[9.3, 2.2, 5.4, 7.2, 6.4, 7.6, 3.8, 1.8, 8.4, 0.2],
[0.6, 3.97, 5.82, 8.21, 1.65, 4.55, 6.72, 9.5, 7.33, 2.34],
(-0.12720361919462056, 10)],