about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--.guix_deploy8
-rw-r--r--.pylintrc6
-rw-r--r--README.md3
-rw-r--r--gn3/api/correlation.py31
-rw-r--r--gn3/authentication.py1
-rw-r--r--gn3/computations/correlations.py5
-rw-r--r--gn3/computations/partial_correlations.py137
-rw-r--r--gn3/db/correlations.py22
-rw-r--r--gn3/db/datasets.py12
-rw-r--r--gn3/db/traits.py48
-rw-r--r--mypy.ini12
-rw-r--r--scripts/wgcna_analysis.R7
-rw-r--r--tests/unit/computations/test_correlation.py1
-rw-r--r--tests/unit/db/test_datasets.py14
-rw-r--r--tests/unit/db/test_traits.py33
15 files changed, 258 insertions, 82 deletions
diff --git a/.guix_deploy b/.guix_deploy
new file mode 100644
index 0000000..c7bbb5b
--- /dev/null
+++ b/.guix_deploy
@@ -0,0 +1,8 @@
+# Deploy script on tux01
+#
+# echo Run tests:
+# echo python -m unittest discover -v
+# echo Run service (single process):
+# echo flask run --port=8080
+
+/home/wrk/opt/guix-pull/bin/guix shell -L /home/wrk/guix-bioinformatics/ --expose=$HOME/production/genotype_files/ -C -N -Df guix.scm
diff --git a/.pylintrc b/.pylintrc
index 0bdef23..00dd6cd 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -1,3 +1,7 @@
 [SIMILARITIES]
 
-ignore-imports=yes
\ No newline at end of file
+ignore-imports=yes
+
+[MESSAGES CONTROL]
+
+disable=fixme
\ No newline at end of file
diff --git a/README.md b/README.md
index d3470ee..5669192 100644
--- a/README.md
+++ b/README.md
@@ -38,7 +38,6 @@ python3
 guix shell -C --network --expose=$HOME/genotype_files/ -Df guix.scm
 ```
 
-
 #### Using a Guix profile (or rolling back)
 
 Create a new profile with
@@ -128,6 +127,8 @@ And for the scalable production version run
 gunicorn --bind 0.0.0.0:8080 --workers 8 --keep-alive 6000 --max-requests 10 --max-requests-jitter 5 --timeout 1200 wsgi:app
 ```
 
+(see also the [.guix_deploy](./.guix_deploy) script)
+
 ## Using python-pip
 
 IMPORTANT NOTE: we do not recommend using pip tools, use Guix instead
diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py
index 46121f8..1caf31f 100644
--- a/gn3/api/correlation.py
+++ b/gn3/api/correlation.py
@@ -1,13 +1,16 @@
 """Endpoints for running correlations"""
+import json
 from flask import jsonify
 from flask import Blueprint
 from flask import request
+from flask import make_response
 
 from gn3.computations.correlations import compute_all_sample_correlation
 from gn3.computations.correlations import compute_all_lit_correlation
 from gn3.computations.correlations import compute_tissue_correlation
 from gn3.computations.correlations import map_shared_keys_to_values
 from gn3.db_utils import database_connector
+from gn3.computations.partial_correlations import partial_correlations_entry
 
 correlation = Blueprint("correlation", __name__)
 
@@ -83,3 +86,31 @@ def compute_tissue_corr(corr_method="pearson"):
                                          corr_method=corr_method)
 
     return jsonify(results)
