about summary refs log tree commit diff
path: root/gn3
diff options
context:
space:
mode:
authorArun Isaac2021-11-11 16:10:35 +0530
committerBonfaceKilz2021-11-11 20:49:30 +0300
commitec1d2180d99e0cde1dc181ee9ed79e86cf1a5675 (patch)
tree61162273cc9b0b75577573b8b2082a1660cb39c1 /gn3
parent4e790f08000825931cb5edec1738d2b7d073f73e (diff)
downloadgenenetwork3-ec1d2180d99e0cde1dc181ee9ed79e86cf1a5675.tar.gz
Reimplement correlations2.compute_correlation using pearsonr.
correlations2.compute_correlation computes the Pearson correlation
coefficient. Outsource this computation to scipy.stats.pearsonr. When the
inputs are constant, the Pearson correlation coefficient does not exist and is
represented by NaN. Update the tests to reflect this.

* gn3/computations/correlations2.py: Remove import of sqrt from math.
(compute_correlation): Reimplement using scipy.stats.pearsonr.
* tests/unit/computations/test_correlation.py: Import math.
(TestCorrelation.test_compute_correlation): When inputs are constant, set
expected correlation coefficient to NaN.
Diffstat (limited to 'gn3')
-rw-r--r--gn3/computations/correlations2.py21
1 files changed, 4 insertions, 17 deletions
diff --git a/gn3/computations/correlations2.py b/gn3/computations/correlations2.py
index 69921b1..d0222ae 100644
--- a/gn3/computations/correlations2.py
+++ b/gn3/computations/correlations2.py
@@ -6,7 +6,7 @@ FUNCTIONS:
 compute_correlation:
     TODO: Describe what the function does..."""
 
-from math import sqrt
+from scipy import stats
 ## From GN1: mostly for clustering and heatmap generation
 
 def __items_with_values(dbdata, userdata):
@@ -16,24 +16,11 @@ def __items_with_values(dbdata, userdata):
     return tuple(zip(*filtered)) if filtered else ([], [])
 
 def compute_correlation(dbdata, userdata):
-    """Compute some form of correlation.
+    """Compute the Pearson correlation coefficient.
 
     This is extracted from
     https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/utility/webqtlUtil.py#L622-L647
     """
     x_items, y_items = __items_with_values(dbdata, userdata)
-    if len(x_items) < 6:
-        return (0.0, len(x_items))
-    meanx = sum(x_items)/len(x_items)
-    meany = sum(y_items)/len(y_items)
-    def cal_corr_vals(acc, item):
-        xitem, yitem = item
-        return [
-            acc[0] + ((xitem - meanx) * (yitem - meany)),
-            acc[1] + ((xitem - meanx) * (xitem - meanx)),
-            acc[2] + ((yitem - meany) * (yitem - meany))]
-    xyd, sxd, syd = reduce(cal_corr_vals, zip(x_items, y_items), [0.0, 0.0, 0.0])
-    try:
-        return ((xyd/(sqrt(sxd)*sqrt(syd))), len(x_items))
-    except ZeroDivisionError:
-        return(0, len(x_items))
+    correlation = stats.pearsonr(x_items, y_items)[0] if len(x_items) >= 6 else 0
+    return (correlation, len(x_items))