From 43d1bb7f6cd2b5890d5b3eb7c357caafda25a35c Mon Sep 17 00:00:00 2001
From: Alexander Kabui
Date: Tue, 16 Mar 2021 10:36:58 +0300
Subject: Refactor/clean up correlations (#4)

* initial commit for Refactor/clean-up-correlation

* add python scipy dependency

* initial commit for sample correlation

* initial commit for sample correlation endpoint

* initial commit for integration and unittest

* initial commit for registering  correlation blueprint

* add and modify unittest and integration tests for correlation

* Add compute compute_all_sample_corr   method for correlation

* add scipy to requirement txt file

* add tissue correlation for trait list

* add unittest for tissue correlation

* add lit correlation for trait list

* add unittests for lit correlation for trait list

* modify lit correlarion for trait list

* add unittests for lit correlation for trait list

* add correlation metho  in dynamic url

* add file format for expected structure input  while doing sample correlation

* modify input data structure -> add  trait id

* update tests for sample r correlation

* add compute all lit correlation method

* add endpoint for computing lit_corr

* add unit and integration tests for computing lit corr

* add /api/correlation/tissue_corr/{corr_method} endpoint for tissue correlation

* add unittest and integration tests for tissue correlation

Co-authored-by: BonfaceKilz <bonfacemunyoki@gmail.com>---
 gn3/api/correlation.py                             |  77 ++--
 gn3/computations/correlations.py                   | 305 ++++++++++++++++
 guix.scm                                           |   1 +
 requirements.txt                                   |  13 +-
 tests/integration/test_correlation.py              | 118 ++++--
 .../correlation_test_data/target_dataset.json      | 230 ++++++++++++
 .../correlation_test_data/this_trait_data.json     |  76 ++++
 tests/unit/computations/test_correlation.py        | 399 +++++++++++++++++++++
 8 files changed, 1149 insertions(+), 70 deletions(-)
 create mode 100644 gn3/computations/correlations.py
 create mode 100644 tests/unit/computations/correlation_test_data/target_dataset.json
 create mode 100644 tests/unit/computations/correlation_test_data/this_trait_data.json
 create mode 100644 tests/unit/computations/test_correlation.py

diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py
index 217b7ce..56b8381 100644
--- a/gn3/api/correlation.py
+++ b/gn3/api/correlation.py
@@ -1,44 +1,63 @@
-"""Endpoints for computing correlation"""
-import time
-from flask import Blueprint
+"""Endpoints for running correlations"""
+from unittest import mock
+
 from flask import jsonify
+from flask import Blueprint
 from flask import request
-from flask import g
-from sqlalchemy import create_engine
 
-from default_settings import SQL_URI
-from gn3.correlation.correlation_computations import compute_correlation
+from gn3.computations.correlations import compute_all_sample_correlation
+from gn3.computations.correlations import compute_all_lit_correlation
+from gn3.computations.correlations import compute_all_tissue_correlation
+
 
 correlation = Blueprint("correlation", __name__)
 
 
-# xtodo implement neat db setup
-@correlation.before_request
-def connect_db():
-    """add connection to db method"""
-    print("@app.before_request connect_db")
-    db_connection = getattr(g, '_database', None)
-    if db_connection is None:
-        print("Get new database connector")
-        g.db = g._database = create_engine(SQL_URI, encoding="latin1")
+@correlation.route("/sample_r/<string:corr_method>", methods=["POST"])
+def compute_sample_r(corr_method="pearson"):
+    """correlation endpoint for computing sample r correlations\
+    api expects the trait data with has the trait and also the\
+    target_dataset  data"""
+    correlation_input = request.get_json()
+
+    # xtodo move code below to compute_all_sampl correlation
+    this_trait_data = correlation_input.get("this_trait")
+    target_datasets = correlation_input.get("target_dataset")
+
+    correlation_results = compute_all_sample_correlation(corr_method=corr_method,
+                                                         this_trait=this_trait_data,
+                                                         target_dataset=target_datasets)
+
+    return jsonify({
+        "corr_results": correlation_results
+    })
+
 
-    g.initial_time = time.time()
+@correlation.route("/lit_corr/<string:species>/<int:gene_id>", methods=["POST"])
+def compute_lit_corr(species=None, gene_id=None):
+    """api endpoint for doing lit correlation.results for lit correlation\
+    are fetched from the database this is the only case where the db\
+    might be needed for actual computing of the correlation results"""
 
+    database_instance = mock.Mock()
+    target_traits_gene_ids = request.get_json()
 
-@correlation.route("/corr_compute", methods=["POST"])
-def corr_compute_page():
-    """api for doing  correlation"""
+    lit_corr_results = compute_all_lit_correlation(
+        database_instance=database_instance, trait_lists=target_traits_gene_ids,
+        species=species, gene_id=gene_id)
 
-    correlation_input = request.json
+    return jsonify(lit_corr_results)
 
-    if correlation_input is None:
-        return jsonify({"error": str("Bad request")}), 400
 
-    try:
-        corr_results = compute_correlation(
-            correlation_input_data=correlation_input)
+@correlation.route("/tissue_corr/<string:corr_method>", methods=["POST"])
+def compute_tissue_corr(corr_method="pearson"):
+    """api endpoint fr doing tissue correlation"""
+    tissue_input_data = request.get_json()
+    primary_tissue_dict = tissue_input_data["primary_tissue"]
+    target_tissues_dict_list = tissue_input_data["target_tissues"]
 
-    except Exception as error:  # pylint: disable=broad-except
-        return jsonify({"error": str(error)})
+    results = compute_all_tissue_correlation(primary_tissue_dict=primary_tissue_dict,
+                                             target_tissues_dict_list=target_tissues_dict_list,
+                                             corr_method=corr_method)
 
-    return {"correlation_results": corr_results}
+    return jsonify(results)
\ No newline at end of file
diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py
new file mode 100644
index 0000000..21f5929
--- /dev/null
+++ b/gn3/computations/correlations.py
@@ -0,0 +1,305 @@
+"""module contains code for correlations"""
+from typing import List
+from typing import Tuple
+from typing import Optional
+from typing import Callable
+
+import scipy.stats  # type: ignore
+
+
+def compute_sum(rhs: int, lhs: int)-> int:
+    """initial tests to compute  sum  of two numbers"""
+    return rhs + lhs
+
+
+def normalize_values(a_values: List, b_values: List)->Tuple[List[float], List[float], int]:
+    """
+    Trim two lists of values to contain only the values they both share
+
+    Given two lists of sample values, trim each list so that it contains
+    only the samples that contain a value in both lists. Also returns
+    the number of such samples.
+
+    >>> normalize_values([2.3, None, None, 3.2, 4.1, 5], [3.4, 7.2, 1.3, None, 6.2, 4.1])
+    ([2.3, 4.1, 5], [3.4, 6.2, 4.1], 3)
+
+    """
+    a_new = []
+    b_new = []
+    for a_val, b_val in zip(a_values, b_values):
+        if (a_val and b_val is not None):
+            a_new.append(a_val)
+            b_new.append(b_val)
+    return a_new, b_new, len(a_new)
+
+
+def compute_corr_coeff_p_value(primary_values: List, target_values: List, corr_method: str)->\
+        Tuple[float, float]:
+    """given array like inputs calculate the primary and target_value
+     methods ->pearson,spearman and biweight mid correlation
+     return value is rho and p_value
+    """
+    corr_mapping = {
+        "bicor": do_bicor,
+        "pearson": scipy.stats.pearsonr,
+        "spearman": scipy.stats.spearmanr
+    }
+
+    use_corr_method = corr_mapping.get(corr_method, "spearman")
+
+    corr_coeffient, p_val = use_corr_method(primary_values, target_values)
+
+    return (corr_coeffient, p_val)
+
+
+def compute_sample_r_correlation(corr_method: str, trait_vals, target_samples_vals)->\
+        Optional[Tuple[float, float, int]]:
+    """Given a primary trait values and target trait values
+    calculate the correlation coeff and p value"""
+
+    sanitized_traits_vals, sanitized_target_vals,\
+        num_overlap = normalize_values(trait_vals, target_samples_vals)
+
+    if num_overlap > 5:
+
+        (corr_coeffient, p_value) =\
+            compute_corr_coeff_p_value(primary_values=sanitized_traits_vals,
+                                       target_values=sanitized_target_vals,
+                                       corr_method=corr_method)
+
+        # xtodo check if corr_coefficient is None should use numpy.isNan scipy.isNan is deprecated
+        if corr_coeffient is not None:
+            return (corr_coeffient, p_value, num_overlap)
+
+    return None
+
+
+def do_bicor(x_val, y_val) -> Tuple[float, float]:
+    """not implemented method for doing biweight mid correlation
+    use  astropy stats package :not packaged in guix
+    """
+
+    return (x_val, y_val)
+
+
+def filter_shared_sample_keys(this_samplelist, target_samplelist)->Tuple[List, List]:
+    """given primary and target samplelist for two base and target\
+    trait select filter the values using the shared keys"""
+    this_vals = []
+    target_vals = []
+
+    for key, value in target_samplelist.items():
+        if key in this_samplelist:
+            target_vals.append(value)
+            this_vals.append(this_samplelist[key])
+
+    return (this_vals, target_vals)
+
+
+def compute_all_sample_correlation(this_trait, target_dataset, corr_method="pearson")->List:
+    """given a trait data samplelist and target__datasets compute all sample correlation"""
+
+    this_trait_samples = this_trait["trait_sample_data"]
+
+    corr_results = []
+
+    for target_trait in target_dataset:
+        trait_id = target_trait.get("trait_id")
+        target_trait_data = target_trait["trait_sample_data"]
+        this_vals, target_vals = filter_shared_sample_keys(
+            this_trait_samples, target_trait_data)
+
+        sample_correlation = compute_sample_r_correlation(
+            corr_method=corr_method, trait_vals=this_vals, target_samples_vals=target_vals)
+
+        if sample_correlation is not None:
+            (corr_coeffient, p_value, num_overlap) = sample_correlation
+
+        else:
+            continue
+
+        corr_result = {"corr_coeffient": corr_coeffient,
+                       "p_value": p_value,
+                       "num_overlap": num_overlap}
+
+        corr_results.append({trait_id: corr_result})
+
+    return corr_results
+
+
+def tissue_lit_corr_for_probe_type(corr_type: str, top_corr_results):
+    """function that does either lit_corr_for_trait_list or tissue_corr\
+    _for_trait list depending on whether both dataset and target_dataset are\
+    both set to probet"""
+
+    corr_results = {"lit": 1}
+
+    if corr_type not in ("lit", "literature"):
+
+        corr_results["top_corr_results"] = top_corr_results
+        # run lit_correlation for  the given  top_corr_results
+    if corr_type == "tissue":
+        # run lit correlation the given top corr results
+        pass
+    if corr_type == "sample":
+        pass
+        # run sample r correlation for the given top  results
+
+    return corr_results
+
+
+def tissue_correlation_for_trait_list(primary_tissue_vals: List,
+                                      target_tissues_values: List,
+                                      corr_method: str,
+                                      compute_corr_p_value: Callable =
+                                      compute_corr_coeff_p_value)->dict:
+    """given a primary tissue values for a trait and the target tissues values\
+    compute the correlation_cooeff and p value  the input required are arrays\
+    output - > List containing Dicts with corr_coefficient value,P_value and\
+    also the tissue numbers is len(primary) == len(target)"""
+
+    # ax :todo assertion that lenggth one one target tissue ==primary_tissue
+
+    (tissue_corr_coeffient, p_value) = compute_corr_p_value(
+        primary_values=primary_tissue_vals,
+        target_values=target_tissues_values,
+        corr_method=corr_method)
+
+    lit_corr_result = {
+        "tissue_corr": tissue_corr_coeffient,
+        "p_value": p_value,
+        "tissue_number": len(primary_tissue_vals)
+    }
+
+    return lit_corr_result
+
+
+def fetch_lit_correlation_data(database,
+                               input_mouse_gene_id: Optional[str],
+                               gene_id: str,
+                               mouse_gene_id: Optional[str] = None)->Tuple[str, float]:
+    """given input trait mouse gene id and mouse gene id fetch the lit\
+    corr_data"""
+    if mouse_gene_id is not None and ";" not in mouse_gene_id:
+        query = """
+        SELECT VALUE
+        FROM  LCorrRamin3
+        WHERE GeneId1='%s' and
+        GeneId2='%s'
+        """
+
+        query_values = (str(mouse_gene_id), str(input_mouse_gene_id))
+
+        results = database.execute(
+            query_formatter(query, *query_values)).fetchone()
+
+        lit_corr_results = results if results is not None else database.execute(
+            query_formatter(query, *tuple(reversed(query_values)))).fetchone()
+
+        lit_results = (gene_id, lit_corr_results.val)\
+            if lit_corr_results else (gene_id, 0)
+        return lit_results
+
+    return (gene_id, 0)
+
+
+def lit_correlation_for_trait_list(database,
+                                   target_trait_lists: List,
+                                   species: Optional[str] = None,
+                                   trait_gene_id: Optional[str] = None)->List:
+    """given species,base trait gene id fetch the lit corr results from the db\
+    output is float for lit corr results """
+    fetched_lit_corr_results = []
+
+    this_trait_mouse_gene_id = map_to_mouse_gene_id(
+        database=database, species=species, gene_id=trait_gene_id)
+
+    for trait in target_trait_lists:
+        target_trait_gene_id = trait.get("gene_id")
+        if target_trait_gene_id:
+            target_mouse_gene_id = map_to_mouse_gene_id(
+                database=database, species=species, gene_id=target_trait_gene_id)
+
+            fetched_corr_data = fetch_lit_correlation_data(
+                database=database, input_mouse_gene_id=this_trait_mouse_gene_id,
+                gene_id=target_trait_gene_id, mouse_gene_id=target_mouse_gene_id)
+
+            dict_results = dict(
+                zip(("gene_id", "lit_corr"), fetched_corr_data))
+            fetched_lit_corr_results.append(dict_results)
+
+    return fetched_lit_corr_results
+
+
+def query_formatter(query_string: str, * query_values):
+    """formatter query string given the unformatted query string\
+    and the respectibe values.Assumes number of placeholders is
+    equal to the number of query values """
+    results = query_string % (query_values)
+
+    return results
+
+
+def map_to_mouse_gene_id(database, species: Optional[str], gene_id: Optional[str])->Optional[str]:
+    """given a species which is not mouse map the gene_id\
+    to respective mouse gene id"""
+    # AK:xtodo move the code for checking nullity out of thing functions bug while\
+    # method for string
+    if None in (species, gene_id):
+        return None
+    if species == "mouse":
+        return gene_id
+
+    query = """SELECT mouse
+                FROM GeneIDXRef
+                WHERE '%s' = '%s'"""
+
+    query_values = (species, gene_id)
+
+    results = database.execute(
+        query_formatter(query, *query_values)).fetchone()
+
+    mouse_gene_id = results.mouse if results is not None else None
+
+    return mouse_gene_id
+
+
+def compute_all_lit_correlation(database_instance, trait_lists: List, species: str, gene_id):
+    """function that acts as an abstraction for lit_correlation_for_trait_list"""
+    # xtodo to be refactored
+
+    lit_results = lit_correlation_for_trait_list(database=database_instance,
+                                                 target_trait_lists=trait_lists,
+                                                 species=species,
+                                                 trait_gene_id=gene_id
+                                                 )
+
+    return {
+        "lit_results": lit_results
+    }
+
+
+def compute_all_tissue_correlation(primary_tissue_dict: dict,
+                                   target_tissues_dict_list: List,
+                                   corr_method: str):
+    """function acts as an abstraction for tissue_correlation_for_trait_list\
+    required input are target tissue object and primary tissue trait """
+
+    tissues_results = {}
+
+    primary_tissue_vals = primary_tissue_dict["tissue_values"]
+
+    target_tissues_list = target_tissues_dict_list
+
+    for target_tissue_obj in target_tissues_list:
+        trait_id = target_tissue_obj.get("trait_id")
+
+        target_tissue_vals = target_tissue_obj.get("tissue_values")
+
+        tissue_result = tissue_correlation_for_trait_list(primary_tissue_vals=primary_tissue_vals,
+                                                          target_tissues_values=target_tissue_vals,
+                                                          corr_method=corr_method)
+
+        tissues_results[trait_id] = tissue_result
+
+    return tissues_results
diff --git a/guix.scm b/guix.scm
index 45bb3fa..503694c 100644
--- a/guix.scm
+++ b/guix.scm
@@ -73,6 +73,7 @@
                        ("python-flask" ,python-flask)
                        ("python-pylint" python-pylint)
                        ("python-numpy" ,python-numpy)
+                       ("python-scipy" ,python-scipy)
                        ("python-mypy" ,python-mypy)
                        ("python-mypy-extensions" ,python-mypy-extensions)
                        ("python-redis" ,python-redis)
diff --git a/requirements.txt b/requirements.txt
index e495e19..e4dc881 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,9 +4,16 @@ Flask==1.1.2
 itsdangerous==1.1.0
 Jinja2==2.11.3
 MarkupSafe==1.1.1
-mysqlclient==2.0.1
+mccabe==0.6.1
+mypy==0.790
+mypy-extensions==0.4.3
 numpy==1.20.1
+pycparser==2.20
+pylint==2.5.3
+redis==3.5.3
 scipy==1.6.0
-SQLAlchemy==1.3.20
-sqlalchemy-stubs==0.4
+six==1.15.0
+toml==0.10.2
+typed-ast==1.4.2
+typing-extensions==3.7.4.3
 Werkzeug==1.0.1
diff --git a/tests/integration/test_correlation.py b/tests/integration/test_correlation.py
index 33e0de9..488a8a4 100644
--- a/tests/integration/test_correlation.py
+++ b/tests/integration/test_correlation.py
@@ -1,57 +1,99 @@
-"""Integration tests for correlation api"""
-
-import os
-import json
-import unittest
+"""module contains integration tests for correlation"""
+from unittest import TestCase
 from unittest import mock
-
 from gn3.app import create_app
 
 
-def file_path(relative_path):
-    """getting abs path for file """
-    dir_name = os.path.dirname(os.path.abspath(__file__))
-    split_path = relative_path.split("/")
-    new_path = os.path.join(dir_name, *split_path)
-    return new_path
+class CorrelationIntegrationTest(TestCase):
+    """class for correlation integration tests"""
 
-
-class CorrelationAPITest(unittest.TestCase):
-    # currently disable
-    """Test cases for the Correlation API"""
     def setUp(self):
         self.app = create_app().test_client()
 
-        with open(file_path("correlation_data.json")) as json_file:
-            self.correlation_data = json.load(json_file)
+    def test_fail(self):
+        """initial method for class that fails"""
+        self.assertEqual(2, 2)
+
+    @mock.patch("gn3.api.correlation.compute_all_sample_correlation")
+    def test_sample_r_correlation(self, mock_compute_samples):
+        """Test /api/correlation/sample_r/{method}"""
+        this_trait_data = {
+            "trait_id": "1455376_at",
+            "trait_sample_data": {
+                "C57BL/6J": "6.138",
+                "DBA/2J": "6.266",
+                "B6D2F1": "6.434",
+                "D2B6F1": "6.55",
+                "BXS2": "6.7"
+            }}
+
+        traits_dataset = [
+            {
+                "trait_id": "14192_at",
+                "trait_sample_data": {
+                    "DBA/2J": "7.13",
+                    "D2B6F1": "5.65",
+                    "BXD2": "1.46"
+                }
+            }
+        ]
+
+        correlation_input_data = {"corr_method": "pearson",
+                                  "this_trait": this_trait_data,
+                                  "target_dataset": traits_dataset}
+
+        expected_results = [
+            {
+                "sample_r": "-0.407",
+                "p_value": "6.234e-04"
+            },
+            {
+                "sample_r": "0.398",
+                "sample_p": "8.614e-04"
+            }
+        ]
+
+        mock_compute_samples.return_value = expected_results
+
+        api_response = {
+            "corr_results": expected_results
+        }
+
+        response = self.app.post("/api/correlation/sample_r/pearson",
+                                 json=correlation_input_data, follow_redirects=True)
 
-        with open(file_path("expected_corr_results.json")) as results_file:
-            self.correlation_results = json.load(results_file)
+        self.assertEqual(response.status_code, 200)
+        self.assertEqual(response.get_json(), api_response)
 
-    def tearDown(self):
-        self.correlation_data = ""
+    @mock.patch("gn3.api.correlation.compute_all_lit_correlation")
+    def test_lit_correlation(self, mock_compute_corr):
+        """Test api/correlation/lit_corr/{species}/{gene_id}"""
 
-        self.correlation_results = ""
+        mock_compute_corr.return_value = []
 
-    @mock.patch("gn3.api.correlation.compute_correlation")
-    def test_corr_compute(self, compute_corr):
-        """Test that the correct response in correlation"""
+        post_data = [{"gene_id": 8, "lit_corr": 1}, {
+            "gene_id": 12, "lit_corr": 0.3}]
 
-        compute_corr.return_value = self.correlation_results
-        response = self.app.post("/api/correlation/corr_compute",
-                                 json=self.correlation_data,
-                                 follow_redirects=True)
+        response = self.app.post(
+            "/api/correlation/lit_corr/mouse/16", json=post_data, follow_redirects=True)
 
+        self.assertEqual(mock_compute_corr.call_count, 1)
         self.assertEqual(response.status_code, 200)
 
-    @mock.patch("gn3.api.correlation.compute_correlation")
-    def test_corr_compute_failed_request(self, compute_corr):
-        """test taht cormpute requests fails """
+    @mock.patch("gn3.api.correlation.compute_all_tissue_correlation")
+    def test_tissue_correlation(self, mock_tissue_corr):
+        """Test api/correlation/tissue_corr/{corr_method}"""
+        mock_tissue_corr.return_value = {}
 
-        compute_corr.return_value = self.correlation_results
+        primary_dict = {"trait_id": "1449593_at", "tissue_values": [1, 2, 3]}
 
-        response = self.app.post("/api/correlation/corr_compute",
-                                 json=None,
-                                 follow_redirects=True)
+        target_tissue_dict_list = [
+            {"trait_id": "1449593_at", "tissue_values": [1, 2, 3]}]
 
-        self.assertEqual(response.status_code, 400)
+        tissue_corr_input_data = {"primary_tissue": primary_dict,
+                                  "target_tissues": target_tissue_dict_list}
+
+        response = self.app.post("/api/correlation/tissue_corr/spearman",
+                                 json=tissue_corr_input_data, follow_redirects=True)
+
+        self.assertEqual(response.status_code, 200)
diff --git a/tests/unit/computations/correlation_test_data/target_dataset.json b/tests/unit/computations/correlation_test_data/target_dataset.json
new file mode 100644
index 0000000..f6757b6
--- /dev/null
+++ b/tests/unit/computations/correlation_test_data/target_dataset.json
@@ -0,0 +1,230 @@
+[
+   {
+      "trait_id":"1425637_at",
+      "sample_data":{
+         "BXD1":7.081,
+         "BXD2":6.912,
+         "BXD5":7.153,
+         "BXD6":6.92,
+         "BXD8":6.886,
+         "BXD9":7.406,
+         "BXD11":6.917,
+         "BXD12":6.914,
+         "BXD13":6.964,
+         "BXD15":6.863,
+         "BXD16":7.06,
+         "BXD19":7.002,
+         "BXD20":7.158,
+         "BXD21":7.039,
+         "BXD22":7.036,
+         "BXD23":6.962,
+         "BXD24":6.946,
+         "BXD27":7.084,
+         "BXD28":7.154,
+         "BXD29":6.932,
+         "BXD31":6.994,
+         "BXD32":6.846,
+         "BXD33":7.078,
+         "BXD34":6.94,
+         "BXD38":6.992,
+         "BXD39":7.048,
+         "BXD40":7.14,
+         "BXD42":6.98,
+         "BXD43":7.072,
+         "BXD44":7.045,
+         "BXD45":6.739,
+         "BXD48":7.07,
+         "BXD48a":6.998,
+         "BXD50":7.053,
+         "BXD51":6.922,
+         "BXD55":6.782,
+         "BXD60":7.042,
+         "BXD61":6.887,
+         "BXD62":6.86,
+         "BXD63":6.815,
+         "BXD64":7.424,
+         "BXD65":7.216,
+         "BXD65a":6.934,
+         "BXD65b":6.893,
+         "BXD66":6.935,
+         "BXD67":6.985,
+         "BXD68":7.044,
+         "BXD69":6.908,
+         "BXD70":6.864,
+         "BXD73":7.074,
+         "BXD73a":6.986,
+         "BXD74":6.914,
+         "BXD75":6.98,
+         "BXD76":6.772,
+         "BXD77":7.121,
+         "BXD79":6.829,
+         "BXD83":7.018,
+         "BXD84":6.948,
+         "BXD85":7.112,
+         "BXD86":6.858,
+         "BXD87":6.865,
+         "BXD89":7.034,
+         "BXD90":6.901,
+         "BXD93":6.97,
+         "BXD94":7.112,
+         "BXD98":6.954,
+         "BXD99":6.912,
+         "C57BL/6J":7.121,
+         "DBA/2J":6.821,
+         "B6D2F1":6.998,
+         "D2B6F1":6.967
+      }
+   },
+   {
+      "trait_id":"1455376_at",
+      "trait_sample_data":{
+         "BXD1":10.929,
+         "BXD2":11.279,
+         "BXD5":11.941,
+         "BXD6":11.407,
+         "BXD8":12.048,
+         "BXD9":11.694,
+         "BXD11":11.534,
+         "BXD12":11.048,
+         "BXD13":12.274,
+         "BXD15":12.077,
+         "BXD16":11.91,
+         "BXD19":11.797,
+         "BXD20":11.67,
+         "BXD21":12.062,
+         "BXD22":12.49,
+         "BXD23":11.957,
+         "BXD24":11.766,
+         "BXD27":13.026,
+         "BXD28":12.184,
+         "BXD29":11.792,
+         "BXD31":12.36,
+         "BXD32":10.608,
+         "BXD33":11.817,
+         "BXD34":11.213,
+         "BXD38":11.212,
+         "BXD39":12.023,
+         "BXD40":12.892,
+         "BXD42":11.518,
+         "BXD43":12.306,
+         "BXD44":11.932,
+         "BXD45":10.982,
+         "BXD48":12.055,
+         "BXD48a":12.572,
+         "BXD50":11.696,
+         "BXD51":11.828,
+         "BXD55":10.523,
+         "BXD60":11.403,
+         "BXD61":11.378,
+         "BXD62":11.887,
+         "BXD63":11.776,
+         "BXD64":12.37,
+         "BXD65":11.122,
+         "BXD65a":10.853,
+         "BXD65b":11.46,
+         "BXD66":11.546,
+         "BXD67":12.198,
+         "BXD68":13.21,
+         "BXD69":11.581,
+         "BXD70":12.338,
+         "BXD73":11.876,
+         "BXD73a":11.75,
+         "BXD74":11.898,
+         "BXD75":11.718,
+         "BXD76":11.926,
+         "BXD77":12.326,
+         "BXD79":12.052,
+         "BXD83":11.478,
+         "BXD84":11.494,
+         "BXD85":11.435,
+         "BXD86":11.476,
+         "BXD87":11.456,
+         "BXD89":11.547,
+         "BXD90":12.452,
+         "BXD93":12.921,
+         "BXD94":11.892,
+         "BXD98":12.614,
+         "BXD99":13.142,
+         "C57BL/6J":12.138,
+         "DBA/2J":11.394,
+         "B6D2F1":11.615,
+         "D2B6F1":11.918
+      }
+   },
+   {
+      "trait_id":"1444351_at",
+      "trait_sample_data":{
+         "BXD1":17.847,
+         "BXD2":15.262,
+         "BXD5":18.054,
+         "BXD6":17.24,
+         "BXD8":15.735,
+         "BXD9":17.876,
+         "BXD11":17.359,
+         "BXD12":17.906,
+         "BXD13":16.084,
+         "BXD15":17.173,
+         "BXD16":15.941,
+         "BXD19":17.721,
+         "BXD20":17.548,
+         "BXD21":17.242,
+         "BXD22":17.012,
+         "BXD23":17.139,
+         "BXD24":17.904,
+         "BXD27":17.008,
+         "BXD28":17.441,
+         "BXD29":17.606,
+         "BXD31":17.35,
+         "BXD32":17.859,
+         "BXD33":17.453,
+         "BXD34":15.924,
+         "BXD38":17.271,
+         "BXD39":18.034,
+         "BXD40":17.844,
+         "BXD42":17.444,
+         "BXD43":17.676,
+         "BXD44":17.71,
+         "BXD45":17.059,
+         "BXD48":17.334,
+         "BXD48a":17.398,
+         "BXD50":17.343,
+         "BXD51":17.514,
+         "BXD55":14.995,
+         "BXD60":18.03,
+         "BXD61":17.628,
+         "BXD62":17.431,
+         "BXD63":16.96,
+         "BXD64":18.199,
+         "BXD65":17.593,
+         "BXD65a":17.49,
+         "BXD65b":17.268,
+         "BXD66":16.602,
+         "BXD67":17.306,
+         "BXD68":17.167,
+         "BXD69":17.706,
+         "BXD70":17.287,
+         "BXD73":17.412,
+         "BXD73a":16.224,
+         "BXD74":16.873,
+         "BXD75":17.202,
+         "BXD76":16.934,
+         "BXD77":17.926,
+         "BXD79":16.55,
+         "BXD83":17.042,
+         "BXD84":17.134,
+         "BXD85":18.021,
+         "BXD86":17.194,
+         "BXD87":17.075,
+         "BXD89":17.511,
+         "BXD90":17.168,
+         "BXD93":17.817,
+         "BXD94":18.04,
+         "BXD98":16.744,
+         "BXD99":17.304,
+         "C57BL/6J":17.084,
+         "DBA/2J":17.316,
+         "B6D2F1":16.964,
+         "D2B6F1":17.086
+      }
+   }
+]
\ No newline at end of file
diff --git a/tests/unit/computations/correlation_test_data/this_trait_data.json b/tests/unit/computations/correlation_test_data/this_trait_data.json
new file mode 100644
index 0000000..7c57fdb
--- /dev/null
+++ b/tests/unit/computations/correlation_test_data/this_trait_data.json
@@ -0,0 +1,76 @@
+{
+  "trait_id":"1457784_at",
+  "trait_sample_data":{
+  "BXD1": 6.03,
+  "BXD2": 6.001,
+  "BXD5": 6.154,
+  "BXD6": 6.179,
+  "BXD8": 6.2,
+  "BXD9": 6.062,
+  "BXD11": 6.12,
+  "BXD12": 6.159,
+  "BXD13": 6.153,
+  "BXD15": 6.144,
+  "BXD16": 6.212,
+  "BXD19": 6.206,
+  "BXD20": 6.008,
+  "BXD21": 6.062,
+  "BXD22": 6.042,
+  "BXD23": 6.135,
+  "BXD24": 6.144,
+  "BXD27": 6.316,
+  "BXD28": 6.14,
+  "BXD29": 6.222,
+  "BXD31": 6.211,
+  "BXD32": 5.984,
+  "BXD33": 6.128,
+  "BXD34": 6.086,
+  "BXD38": 6.342,
+  "BXD39": 6.111,
+  "BXD40": 6.136,
+  "BXD42": 6.201,
+  "BXD43": 5.934,
+  "BXD44": 6.116,
+  "BXD45": 6.226,
+  "BXD48": 6.228,
+  "BXD48a": 6.16,
+  "BXD50": 5.92,
+  "BXD51": 6.227,
+  "BXD55": 6.137,
+  "BXD60": 5.932,
+  "BXD61": 6.18,
+  "BXD62": 6.188,
+  "BXD63": 6.134,
+  "BXD64": 6.102,
+  "BXD65": 6.258,
+  "BXD65a": 6.031,
+  "BXD65b": 6.088,
+  "BXD66": 6.07,
+  "BXD67": 6.275,
+  "BXD68": 6.116,
+  "BXD69": 6.031,
+  "BXD70": 6.14,
+  "BXD73": 6.089,
+  "BXD73a": 6.195,
+  "BXD74": 5.971,
+  "BXD75": 5.972,
+  "BXD76": 6.125,
+  "BXD77": 6.107,
+  "BXD79": 6.288,
+  "BXD83": 6.119,
+  "BXD84": 6.102,
+  "BXD85": 5.959,
+  "BXD86": 6.249,
+  "BXD87": 6.172,
+  "BXD89": 6.13,
+  "BXD90": 6.162,
+  "BXD93": 6.19,
+  "BXD94": 6.068,
+  "BXD98": 6.137,
+  "BXD99": 6.252,
+  "C57BL/6J": 6.255,
+  "DBA/2J": 6.14,
+  "B6D2F1": 6.223,
+  "D2B6F1": 6.038
+}
+}
\ No newline at end of file
diff --git a/tests/unit/computations/test_correlation.py b/tests/unit/computations/test_correlation.py
new file mode 100644
index 0000000..84b9330
--- /dev/null
+++ b/tests/unit/computations/test_correlation.py
@@ -0,0 +1,399 @@
+"""module contains the tests for correlation"""
+import unittest
+from unittest import TestCase
+from unittest import mock
+
+from collections import namedtuple
+
+from gn3.computations.correlations import normalize_values
+from gn3.computations.correlations import do_bicor
+from gn3.computations.correlations import compute_sample_r_correlation
+from gn3.computations.correlations import compute_all_sample_correlation
+from gn3.computations.correlations import filter_shared_sample_keys
+from gn3.computations.correlations import tissue_lit_corr_for_probe_type
+from gn3.computations.correlations import tissue_correlation_for_trait_list
+from gn3.computations.correlations import lit_correlation_for_trait_list
+from gn3.computations.correlations import fetch_lit_correlation_data
+from gn3.computations.correlations import query_formatter
+from gn3.computations.correlations import map_to_mouse_gene_id
+from gn3.computations.correlations import compute_all_lit_correlation
+from gn3.computations.correlations import compute_all_tissue_correlation
+
+
+class QueryableMixin:
+    """base class for db call"""
+
+    def execute(self, query_options):
+        """base method for execute"""
+        raise NotImplementedError()
+
+    def fetchone(self):
+        """base method for fetching one iten"""
+        raise NotImplementedError()
+
+    def fetchall(self):
+        """base method for fetch all items"""
+        raise NotImplementedError()
+
+
+class IllegalOperationError(Exception):
+    """custom error to raise illegal operation in db"""
+
+    def __init__(self):
+        super().__init__("Operation not permitted!")
+
+
+class DataBase(QueryableMixin):
+    """Class for creating db object"""
+
+    def __init__(self):
+        self.__query_options = None
+        self.__results = None
+
+    def execute(self, query_options):
+        """method to execute an sql query"""
+        self.__query_options = query_options
+        self.results_generator()
+        return self
+
+    def fetchone(self):
+        """method to fetch single item from the db query"""
+        if self.__results is None:
+            raise IllegalOperationError()
+
+        return self.__results[0]
+
+    def fetchall(self):
+        """method for fetching all items from db query"""
+        if self.__results is None:
+            raise IllegalOperationError()
+        return self.__results
+
+    def results_generator(self, expected_results=None):
+        """private method  for generating mock results"""
+
+        if expected_results is None:
+            self.__results = [namedtuple("lit_coeff", "val")(x*0.1)
+                              for x in range(1, 4)]
+        else:
+            self.__results = expected_results
+
+
+class TestCorrelation(TestCase):
+    """class for testing correlation functions"""
+
+    def test_normalize_values(self):
+        """function to test normalizing values """
+        results = normalize_values([2.3, None, None, 3.2, 4.1, 5],
+                                   [3.4, 7.2, 1.3, None, 6.2, 4.1])
+
+        expected_results = ([2.3, 4.1, 5], [3.4, 6.2, 4.1], 3)
+
+        self.assertEqual(results, expected_results)
+
+    def test_bicor(self):
+        """test for doing biweight mid correlation """
+
+        results = do_bicor(x_val=[1, 2, 3], y_val=[4, 5, 6])
+
+        self.assertEqual(results, ([1, 2, 3], [4, 5, 6])
+                         )
+
+    @mock.patch("gn3.computations.correlations.compute_corr_coeff_p_value")
+    @mock.patch("gn3.computations.correlations.normalize_values")
+    def test_compute_sample_r_correlation(self, norm_vals, compute_corr):
+        """test for doing sample correlation gets the cor\
+        and p value and rho value using pearson correlation"""
+        primary_values = [2.3, 4.1, 5]
+        target_values = [3.4, 6.2, 4.1]
+
+        norm_vals.return_value = ([2.3, 4.1, 5, 4.2, 4, 1.2],
+                                  [3.4, 6.2, 4, 1.1, 8, 1.1], 6)
+        compute_corr.side_effect = [(0.7, 0.3), (-1.0, 0.9), (1, 0.21)]
+
+        pearson_results = compute_sample_r_correlation(corr_method="pearson",
+                                                       trait_vals=primary_values,
+                                                       target_samples_vals=target_values)
+
+        spearman_results = compute_sample_r_correlation(corr_method="spearman",
+                                                        trait_vals=primary_values,
+                                                        target_samples_vals=target_values)
+
+        bicor_results = compute_sample_r_correlation(corr_method="bicor",
+                                                     trait_vals=primary_values,
+                                                     target_samples_vals=target_values)
+
+        self.assertEqual(bicor_results, (1, 0.21, 6))
+        self.assertEqual(pearson_results, (0.7, 0.3, 6))
+        self.assertEqual(spearman_results, (-1.0, 0.9, 6))
+
+        self.assertIsInstance(
+            pearson_results, tuple, "message")
+        self.assertIsInstance(
+            spearman_results, tuple, "message")
+
+    def test_filter_shared_sample_keys(self):
+        """function to  tests shared key between two dicts"""
+
+        this_samplelist = {
+            "C57BL/6J": "6.638",
+            "DBA/2J": "6.266",
+            "B6D2F1": "6.494",
+            "D2B6F1": "6.565",
+            "BXD2": "6.456"
+        }
+
+        target_samplelist = {
+            "DBA/2J": "1.23",
+            "D2B6F1": "6.565",
+            "BXD2": "6.456"
+
+        }
+
+        filtered_target_samplelist = ["1.23", "6.565", "6.456"]
+        filtered_this_samplelist = ["6.266", "6.565", "6.456"]
+
+        results = filter_shared_sample_keys(
+            this_samplelist=this_samplelist, target_samplelist=target_samplelist)
+
+        self.assertEqual(results, (filtered_this_samplelist,
+                                   filtered_target_samplelist))
+
+    @mock.patch("gn3.computations.correlations.compute_sample_r_correlation")
+    @mock.patch("gn3.computations.correlations.filter_shared_sample_keys")
+    def test_compute_all_sample(self, filter_shared_samples, sample_r_corr):
+        """given target dataset compute all sample r correlation"""
+
+        filter_shared_samples.return_value = (["1.23", "6.565", "6.456"], [
+            "6.266", "6.565", "6.456"])
+        sample_r_corr.return_value = ([-1.0, 0.9, 6])
+
+        this_trait_data = {
+            "trait_id": "1455376_at",
+            "trait_sample_data": {
+                "C57BL/6J": "6.638",
+                "DBA/2J": "6.266",
+                "B6D2F1": "6.494",
+                "D2B6F1": "6.565",
+                "BXD2": "6.456"
+            }}
+
+        traits_dataset = [
+            {
+                "trait_id": "1419792_at",
+                "trait_sample_data": {
+                    "DBA/2J": "1.23",
+                    "D2B6F1": "6.565",
+                    "BXD2": "6.456"
+                }
+            }
+        ]
+
+        sample_all_results = [{"1419792_at": {"corr_coeffient": -1.0,
+                                              "p_value": 0.9,
+                                              "num_overlap": 6}}]
+        # ?corr_method: str, trait_vals, target_samples_vals
+
+        self.assertEqual(compute_all_sample_correlation(
+            this_trait=this_trait_data, target_dataset=traits_dataset), sample_all_results)
+        sample_r_corr.assert_called_once_with(
+            corr_method="pearson", trait_vals=['1.23', '6.565', '6.456'],
+            target_samples_vals=['6.266', '6.565', '6.456'])
+        filter_shared_samples.assert_called_once_with(
+            this_trait_data.get("trait_sample_data"), traits_dataset[0].get("trait_sample_data"))
+
+    @unittest.skip("not implemented")
+    def test_tissue_lit_corr_for_probe_type(self):
+        """tests for doing tissue and lit correlation for  trait list\
+        if both the dataset and target dataset are probeset runs\
+        on after initial correlation has been done"""
+
+        results = tissue_lit_corr_for_probe_type(
+            corr_type="tissue", top_corr_results={})
+
+        self.assertEqual(results, (None, None))
+
+    @mock.patch("gn3.computations.correlations.compute_corr_coeff_p_value")
+    def test_tissue_correlation_for_trait_list(self, mock_compute_corr_coeff):
+        """test given a primary tissue values for a trait  and and a list of\
+        target tissues for traits  do the tissue correlation for them"""
+
+        primary_tissue_values = [1.1, 1.5, 2.3]
+        target_tissues_values = [1, 2, 3]
+        mock_compute_corr_coeff.side_effect = [(0.4, 0.9), (-0.2, 0.91)]
+        expected_tissue_results = {
+            'tissue_corr': 0.4, 'p_value': 0.9, "tissue_number": 3}
+
+        tissue_results = tissue_correlation_for_trait_list(
+            primary_tissue_values, target_tissues_values,
+            corr_method="pearson", compute_corr_p_value=mock_compute_corr_coeff)
+
+        self.assertEqual(tissue_results, expected_tissue_results)
+
+    @mock.patch("gn3.computations.correlations.fetch_lit_correlation_data")
+    @mock.patch("gn3.computations.correlations.map_to_mouse_gene_id")
+    def test_lit_correlation_for_trait_list(self, mock_mouse_gene_id, fetch_lit_data):
+        """fetch results from  db call for lit correlation given a trait list\
+        after doing correlation"""
+
+        target_trait_lists = [{"gene_id": 15},
+                              {"gene_id": 17},
+                              {"gene_id": 11}]
+        mock_mouse_gene_id.side_effect = [12, 11, 18, 16, 20]
+
+        database_instance = namedtuple("database", "execute")("fetchone")
+
+        fetch_lit_data.side_effect = [(15, 9), (17, 8), (11, 12)]
+
+        lit_results = lit_correlation_for_trait_list(
+            database=database_instance, target_trait_lists=target_trait_lists,
+            species="rat", trait_gene_id="12")
+
+        expected_results = [{"gene_id": 15, "lit_corr": 9}, {
+            "gene_id": 17, "lit_corr": 8}, {"gene_id": 11, "lit_corr": 12}]
+
+        self.assertEqual(lit_results, expected_results)
+
+    def test_fetch_lit_correlation_data(self):
+        """test for fetching lit correlation data from\
+        the database where the input and mouse geneid are none"""
+
+        database_instance = DataBase()
+        results = fetch_lit_correlation_data(database=database_instance,
+                                             gene_id="1",
+                                             input_mouse_gene_id=None,
+                                             mouse_gene_id=None)
+
+        self.assertEqual(results, ("1", 0))
+
+    def test_fetch_lit_correlation_data_db_query(self):
+        """test for fetching lit corr coefficent givent the input\
+         input trait mouse gene id and mouse gene id"""
+
+        database_instance = DataBase()
+        expected_results = ("1", 0.1)
+
+        lit_results = fetch_lit_correlation_data(database=database_instance,
+                                                 gene_id="1",
+                                                 input_mouse_gene_id="20",
+                                                 mouse_gene_id="15")
+
+        self.assertEqual(expected_results, lit_results)
+
+    def test_query_lit_correlation_for_db_empty(self):
+        """test that corr coeffient returned is 0 given the\
+        db value if corr coefficient is empty"""
+        database_instance = mock.Mock()
+        database_instance.execute.return_value.fetchone.return_value = None
+
+        lit_results = fetch_lit_correlation_data(database=database_instance,
+                                                 input_mouse_gene_id="12",
+                                                 gene_id="16",
+                                                 mouse_gene_id="12")
+
+        self.assertEqual(lit_results, ("16", 0))
+
+    def test_query_formatter(self):
+        """test for formatting a query given the query string and also the\
+        values"""
+        query = """
+        SELECT VALUE
+        FROM  LCorr
+        WHERE GeneId1='%s' and
+        GeneId2='%s'
+        """
+
+        expected_formatted_query = """
+        SELECT VALUE
+        FROM  LCorr
+        WHERE GeneId1='20' and
+        GeneId2='15'
+        """
+
+        mouse_gene_id = "20"
+        input_mouse_gene_id = "15"
+
+        query_values = (mouse_gene_id, input_mouse_gene_id)
+
+        formatted_query = query_formatter(query, *query_values)
+
+        self.assertEqual(formatted_query, expected_formatted_query)
+
+    def test_query_formatter_no_query_values(self):
+        """test for formatting a query where there are no\
+        string placeholder"""
+        query = """SELECT * FROM  USERS"""
+        formatted_query = query_formatter(query)
+
+        self.assertEqual(formatted_query, query)
+
+    def test_map_to_mouse_gene_id(self):
+        """test for converting a gene id to mouse geneid\
+        given a species which is not mouse"""
+        database_instance = mock.Mock()
+        test_data = [("Human", 14), (None, 9), ("Mouse", 15), ("Rat", 14)]
+
+        database_results = [namedtuple("mouse_id", "mouse")(val)
+                            for val in range(12, 20)]
+        results = []
+
+        database_instance.execute.return_value.fetchone.side_effect = database_results
+        expected_results = [12, None, 13, 14]
+        for (species, gene_id) in test_data:
+
+            mouse_gene_id_results = map_to_mouse_gene_id(
+                database=database_instance, species=species, gene_id=gene_id)
+            results.append(mouse_gene_id_results)
+
+        self.assertEqual(results, expected_results)
+
+    @mock.patch("gn3.computations.correlations.lit_correlation_for_trait_list")
+    def test_compute_all_lit_correlation(self, mock_lit_corr):
+        """test for compute all lit correlation which acts\
+        as an abstraction for lit_correlation_for_trait_list
+        and is used in the api/correlation/lit"""
+
+        database = mock.Mock()
+
+        expected_mocked_lit_results = [{"gene_id": 11, "lit_corr": 9}, {
+            "gene_id": 17, "lit_corr": 8}]
+
+        mock_lit_corr.side_effect = expected_mocked_lit_results
+
+        lit_correlation_results = compute_all_lit_correlation(
+            database_instance=database, trait_lists=[{"gene_id": 11}],
+            species="rat", gene_id=12)
+
+        expected_results = {
+            "lit_results": {"gene_id": 11, "lit_corr": 9}
+        }
+
+        self.assertEqual(lit_correlation_results, expected_results)
+
+    @mock.patch("gn3.computations.correlations.tissue_correlation_for_trait_list")
+    def test_compute_all_tissue_correlation(self, mock_tissue_corr):
+        """test for compute all tissue corelation which abstracts
+        api calling the tissue_correlation for trait_list"""
+
+        primary_tissue_dict = {"trait_id": "1419792_at",
+                               "tissue_values": [1, 2, 3, 4, 5]}
+
+        target_tissue_dict = [{"trait_id": "1418702_a_at", "tissue_values": [1, 2, 3]},
+                              {"trait_id": "1412_at", "tissue_values": [1, 2, 3]}]
+
+        mock_tissue_corr.side_effect = [{"tissue_corr": -0.5, "p_value": 0.9, "tissue_number": 3},
+                                        {"tissue_corr": 1.11, "p_value": 0.2, "tissue_number": 3}]
+
+        expected_results = {"1418702_a_at":
+                            {"tissue_corr": -0.5, "p_value": 0.9, "tissue_number": 3},
+                            "1412_at":
+                            {"tissue_corr": 1.11, "p_value": 0.2, "tissue_number": 3}}
+
+        results = compute_all_tissue_correlation(
+            primary_tissue_dict=primary_tissue_dict,
+            target_tissues_dict_list=target_tissue_dict,
+            corr_method="pearson")
+
+        self.assertEqual(mock_tissue_corr.call_count, 2)
+
+        self.assertEqual(results, expected_results)
-- 
cgit v1.2.3