+
+@correlation.route("/partial", methods=["POST"])
+def partial_correlation():
+    """API endpoint for partial correlations."""
+    def trait_fullname(trait):
+        return f"{trait['dataset']}::{trait['name']}"
+
+    class OutputEncoder(json.JSONEncoder):
+        """
+        Class to encode output into JSON, for objects which the default
+        json.JSONEncoder class does not have default encoding for.
+        """
+        def default(self, obj):
+            if isinstance(obj, bytes):
+                return str(obj, encoding="utf-8")
+            return json.JSONEncoder.default(self, obj)
+
+    args = request.get_json()
+    conn, _cursor_object = database_connector()
+    corr_results = partial_correlations_entry(
+        conn, trait_fullname(args["primary_trait"]),
+        tuple(trait_fullname(trait) for trait in args["control_traits"]),
+        args["method"], int(args["criteria"]), args["target_db"])
+    response = make_response(
+        json.dumps(corr_results, cls=OutputEncoder).replace(": NaN", ": null"),
+        400 if "error" in corr_results.keys() else 200)
+    response.headers["Content-Type"] = "application/json"
+    return response
diff --git a/gn3/authentication.py b/gn3/authentication.py
index a6372c1..d0b35bc 100644
--- a/gn3/authentication.py
+++ b/gn3/authentication.py
@@ -163,3 +163,4 @@ def create_group(conn: Redis, group_name: Optional[str],
         }
         conn.hset("groups", group_id, json.dumps(group))
         return group
+    return None
diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py
index 37c70e9..09288c5 100644
--- a/gn3/computations/correlations.py
+++ b/gn3/computations/correlations.py
@@ -8,6 +8,7 @@ from typing import List
 from typing import Tuple
 from typing import Optional
 from typing import Callable
+from typing import Generator
 
 import scipy.stats
 import pingouin as pg
@@ -80,7 +81,7 @@ def compute_sample_r_correlation(trait_name, corr_method, trait_vals,
             zip(*list(normalize_values(trait_vals, target_samples_vals))))
         num_overlap = len(normalized_traits_vals)
     except ValueError:
-        return
+        return None
 
     if num_overlap > 5:
 
@@ -107,7 +108,7 @@ package :not packaged in guix
 
 
 def filter_shared_sample_keys(this_samplelist,
-                              target_samplelist) -> Tuple[List, List]:
+                              target_samplelist) -> Generator:
     """Given primary and target sample-list for two base and target trait select
     filter the values using the shared keys
 
diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py
index 719c605..984c15a 100644
--- a/gn3/computations/partial_correlations.py
+++ b/gn3/computations/partial_correlations.py
@@ -18,6 +18,7 @@ from gn3.random import random_string
 from gn3.function_helpers import  compose
 from gn3.data_helpers import parse_csv_line
 from gn3.db.traits import export_informative
+from gn3.db.datasets import retrieve_trait_dataset
 from gn3.db.traits import retrieve_trait_info, retrieve_trait_data
 from gn3.db.species import species_name, translate_to_mouse_gene_id
 from gn3.db.correlations import (
@@ -216,7 +217,7 @@ def good_dataset_samples_indexes(
 def partial_correlations_fast(# pylint: disable=[R0913, R0914]
         samples, primary_vals, control_vals, database_filename,
         fetched_correlations, method: str, correlation_type: str) -> Tuple[
-            float, Tuple[float, ...]]:
+            int, Tuple[float, ...]]:
     """
     Computes partial correlation coefficients using data from a CSV file.
 
@@ -257,8 +258,9 @@ def partial_correlations_fast(# pylint: disable=[R0913, R0914]
     ## `correlation_type` parameter
     return len(all_correlations), tuple(
         corr + (
-            (fetched_correlations[corr[0]],) if correlation_type == "literature"
-            else fetched_correlations[corr[0]][0:2])
+            (fetched_correlations[corr[0]],) # type: ignore[index]
+            if correlation_type == "literature"
+            else fetched_correlations[corr[0]][0:2]) # type: ignore[index]
         for idx, corr in enumerate(all_correlations))
 
 def build_data_frame(
@@ -305,11 +307,19 @@ def compute_partial(
             prim for targ, prim in zip(targ_vals, primary_vals)
             if targ is not None]
 
+        if len(primary) < 3:
+            return None
+
+        def __remove_controls_for_target_nones(cont_targ):
+            return tuple(cont for cont, targ in cont_targ if targ is not None)
+
+        conts_targs = tuple(tuple(
+            zip(control, targ_vals)) for control in control_vals)
         datafrm = build_data_frame(
             primary,
-            tuple(targ for targ in targ_vals if targ is not None),
-            tuple(cont for i, cont in enumerate(control_vals)
-                  if target[0][i] is not None))
+            [targ for targ in targ_vals if targ is not None],
+            [__remove_controls_for_target_nones(cont_targ)
+             for cont_targ in conts_targs])
         covariates = "z" if datafrm.shape[1] == 3 else [
             col for col in datafrm.columns if col not in ("x", "y")]
         ppc = pingouin.partial_corr(
@@ -332,13 +342,17 @@ def compute_partial(
             zero_order_corr["r"][0], zero_order_corr["p-val"][0])
 
     return tuple(
-        __compute_trait_info__(target)
-        for target in zip(target_vals, target_names))
+        result for result in (
+            __compute_trait_info__(target)
+            for target in zip(target_vals, target_names))
+        if result is not None)
 
 def partial_correlations_normal(# pylint: disable=R0913
         primary_vals, control_vals, input_trait_gene_id, trait_database,
         data_start_pos: int, db_type: str, method: str) -> Tuple[
-            float, Tuple[float, ...]]:
+            int, Tuple[Union[
+                Tuple[str, int, float, float, float, float], None],
+                       ...]]:#Tuple[float, ...]
     """
     Computes the correlation coefficients.
 
@@ -360,7 +374,7 @@ def partial_correlations_normal(# pylint: disable=R0913
             return tuple(item) + (trait_database[1], trait_database[2])
         return item
 
-    target_trait_names, target_trait_vals = reduce(
+    target_trait_names, target_trait_vals = reduce(# type: ignore[var-annotated]
         lambda acc, item: (acc[0]+(item[0],), acc[1]+(item[data_start_pos:],)),
         trait_database, (tuple(), tuple()))
 
@@ -413,7 +427,7 @@ def partial_corrs(# pylint: disable=[R0913]
         data_start_pos, dataset, method)
 
 def literature_correlation_by_list(
-        conn: Any, species: str, trait_list: Tuple[dict]) -> Tuple[dict]:
+        conn: Any, species: str, trait_list: Tuple[dict]) -> Tuple[dict, ...]:
     """
     This is a migration of the
     `web.webqtl.correlation.CorrelationPage.getLiteratureCorrelationByList`
@@ -473,7 +487,7 @@ def literature_correlation_by_list(
 
 def tissue_correlation_by_list(
         conn: Any, primary_trait_symbol: str, tissue_probeset_freeze_id: int,
-        method: str, trait_list: Tuple[dict]) -> Tuple[dict]:
+        method: str, trait_list: Tuple[dict]) -> Tuple[dict, ...]:
     """
     This is a migration of the
     `web.webqtl.correlation.CorrelationPage.getTissueCorrelationByList`
@@ -496,7 +510,7 @@ def tissue_correlation_by_list(
             primary_trait_value = prim_trait_symbol_value_dict[
                 primary_trait_symbol.lower()]
             gene_symbol_list = tuple(
-                trait for trait in trait_list if "symbol" in trait.keys())
+                trait["symbol"] for trait in trait_list if "symbol" in trait.keys())
             symbol_value_dict = fetch_gene_symbol_tissue_value_dict_for_trait(
                 gene_symbol_list, tissue_probeset_freeze_id, conn)
             return tuple(
@@ -514,6 +528,54 @@ def tissue_correlation_by_list(
         } for trait in trait_list)
     return trait_list
 
+def trait_for_output(trait):
+    """
+    Process a trait for output.
+
+    Removes a lot of extraneous data from the trait, that is not needed for
+    the display of partial correlation results.
+    This function also removes all key-value pairs, for which the value is
+    `None`, because it is a waste of network resources to transmit the key-value
+    pair just to indicate it does not exist.
+    """
+    trait = {
+        "trait_type": trait["trait_type"],
+        "dataset_name": trait["db"]["dataset_name"],
+        "dataset_type": trait["db"]["dataset_type"],
+        "group": trait["db"]["group"],
+        "trait_fullname": trait["trait_fullname"],
+        "trait_name": trait["trait_name"],
+        "symbol": trait.get("symbol"),
+        "description": trait.get("description"),
+        "pre_publication_description": trait.get(
+            "pre_publication_description"),
+        "post_publication_description": trait.get(
+            "post_publication_description"),
+        "original_description": trait.get(
+            "original_description"),
+        "authors": trait.get("authors"),
+        "year": trait.get("year"),
+        "probe_target_description": trait.get(
+            "probe_target_description"),
+        "chr": trait.get("chr"),
+        "mb": trait.get("mb"),
+        "geneid": trait.get("geneid"),
+        "homologeneid": trait.get("homologeneid"),
+        "noverlap": trait.get("noverlap"),
+        "partial_corr": trait.get("partial_corr"),
+        "partial_corr_p_value": trait.get("partial_corr_p_value"),
+        "corr": trait.get("corr"),
+        "corr_p_value": trait.get("corr_p_value"),
+        "rank_order": trait.get("rank_order"),
+        "delta": (
+            None if trait.get("partial_corr") is None
+            else (trait.get("partial_corr") - trait.get("corr"))),
+        "l_corr":  trait.get("l_corr"),
+        "tissue_corr": trait.get("tissue_corr"),
+        "tissue_p_value": trait.get("tissue_p_value")
+    }
+    return {key: val for key, val in trait.items() if val is not None}
+
 def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
         conn: Any, primary_trait_name: str,
         control_trait_names: Tuple[str, ...], method: str,
@@ -640,28 +702,47 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
                 "any associated Tissue Correlation Information."),
             "error_type": "Tissue Correlation"}
 
+    target_dataset = retrieve_trait_dataset(
+        ("Temp" if "Temp" in target_db_name else
+         ("Publish" if "Publish" in target_db_name else
+          "Geno" if "Geno" in target_db_name else "ProbeSet")),
+        {"db": {"dataset_name": target_db_name}, "trait_name": "_"},
+        threshold,
+        conn)
+
     database_filename = get_filename(conn, target_db_name, TEXTDIR)
     _total_traits, all_correlations = partial_corrs(
         conn, common_primary_control_samples, fixed_primary_vals,
         fixed_control_vals, len(fixed_primary_vals), species,
         input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id,
-        method, primary_trait["db"], database_filename)
+        method, {**target_dataset, "dataset_type": target_dataset["type"]}, database_filename)
 
 
     def __make_sorter__(method):
-        def __sort_6__(row):
-            return row[6]
-
-        def __sort_3__(row):
+        def __compare_lit_or_tiss_correlation_values_(row):
+            # Index  Content
+            # 0      trait name
+            # 1      N
+            # 2      partial correlation coefficient
+            # 3      p value of partial correlation
+            # 6      literature/tissue correlation value
+            return (row[6], row[3])
+
+        def __compare_partial_correlation_p_values__(row):
+            # Index  Content
+            # 0      trait name
+            # 1      partial correlation coefficient
+            # 2      N
+            # 3      p value of partial correlation
             return row[3]
 
         if "literature" in method.lower():
-            return __sort_6__
+            return __compare_lit_or_tiss_correlation_values_
 
         if "tissue" in method.lower():
-            return __sort_6__
+            return __compare_lit_or_tiss_correlation_values_
 
-        return __sort_3__
+        return __compare_partial_correlation_p_values__
 
     sorted_correlations = sorted(
         all_correlations, key=__make_sorter__(method))
@@ -676,7 +757,7 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
         {
             **retrieve_trait_info(
                 threshold,
-                f"{primary_trait['db']['dataset_name']}::{item[0]}",
+                f"{target_dataset['dataset_name']}::{item[0]}",
                 conn),
             "noverlap": item[1],
             "partial_corr": item[2],
@@ -694,4 +775,14 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
         for item in
         sorted_correlations[:min(criteria, len(all_correlations))]))
 
-    return trait_list
+    return {
+        "status": "success",
+        "results": {
+            "primary_trait": trait_for_output(primary_trait),
+            "control_traits": tuple(
+                trait_for_output(trait) for trait in cntrl_traits),
+            "correlations": tuple(
+                trait_for_output(trait) for trait in trait_list),
+            "dataset_type": target_dataset["type"],
+            "method": "spearman" if "spearman" in method.lower() else "pearson"
+        }}
diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py
index 254af10..5361a1e 100644
--- a/gn3/db/correlations.py
+++ b/gn3/db/correlations.py
@@ -157,11 +157,12 @@ def fetch_symbol_value_pair_dict(
         symbol: data_id_dict.get(symbol) for symbol in symbol_list
         if data_id_dict.get(symbol) is not None
     }
-    query = "SELECT Id, value FROM TissueProbeSetData WHERE Id IN %(data_ids)s"
+    query = "SELECT Id, value FROM TissueProbeSetData WHERE Id IN ({})".format(
+        ",".join(f"%(id{i})s" for i in range(len(data_ids.values()))))
     with conn.cursor() as cursor:
         cursor.execute(
             query,
-            data_ids=tuple(data_ids.values()))
+            **{f"id{i}": did for i, did in enumerate(data_ids.values())})
         value_results = cursor.fetchall()
         return {
             key: tuple(row[1] for row in value_results if row[0] == key)
@@ -406,21 +407,22 @@ def fetch_sample_ids(
     """
     query = (
         "SELECT Strain.Id FROM Strain, Species "
-        "WHERE Strain.Name IN %(samples_names)s "
+        "WHERE Strain.Name IN ({}) "
         "AND Strain.SpeciesId=Species.Id "
-        "AND Species.name=%(species_name)s")
+        "AND Species.name=%(species_name)s").format(
+            ",".join(f"%(s{i})s" for i in range(len(sample_names))))
     with conn.cursor() as cursor:
         cursor.execute(
             query,
             {
-                "samples_names": tuple(sample_names),
+                **{f"s{i}": sname for i, sname in enumerate(sample_names)},
                 "species_name": species_name
             })
         return tuple(row[0] for row in cursor.fetchall())
 
 def build_query_sgo_lit_corr(
         db_type: str, temp_table: str, sample_id_columns: str,
-        joins: Tuple[str, ...]) -> str:
+        joins: Tuple[str, ...]) -> Tuple[str, int]:
     """
     Build query for `SGO Literature Correlation` data, when querying the given
     `temp_table` temporary table.
@@ -483,14 +485,14 @@ def fetch_all_database_data(# pylint: disable=[R0913, R0914]
         sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in sample_ids)
         if db_type == "Publish":
             joins = tuple(
-                ("LEFT JOIN PublishData AS T{item} "
-                 "ON T{item}.Id = PublishXRef.DataId "
-                 "AND T{item}.StrainId = %(T{item}_sample_id)s")
+                (f"LEFT JOIN PublishData AS T{item} "
+                 f"ON T{item}.Id = PublishXRef.DataId "
+                 f"AND T{item}.StrainId = %(T{item}_sample_id)s")
                 for item in sample_ids)
             return (
                 ("SELECT PublishXRef.Id, " +
                  sample_id_columns +
-                 "FROM (PublishXRef, PublishFreeze) " +
+                 " FROM (PublishXRef, PublishFreeze) " +
                  " ".join(joins) +
                  " WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
                  "AND PublishFreeze.Name = %(db_name)s"),
diff --git a/gn3/db/datasets.py b/gn3/db/datasets.py
index c50e148..a41e228 100644
--- a/gn3/db/datasets.py
+++ b/gn3/db/datasets.py
@@ -3,7 +3,7 @@ This module contains functions relating to specific trait dataset manipulation
 """
 import re
 from string import Template
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
 from SPARQLWrapper import JSON, SPARQLWrapper
 from gn3.settings import SPARQL_ENDPOINT
 
@@ -297,7 +297,7 @@ def retrieve_trait_dataset(trait_type, trait, threshold, conn):
         **group
     }
 
-def sparql_query(query: str) -> Dict[str, Any]:
+def sparql_query(query: str) -> List[Dict[str, Any]]:
     """Run a SPARQL query and return the bound variables."""
     sparql = SPARQLWrapper(SPARQL_ENDPOINT)
     sparql.setQuery(query)
@@ -328,7 +328,7 @@ WHERE {
   OPTIONAL { ?dataset gn:geoSeries ?geo_series } .
 }
 """,
-             """
+               """
 PREFIX gn: <http://genenetwork.org/>
 SELECT ?platform_name ?normalization_name ?species_name ?inbred_set_name ?tissue_name
 WHERE {
@@ -341,7 +341,7 @@ WHERE {
   OPTIONAL { ?dataset gn:datasetOfPlatform / gn:name ?platform_name } .
 }
 """,
-             """
+               """
 PREFIX gn: <http://genenetwork.org/>
 SELECT ?specifics ?summary ?about_cases ?about_tissue ?about_platform
        ?about_data_processing ?notes ?experiment_design ?contributors
@@ -362,8 +362,8 @@ WHERE {
   OPTIONAL { ?dataset gn:acknowledgment ?acknowledgment . }
 }
 """]
-    result = {'accession_id': accession_id,
-              'investigator': {}}
+    result: Dict[str, Any] = {'accession_id': accession_id,
+                              'investigator': {}}
     query_result = {}
     for query in queries:
         if sparql_result := sparql_query(Template(query).substitute(accession_id=accession_id)):
diff --git a/gn3/db/traits.py b/gn3/db/traits.py
index 7994aef..338b320 100644
--- a/gn3/db/traits.py
+++ b/gn3/db/traits.py
@@ -1,6 +1,5 @@
 """This class contains functions relating to trait data manipulation"""
 import os
-import MySQLdb
 from functools import reduce
 from typing import Any, Dict, Union, Sequence
 
@@ -111,7 +110,6 @@ def get_trait_csv_sample_data(conn: Any,
 
 
 def update_sample_data(conn: Any, #pylint: disable=[R0913]
-
                        trait_name: str,
                        strain_name: str,
                        phenotype_id: int,
@@ -204,25 +202,30 @@ def delete_sample_data(conn: Any,
                  "AND Strain.Name = \"%s\"") % (trait_name,
                                                 phenotype_id,
                                                 str(strain_name)))
-            strain_id, data_id = cursor.fetchone()
 
-            cursor.execute(("DELETE FROM PublishData "
+            # Check if it exists if the data was already deleted:
+            if _result := cursor.fetchone():
+                strain_id, data_id = _result
+
+            # Only run if the strain_id and data_id exist
+            if strain_id and data_id:
+                cursor.execute(("DELETE FROM PublishData "
                             "WHERE StrainId = %s AND Id = %s")
-                           % (strain_id, data_id))
-            deleted_published_data = cursor.rowcount
-
-            # Delete the PublishSE table
-            cursor.execute(("DELETE FROM PublishSE "
-                            "WHERE StrainId = %s AND DataId = %s") %
-                           (strain_id, data_id))
-            deleted_se_data = cursor.rowcount
-
-            # Delete the NStrain table
-            cursor.execute(("DELETE FROM NStrain "
-                            "WHERE StrainId = %s AND DataId = %s" %
-                            (strain_id, data_id)))
-            deleted_n_strains = cursor.rowcount
-        except Exception as e: #pylint: disable=[C0103, W0612]
+                               % (strain_id, data_id))
+                deleted_published_data = cursor.rowcount
+
+                # Delete the PublishSE table
+                cursor.execute(("DELETE FROM PublishSE "
+                                "WHERE StrainId = %s AND DataId = %s") %
+                               (strain_id, data_id))
+                deleted_se_data = cursor.rowcount
+
+                # Delete the NStrain table
+                cursor.execute(("DELETE FROM NStrain "
+                                "WHERE StrainId = %s AND DataId = %s" %
+                                (strain_id, data_id)))
+                deleted_n_strains = cursor.rowcount
+        except Exception as e:  #pylint: disable=[C0103, W0612]
             conn.rollback()
             raise MySQLdb.Error
         conn.commit()
@@ -255,6 +258,13 @@ def insert_sample_data(conn: Any, #pylint: disable=[R0913]
                            (strain_name,))
             strain_id = cursor.fetchone()
 
+            # Return early if an insert already exists!
+            cursor.execute("SELECT Id FROM PublishData where Id = %s "
+                           "AND StrainId = %s",
+                           (data_id, strain_id))
+            if cursor.fetchone():  # This strain already exists
+                return (0, 0, 0)
+
             # Insert the PublishData table
             cursor.execute(("INSERT INTO PublishData (Id, StrainId, value)"
                             "VALUES (%s, %s, %s)"),
diff --git a/mypy.ini b/mypy.ini
index b0c48df..ef93008 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -19,4 +19,16 @@ ignore_missing_imports = True
 ignore_missing_imports = True
 
 [mypy-requests.*]
+ignore_missing_imports = True
+
+[mypy-flask.*]
+ignore_missing_imports = True
+
+[mypy-werkzeug.*]
+ignore_missing_imports = True
+
+[mypy-SPARQLWrapper.*]
+ignore_missing_imports = True
+
+[mypy-pandas.*]
 ignore_missing_imports = True
\ No newline at end of file
diff --git a/scripts/wgcna_analysis.R b/scripts/wgcna_analysis.R
index b0d25a9..d368013 100644
--- a/scripts/wgcna_analysis.R
+++ b/scripts/wgcna_analysis.R
@@ -25,6 +25,7 @@ if (length(args)==0) {
 inputData <- fromJSON(file = json_file_path)
 imgDir = inputData$TMPDIR
 
+inputData
 
 trait_sample_data <- do.call(rbind, inputData$trait_sample_data)
 
@@ -51,10 +52,8 @@ names(dataExpr) = inputData$trait_names
 # Allow multi-threading within WGCNA
 enableWGCNAThreads()
 
-# choose softthreshhold (Calculate soft threshold)
-# xtodo allow users to pass args
-
-powers <- c(c(1:10), seq(from = 12, to=20, by=2)) 
+# powers <- c(c(1:10), seq(from = 12, to=20, by=2)) 
+powers <- unlist(c(inputData$SoftThresholds))
 sft <- pickSoftThreshold(dataExpr, powerVector = powers, verbose = 5)
 
 # check the power estimate
diff --git a/tests/unit/computations/test_correlation.py b/tests/unit/computations/test_correlation.py
index 0de347d..7523d99 100644
--- a/tests/unit/computations/test_correlation.py
+++ b/tests/unit/computations/test_correlation.py
@@ -1,7 +1,6 @@
 """Module contains the tests for correlation"""
 from unittest import TestCase
 from unittest import mock
-import unittest
 
 from collections import namedtuple
 import math
diff --git a/tests/unit/db/test_datasets.py b/tests/unit/db/test_datasets.py
index 39f4af9..0b8c2fe 100644
--- a/tests/unit/db/test_datasets.py
+++ b/tests/unit/db/test_datasets.py
@@ -13,15 +13,17 @@ class TestDatasetsDBFunctions(TestCase):
 
     def test_retrieve_dataset_name(self):
         """Test that the function is called correctly."""
-        for trait_type, thresh, trait_name, dataset_name, columns, table in [
+        for trait_type, thresh, trait_name, dataset_name, columns, table, expected in [
                 ["ProbeSet", 9, "probesetTraitName", "probesetDatasetName",
-                 "Id, Name, FullName, ShortName, DataScale", "ProbeSetFreeze"],
+                 "Id, Name, FullName, ShortName, DataScale", "ProbeSetFreeze",
+                 {"dataset_id": None, "dataset_name": "probesetDatasetName",
+                  "dataset_fullname": "probesetDatasetName"}],
                 ["Geno", 3, "genoTraitName", "genoDatasetName",
-                 "Id, Name, FullName, ShortName", "GenoFreeze"],
+                 "Id, Name, FullName, ShortName", "GenoFreeze", {}],
                 ["Publish", 6, "publishTraitName", "publishDatasetName",
-                 "Id, Name, FullName, ShortName", "PublishFreeze"],
+                 "Id, Name, FullName, ShortName", "PublishFreeze", {}],
                 ["Temp", 4, "tempTraitName", "tempTraitName",
-                 "Id, Name, FullName, ShortName", "TempFreeze"]]:
+                 "Id, Name, FullName, ShortName", "TempFreeze", {}]]:
             db_mock = mock.MagicMock()
             with self.subTest(trait_type=trait_type):
                 with db_mock.cursor() as cursor:
@@ -29,7 +31,7 @@ class TestDatasetsDBFunctions(TestCase):
                     self.assertEqual(
                         retrieve_dataset_name(
                             trait_type, thresh, trait_name, dataset_name, db_mock),
-                        {})
+                        expected)
                     cursor.execute.assert_called_once_with(
                         "SELECT {cols} "
                         "FROM {table} "
diff --git a/tests/unit/db/test_traits.py b/tests/unit/db/test_traits.py
index 4aa9389..75f3d4c 100644
--- a/tests/unit/db/test_traits.py
+++ b/tests/unit/db/test_traits.py
@@ -202,8 +202,6 @@ class TestTraitsDBFunctions(TestCase):
         """
         # pylint: disable=C0103
         db_mock = mock.MagicMock()
-
-        STRAIN_ID_SQL: str = "UPDATE Strain SET Name = %s WHERE Id = %s"
         PUBLISH_DATA_SQL: str = (
             "UPDATE PublishData SET value = %s "
             "WHERE StrainId = %s AND Id = %s")
@@ -216,16 +214,33 @@ class TestTraitsDBFunctions(TestCase):
 
         with db_mock.cursor() as cursor:
             type(cursor).rowcount = 1
+            mock_fetchone = mock.MagicMock()
+            mock_fetchone.return_value = (1, 1)
+            type(cursor).fetchone = mock_fetchone
             self.assertEqual(update_sample_data(
                 conn=db_mock, strain_name="BXD11",
-                strain_id=10, publish_data_id=8967049,
-                value=18.7, error=2.3, count=2),
-                             (1, 1, 1, 1))
+                trait_name="1",
+                phenotype_id=10, value=18.7,
+                error=2.3, count=2),
+                             (1, 1, 1))
             cursor.execute.assert_has_calls(
-                [mock.call(STRAIN_ID_SQL, ('BXD11', 10)),
-                 mock.call(PUBLISH_DATA_SQL, (18.7, 10, 8967049)),
-                 mock.call(PUBLISH_SE_SQL, (2.3, 10, 8967049)),
-                 mock.call(N_STRAIN_SQL, (2, 10, 8967049))]
+                [mock.call('SELECT Strain.Id, PublishData.Id FROM'
+                           ' (PublishData, Strain, PublishXRef, '
+                           'PublishFreeze) LEFT JOIN PublishSE ON '
+                           '(PublishSE.DataId = PublishData.Id '
+                           'AND PublishSE.StrainId = '
+                           'PublishData.StrainId) LEFT JOIN NStrain ON '
+                           '(NStrain.DataId = PublishData.Id AND '
+                           'NStrain.StrainId = PublishData.StrainId) WHERE '
+                           'PublishXRef.InbredSetId = '
+                           'PublishFreeze.InbredSetId AND PublishData.Id = '
+                           'PublishXRef.DataId AND PublishXRef.Id = 1 AND '
+                           'PublishXRef.PhenotypeId = 10 AND '
+                           'PublishData.StrainId = Strain.Id AND '
+                           'Strain.Name = "BXD11"'),
+                 mock.call(PUBLISH_DATA_SQL, (18.7, 1, 1)),
+                 mock.call(PUBLISH_SE_SQL, (2.3, 1, 1)),
+                 mock.call(N_STRAIN_SQL, (2, 1, 1))]
             )
 
     def test_set_haveinfo_field(self):