aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.guix_deploy8
-rw-r--r--.pylintrc6
-rw-r--r--README.md3
-rw-r--r--gn3/api/correlation.py31
-rw-r--r--gn3/authentication.py1
-rw-r--r--gn3/computations/correlations.py5
-rw-r--r--gn3/computations/partial_correlations.py137
-rw-r--r--gn3/db/correlations.py22
-rw-r--r--gn3/db/datasets.py12
-rw-r--r--gn3/db/traits.py48
-rw-r--r--mypy.ini12
-rw-r--r--scripts/wgcna_analysis.R7
-rw-r--r--tests/unit/computations/test_correlation.py1
-rw-r--r--tests/unit/db/test_datasets.py14
-rw-r--r--tests/unit/db/test_traits.py33
15 files changed, 258 insertions, 82 deletions
diff --git a/.guix_deploy b/.guix_deploy
new file mode 100644
index 0000000..c7bbb5b
--- /dev/null
+++ b/.guix_deploy
@@ -0,0 +1,8 @@
+# Deploy script on tux01
+#
+# echo Run tests:
+# echo python -m unittest discover -v
+# echo Run service (single process):
+# echo flask run --port=8080
+
+/home/wrk/opt/guix-pull/bin/guix shell -L /home/wrk/guix-bioinformatics/ --expose=$HOME/production/genotype_files/ -C -N -Df guix.scm
diff --git a/.pylintrc b/.pylintrc
index 0bdef23..00dd6cd 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -1,3 +1,7 @@
[SIMILARITIES]
-ignore-imports=yes \ No newline at end of file
+ignore-imports=yes
+
+[MESSAGES CONTROL]
+
+disable=fixme \ No newline at end of file
diff --git a/README.md b/README.md
index d3470ee..5669192 100644
--- a/README.md
+++ b/README.md
@@ -38,7 +38,6 @@ python3
guix shell -C --network --expose=$HOME/genotype_files/ -Df guix.scm
```
-
#### Using a Guix profile (or rolling back)
Create a new profile with
@@ -128,6 +127,8 @@ And for the scalable production version run
gunicorn --bind 0.0.0.0:8080 --workers 8 --keep-alive 6000 --max-requests 10 --max-requests-jitter 5 --timeout 1200 wsgi:app
```
+(see also the [.guix_deploy](./.guix_deploy) script)
+
## Using python-pip
IMPORTANT NOTE: we do not recommend using pip tools, use Guix instead
diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py
index 46121f8..1caf31f 100644
--- a/gn3/api/correlation.py
+++ b/gn3/api/correlation.py
@@ -1,13 +1,16 @@
"""Endpoints for running correlations"""
+import json
from flask import jsonify
from flask import Blueprint
from flask import request
+from flask import make_response
from gn3.computations.correlations import compute_all_sample_correlation
from gn3.computations.correlations import compute_all_lit_correlation
from gn3.computations.correlations import compute_tissue_correlation
from gn3.computations.correlations import map_shared_keys_to_values
from gn3.db_utils import database_connector
+from gn3.computations.partial_correlations import partial_correlations_entry
correlation = Blueprint("correlation", __name__)
@@ -83,3 +86,31 @@ def compute_tissue_corr(corr_method="pearson"):
corr_method=corr_method)
return jsonify(results)
+
+@correlation.route("/partial", methods=["POST"])
+def partial_correlation():
+ """API endpoint for partial correlations."""
+ def trait_fullname(trait):
+ return f"{trait['dataset']}::{trait['name']}"
+
+ class OutputEncoder(json.JSONEncoder):
+ """
+ Class to encode output into JSON, for objects which the default
+ json.JSONEncoder class does not have default encoding for.
+ """
+ def default(self, obj):
+ if isinstance(obj, bytes):
+ return str(obj, encoding="utf-8")
+ return json.JSONEncoder.default(self, obj)
+
+ args = request.get_json()
+ conn, _cursor_object = database_connector()
+ corr_results = partial_correlations_entry(
+ conn, trait_fullname(args["primary_trait"]),
+ tuple(trait_fullname(trait) for trait in args["control_traits"]),
+ args["method"], int(args["criteria"]), args["target_db"])
+ response = make_response(
+ json.dumps(corr_results, cls=OutputEncoder).replace(": NaN", ": null"),
+ 400 if "error" in corr_results.keys() else 200)
+ response.headers["Content-Type"] = "application/json"
+ return response
diff --git a/gn3/authentication.py b/gn3/authentication.py
index a6372c1..d0b35bc 100644
--- a/gn3/authentication.py
+++ b/gn3/authentication.py
@@ -163,3 +163,4 @@ def create_group(conn: Redis, group_name: Optional[str],
}
conn.hset("groups", group_id, json.dumps(group))
return group
+ return None
diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py
index 37c70e9..09288c5 100644
--- a/gn3/computations/correlations.py
+++ b/gn3/computations/correlations.py
@@ -8,6 +8,7 @@ from typing import List
from typing import Tuple
from typing import Optional
from typing import Callable
+from typing import Generator
import scipy.stats
import pingouin as pg
@@ -80,7 +81,7 @@ def compute_sample_r_correlation(trait_name, corr_method, trait_vals,
zip(*list(normalize_values(trait_vals, target_samples_vals))))
num_overlap = len(normalized_traits_vals)
except ValueError:
- return
+ return None
if num_overlap > 5:
@@ -107,7 +108,7 @@ package :not packaged in guix
def filter_shared_sample_keys(this_samplelist,
- target_samplelist) -> Tuple[List, List]:
+ target_samplelist) -> Generator:
"""Given primary and target sample-list for two base and target trait select
filter the values using the shared keys
diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py
index 719c605..984c15a 100644
--- a/gn3/computations/partial_correlations.py
+++ b/gn3/computations/partial_correlations.py
@@ -18,6 +18,7 @@ from gn3.random import random_string
from gn3.function_helpers import compose
from gn3.data_helpers import parse_csv_line
from gn3.db.traits import export_informative
+from gn3.db.datasets import retrieve_trait_dataset
from gn3.db.traits import retrieve_trait_info, retrieve_trait_data
from gn3.db.species import species_name, translate_to_mouse_gene_id
from gn3.db.correlations import (
@@ -216,7 +217,7 @@ def good_dataset_samples_indexes(
def partial_correlations_fast(# pylint: disable=[R0913, R0914]
samples, primary_vals, control_vals, database_filename,
fetched_correlations, method: str, correlation_type: str) -> Tuple[
- float, Tuple[float, ...]]:
+ int, Tuple[float, ...]]:
"""
Computes partial correlation coefficients using data from a CSV file.
@@ -257,8 +258,9 @@ def partial_correlations_fast(# pylint: disable=[R0913, R0914]
## `correlation_type` parameter
return len(all_correlations), tuple(
corr + (
- (fetched_correlations[corr[0]],) if correlation_type == "literature"
- else fetched_correlations[corr[0]][0:2])
+ (fetched_correlations[corr[0]],) # type: ignore[index]
+ if correlation_type == "literature"
+ else fetched_correlations[corr[0]][0:2]) # type: ignore[index]
for idx, corr in enumerate(all_correlations))
def build_data_frame(
@@ -305,11 +307,19 @@ def compute_partial(
prim for targ, prim in zip(targ_vals, primary_vals)
if targ is not None]
+ if len(primary) < 3:
+ return None
+
+ def __remove_controls_for_target_nones(cont_targ):
+ return tuple(cont for cont, targ in cont_targ if targ is not None)
+
+ conts_targs = tuple(tuple(
+ zip(control, targ_vals)) for control in control_vals)
datafrm = build_data_frame(
primary,
- tuple(targ for targ in targ_vals if targ is not None),
- tuple(cont for i, cont in enumerate(control_vals)
- if target[0][i] is not None))
+ [targ for targ in targ_vals if targ is not None],
+ [__remove_controls_for_target_nones(cont_targ)
+ for cont_targ in conts_targs])
covariates = "z" if datafrm.shape[1] == 3 else [
col for col in datafrm.columns if col not in ("x", "y")]
ppc = pingouin.partial_corr(
@@ -332,13 +342,17 @@ def compute_partial(
zero_order_corr["r"][0], zero_order_corr["p-val"][0])
return tuple(
- __compute_trait_info__(target)
- for target in zip(target_vals, target_names))
+ result for result in (
+ __compute_trait_info__(target)
+ for target in zip(target_vals, target_names))
+ if result is not None)
def partial_correlations_normal(# pylint: disable=R0913
primary_vals, control_vals, input_trait_gene_id, trait_database,
data_start_pos: int, db_type: str, method: str) -> Tuple[
- float, Tuple[float, ...]]:
+ int, Tuple[Union[
+ Tuple[str, int, float, float, float, float], None],
+ ...]]:#Tuple[float, ...]
"""
Computes the correlation coefficients.
@@ -360,7 +374,7 @@ def partial_correlations_normal(# pylint: disable=R0913
return tuple(item) + (trait_database[1], trait_database[2])
return item
- target_trait_names, target_trait_vals = reduce(
+ target_trait_names, target_trait_vals = reduce(# type: ignore[var-annotated]
lambda acc, item: (acc[0]+(item[0],), acc[1]+(item[data_start_pos:],)),
trait_database, (tuple(), tuple()))
@@ -413,7 +427,7 @@ def partial_corrs(# pylint: disable=[R0913]
data_start_pos, dataset, method)
def literature_correlation_by_list(
- conn: Any, species: str, trait_list: Tuple[dict]) -> Tuple[dict]:
+ conn: Any, species: str, trait_list: Tuple[dict]) -> Tuple[dict, ...]:
"""
This is a migration of the
`web.webqtl.correlation.CorrelationPage.getLiteratureCorrelationByList`
@@ -473,7 +487,7 @@ def literature_correlation_by_list(
def tissue_correlation_by_list(
conn: Any, primary_trait_symbol: str, tissue_probeset_freeze_id: int,
- method: str, trait_list: Tuple[dict]) -> Tuple[dict]:
+ method: str, trait_list: Tuple[dict]) -> Tuple[dict, ...]:
"""
This is a migration of the
`web.webqtl.correlation.CorrelationPage.getTissueCorrelationByList`
@@ -496,7 +510,7 @@ def tissue_correlation_by_list(
primary_trait_value = prim_trait_symbol_value_dict[
primary_trait_symbol.lower()]
gene_symbol_list = tuple(
- trait for trait in trait_list if "symbol" in trait.keys())
+ trait["symbol"] for trait in trait_list if "symbol" in trait.keys())
symbol_value_dict = fetch_gene_symbol_tissue_value_dict_for_trait(
gene_symbol_list, tissue_probeset_freeze_id, conn)
return tuple(
@@ -514,6 +528,54 @@ def tissue_correlation_by_list(
} for trait in trait_list)
return trait_list
+def trait_for_output(trait):
+ """
+ Process a trait for output.
+
+ Removes a lot of extraneous data from the trait, that is not needed for
+ the display of partial correlation results.
+ This function also removes all key-value pairs, for which the value is
+ `None`, because it is a waste of network resources to transmit the key-value
+ pair just to indicate it does not exist.
+ """
+ trait = {
+ "trait_type": trait["trait_type"],
+ "dataset_name": trait["db"]["dataset_name"],
+ "dataset_type": trait["db"]["dataset_type"],
+ "group": trait["db"]["group"],
+ "trait_fullname": trait["trait_fullname"],
+ "trait_name": trait["trait_name"],
+ "symbol": trait.get("symbol"),
+ "description": trait.get("description"),
+ "pre_publication_description": trait.get(
+ "pre_publication_description"),
+ "post_publication_description": trait.get(
+ "post_publication_description"),
+ "original_description": trait.get(
+ "original_description"),
+ "authors": trait.get("authors"),
+ "year": trait.get("year"),
+ "probe_target_description": trait.get(
+ "probe_target_description"),
+ "chr": trait.get("chr"),
+ "mb": trait.get("mb"),
+ "geneid": trait.get("geneid"),
+ "homologeneid": trait.get("homologeneid"),
+ "noverlap": trait.get("noverlap"),
+ "partial_corr": trait.get("partial_corr"),
+ "partial_corr_p_value": trait.get("partial_corr_p_value"),
+ "corr": trait.get("corr"),
+ "corr_p_value": trait.get("corr_p_value"),
+ "rank_order": trait.get("rank_order"),
+ "delta": (
+ None if trait.get("partial_corr") is None
+ else (trait.get("partial_corr") - trait.get("corr"))),
+ "l_corr": trait.get("l_corr"),
+ "tissue_corr": trait.get("tissue_corr"),
+ "tissue_p_value": trait.get("tissue_p_value")
+ }
+ return {key: val for key, val in trait.items() if val is not None}
+
def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
conn: Any, primary_trait_name: str,
control_trait_names: Tuple[str, ...], method: str,
@@ -640,28 +702,47 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
"any associated Tissue Correlation Information."),
"error_type": "Tissue Correlation"}
+ target_dataset = retrieve_trait_dataset(
+ ("Temp" if "Temp" in target_db_name else
+ ("Publish" if "Publish" in target_db_name else
+ "Geno" if "Geno" in target_db_name else "ProbeSet")),
+ {"db": {"dataset_name": target_db_name}, "trait_name": "_"},
+ threshold,
+ conn)
+
database_filename = get_filename(conn, target_db_name, TEXTDIR)
_total_traits, all_correlations = partial_corrs(
conn, common_primary_control_samples, fixed_primary_vals,
fixed_control_vals, len(fixed_primary_vals), species,
input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id,
- method, primary_trait["db"], database_filename)
+ method, {**target_dataset, "dataset_type": target_dataset["type"]}, database_filename)
def __make_sorter__(method):
- def __sort_6__(row):
- return row[6]
-
- def __sort_3__(row):
+ def __compare_lit_or_tiss_correlation_values_(row):
+ # Index Content
+ # 0 trait name
+ # 1 N
+ # 2 partial correlation coefficient
+ # 3 p value of partial correlation
+ # 6 literature/tissue correlation value
+ return (row[6], row[3])
+
+ def __compare_partial_correlation_p_values__(row):
+ # Index Content
+ # 0 trait name
+ # 1 partial correlation coefficient
+ # 2 N
+ # 3 p value of partial correlation
return row[3]
if "literature" in method.lower():
- return __sort_6__
+ return __compare_lit_or_tiss_correlation_values_
if "tissue" in method.lower():
- return __sort_6__
+ return __compare_lit_or_tiss_correlation_values_
- return __sort_3__
+ return __compare_partial_correlation_p_values__
sorted_correlations = sorted(
all_correlations, key=__make_sorter__(method))
@@ -676,7 +757,7 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
{
**retrieve_trait_info(
threshold,
- f"{primary_trait['db']['dataset_name']}::{item[0]}",
+ f"{target_dataset['dataset_name']}::{item[0]}",
conn),
"noverlap": item[1],
"partial_corr": item[2],
@@ -694,4 +775,14 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
for item in
sorted_correlations[:min(criteria, len(all_correlations))]))
- return trait_list
+ return {
+ "status": "success",
+ "results": {
+ "primary_trait": trait_for_output(primary_trait),
+ "control_traits": tuple(
+ trait_for_output(trait) for trait in cntrl_traits),
+ "correlations": tuple(
+ trait_for_output(trait) for trait in trait_list),
+ "dataset_type": target_dataset["type"],
+ "method": "spearman" if "spearman" in method.lower() else "pearson"
+ }}
diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py
index 254af10..5361a1e 100644
--- a/gn3/db/correlations.py
+++ b/gn3/db/correlations.py
@@ -157,11 +157,12 @@ def fetch_symbol_value_pair_dict(
symbol: data_id_dict.get(symbol) for symbol in symbol_list
if data_id_dict.get(symbol) is not None
}
- query = "SELECT Id, value FROM TissueProbeSetData WHERE Id IN %(data_ids)s"
+ query = "SELECT Id, value FROM TissueProbeSetData WHERE Id IN ({})".format(
+ ",".join(f"%(id{i})s" for i in range(len(data_ids.values()))))
with conn.cursor() as cursor:
cursor.execute(
query,
- data_ids=tuple(data_ids.values()))
+ **{f"id{i}": did for i, did in enumerate(data_ids.values())})
value_results = cursor.fetchall()
return {
key: tuple(row[1] for row in value_results if row[0] == key)
@@ -406,21 +407,22 @@ def fetch_sample_ids(
"""
query = (
"SELECT Strain.Id FROM Strain, Species "
- "WHERE Strain.Name IN %(samples_names)s "
+ "WHERE Strain.Name IN ({}) "
"AND Strain.SpeciesId=Species.Id "
- "AND Species.name=%(species_name)s")
+ "AND Species.name=%(species_name)s").format(
+ ",".join(f"%(s{i})s" for i in range(len(sample_names))))
with conn.cursor() as cursor:
cursor.execute(
query,
{
- "samples_names": tuple(sample_names),
+ **{f"s{i}": sname for i, sname in enumerate(sample_names)},
"species_name": species_name
})
return tuple(row[0] for row in cursor.fetchall())
def build_query_sgo_lit_corr(
db_type: str, temp_table: str, sample_id_columns: str,
- joins: Tuple[str, ...]) -> str:
+ joins: Tuple[str, ...]) -> Tuple[str, int]:
"""
Build query for `SGO Literature Correlation` data, when querying the given
`temp_table` temporary table.
@@ -483,14 +485,14 @@ def fetch_all_database_data(# pylint: disable=[R0913, R0914]
sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in sample_ids)
if db_type == "Publish":
joins = tuple(
- ("LEFT JOIN PublishData AS T{item} "
- "ON T{item}.Id = PublishXRef.DataId "
- "AND T{item}.StrainId = %(T{item}_sample_id)s")
+ (f"LEFT JOIN PublishData AS T{item} "
+ f"ON T{item}.Id = PublishXRef.DataId "
+ f"AND T{item}.StrainId = %(T{item}_sample_id)s")
for item in sample_ids)
return (
("SELECT PublishXRef.Id, " +
sample_id_columns +
- "FROM (PublishXRef, PublishFreeze) " +
+ " FROM (PublishXRef, PublishFreeze) " +
" ".join(joins) +
" WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
"AND PublishFreeze.Name = %(db_name)s"),
diff --git a/gn3/db/datasets.py b/gn3/db/datasets.py
index c50e148..a41e228 100644
--- a/gn3/db/datasets.py
+++ b/gn3/db/datasets.py
@@ -3,7 +3,7 @@ This module contains functions relating to specific trait dataset manipulation
"""
import re
from string import Template
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
from SPARQLWrapper import JSON, SPARQLWrapper
from gn3.settings import SPARQL_ENDPOINT
@@ -297,7 +297,7 @@ def retrieve_trait_dataset(trait_type, trait, threshold, conn):
**group
}
-def sparql_query(query: str) -> Dict[str, Any]:
+def sparql_query(query: str) -> List[Dict[str, Any]]:
"""Run a SPARQL query and return the bound variables."""
sparql = SPARQLWrapper(SPARQL_ENDPOINT)
sparql.setQuery(query)
@@ -328,7 +328,7 @@ WHERE {
OPTIONAL { ?dataset gn:geoSeries ?geo_series } .
}
""",
- """
+ """
PREFIX gn: <http://genenetwork.org/>
SELECT ?platform_name ?normalization_name ?species_name ?inbred_set_name ?tissue_name
WHERE {
@@ -341,7 +341,7 @@ WHERE {
OPTIONAL { ?dataset gn:datasetOfPlatform / gn:name ?platform_name } .
}
""",
- """
+ """
PREFIX gn: <http://genenetwork.org/>
SELECT ?specifics ?summary ?about_cases ?about_tissue ?about_platform
?about_data_processing ?notes ?experiment_design ?contributors
@@ -362,8 +362,8 @@ WHERE {
OPTIONAL { ?dataset gn:acknowledgment ?acknowledgment . }
}
"""]
- result = {'accession_id': accession_id,
- 'investigator': {}}
+ result: Dict[str, Any] = {'accession_id': accession_id,
+ 'investigator': {}}
query_result = {}
for query in queries:
if sparql_result := sparql_query(Template(query).substitute(accession_id=accession_id)):
diff --git a/gn3/db/traits.py b/gn3/db/traits.py
index 7994aef..338b320 100644
--- a/gn3/db/traits.py
+++ b/gn3/db/traits.py
@@ -1,6 +1,5 @@
"""This class contains functions relating to trait data manipulation"""
import os
-import MySQLdb
from functools import reduce
from typing import Any, Dict, Union, Sequence
@@ -111,7 +110,6 @@ def get_trait_csv_sample_data(conn: Any,
def update_sample_data(conn: Any, #pylint: disable=[R0913]
-
trait_name: str,
strain_name: str,
phenotype_id: int,
@@ -204,25 +202,30 @@ def delete_sample_data(conn: Any,
"AND Strain.Name = \"%s\"") % (trait_name,
phenotype_id,
str(strain_name)))
- strain_id, data_id = cursor.fetchone()
- cursor.execute(("DELETE FROM PublishData "
+ # Check if it exists if the data was already deleted:
+ if _result := cursor.fetchone():
+ strain_id, data_id = _result
+
+ # Only run if the strain_id and data_id exist
+ if strain_id and data_id:
+ cursor.execute(("DELETE FROM PublishData "
"WHERE StrainId = %s AND Id = %s")
- % (strain_id, data_id))
- deleted_published_data = cursor.rowcount
-
- # Delete the PublishSE table
- cursor.execute(("DELETE FROM PublishSE "
- "WHERE StrainId = %s AND DataId = %s") %
- (strain_id, data_id))
- deleted_se_data = cursor.rowcount
-
- # Delete the NStrain table
- cursor.execute(("DELETE FROM NStrain "
- "WHERE StrainId = %s AND DataId = %s" %
- (strain_id, data_id)))
- deleted_n_strains = cursor.rowcount
- except Exception as e: #pylint: disable=[C0103, W0612]
+ % (strain_id, data_id))
+ deleted_published_data = cursor.rowcount
+
+ # Delete the PublishSE table
+ cursor.execute(("DELETE FROM PublishSE "
+ "WHERE StrainId = %s AND DataId = %s") %
+ (strain_id, data_id))
+ deleted_se_data = cursor.rowcount
+
+ # Delete the NStrain table
+ cursor.execute(("DELETE FROM NStrain "
+ "WHERE StrainId = %s AND DataId = %s" %
+ (strain_id, data_id)))
+ deleted_n_strains = cursor.rowcount
+ except Exception as e: #pylint: disable=[C0103, W0612]
conn.rollback()
raise MySQLdb.Error
conn.commit()
@@ -255,6 +258,13 @@ def insert_sample_data(conn: Any, #pylint: disable=[R0913]
(strain_name,))
strain_id = cursor.fetchone()
+ # Return early if an insert already exists!
+ cursor.execute("SELECT Id FROM PublishData where Id = %s "
+ "AND StrainId = %s",
+ (data_id, strain_id))
+ if cursor.fetchone(): # This strain already exists
+ return (0, 0, 0)
+
# Insert the PublishData table
cursor.execute(("INSERT INTO PublishData (Id, StrainId, value)"
"VALUES (%s, %s, %s)"),
diff --git a/mypy.ini b/mypy.ini
index b0c48df..ef93008 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -19,4 +19,16 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-requests.*]
+ignore_missing_imports = True
+
+[mypy-flask.*]
+ignore_missing_imports = True
+
+[mypy-werkzeug.*]
+ignore_missing_imports = True
+
+[mypy-SPARQLWrapper.*]
+ignore_missing_imports = True
+
+[mypy-pandas.*]
ignore_missing_imports = True \ No newline at end of file
diff --git a/scripts/wgcna_analysis.R b/scripts/wgcna_analysis.R
index b0d25a9..d368013 100644
--- a/scripts/wgcna_analysis.R
+++ b/scripts/wgcna_analysis.R
@@ -25,6 +25,7 @@ if (length(args)==0) {
inputData <- fromJSON(file = json_file_path)
imgDir = inputData$TMPDIR
+inputData
trait_sample_data <- do.call(rbind, inputData$trait_sample_data)
@@ -51,10 +52,8 @@ names(dataExpr) = inputData$trait_names
# Allow multi-threading within WGCNA
enableWGCNAThreads()
-# choose softthreshhold (Calculate soft threshold)
-# xtodo allow users to pass args
-
-powers <- c(c(1:10), seq(from = 12, to=20, by=2))
+# powers <- c(c(1:10), seq(from = 12, to=20, by=2))
+powers <- unlist(c(inputData$SoftThresholds))
sft <- pickSoftThreshold(dataExpr, powerVector = powers, verbose = 5)
# check the power estimate
diff --git a/tests/unit/computations/test_correlation.py b/tests/unit/computations/test_correlation.py
index 0de347d..7523d99 100644
--- a/tests/unit/computations/test_correlation.py
+++ b/tests/unit/computations/test_correlation.py
@@ -1,7 +1,6 @@
"""Module contains the tests for correlation"""
from unittest import TestCase
from unittest import mock
-import unittest
from collections import namedtuple
import math
diff --git a/tests/unit/db/test_datasets.py b/tests/unit/db/test_datasets.py
index 39f4af9..0b8c2fe 100644
--- a/tests/unit/db/test_datasets.py
+++ b/tests/unit/db/test_datasets.py
@@ -13,15 +13,17 @@ class TestDatasetsDBFunctions(TestCase):
def test_retrieve_dataset_name(self):
"""Test that the function is called correctly."""
- for trait_type, thresh, trait_name, dataset_name, columns, table in [
+ for trait_type, thresh, trait_name, dataset_name, columns, table, expected in [
["ProbeSet", 9, "probesetTraitName", "probesetDatasetName",
- "Id, Name, FullName, ShortName, DataScale", "ProbeSetFreeze"],
+ "Id, Name, FullName, ShortName, DataScale", "ProbeSetFreeze",
+ {"dataset_id": None, "dataset_name": "probesetDatasetName",
+ "dataset_fullname": "probesetDatasetName"}],
["Geno", 3, "genoTraitName", "genoDatasetName",
- "Id, Name, FullName, ShortName", "GenoFreeze"],
+ "Id, Name, FullName, ShortName", "GenoFreeze", {}],
["Publish", 6, "publishTraitName", "publishDatasetName",
- "Id, Name, FullName, ShortName", "PublishFreeze"],
+ "Id, Name, FullName, ShortName", "PublishFreeze", {}],
["Temp", 4, "tempTraitName", "tempTraitName",
- "Id, Name, FullName, ShortName", "TempFreeze"]]:
+ "Id, Name, FullName, ShortName", "TempFreeze", {}]]:
db_mock = mock.MagicMock()
with self.subTest(trait_type=trait_type):
with db_mock.cursor() as cursor:
@@ -29,7 +31,7 @@ class TestDatasetsDBFunctions(TestCase):
self.assertEqual(
retrieve_dataset_name(
trait_type, thresh, trait_name, dataset_name, db_mock),
- {})
+ expected)
cursor.execute.assert_called_once_with(
"SELECT {cols} "
"FROM {table} "
diff --git a/tests/unit/db/test_traits.py b/tests/unit/db/test_traits.py
index 4aa9389..75f3d4c 100644
--- a/tests/unit/db/test_traits.py
+++ b/tests/unit/db/test_traits.py
@@ -202,8 +202,6 @@ class TestTraitsDBFunctions(TestCase):
"""
# pylint: disable=C0103
db_mock = mock.MagicMock()
-
- STRAIN_ID_SQL: str = "UPDATE Strain SET Name = %s WHERE Id = %s"
PUBLISH_DATA_SQL: str = (
"UPDATE PublishData SET value = %s "
"WHERE StrainId = %s AND Id = %s")
@@ -216,16 +214,33 @@ class TestTraitsDBFunctions(TestCase):
with db_mock.cursor() as cursor:
type(cursor).rowcount = 1
+ mock_fetchone = mock.MagicMock()
+ mock_fetchone.return_value = (1, 1)
+ type(cursor).fetchone = mock_fetchone
self.assertEqual(update_sample_data(
conn=db_mock, strain_name="BXD11",
- strain_id=10, publish_data_id=8967049,
- value=18.7, error=2.3, count=2),
- (1, 1, 1, 1))
+ trait_name="1",
+ phenotype_id=10, value=18.7,
+ error=2.3, count=2),
+ (1, 1, 1))
cursor.execute.assert_has_calls(
- [mock.call(STRAIN_ID_SQL, ('BXD11', 10)),
- mock.call(PUBLISH_DATA_SQL, (18.7, 10, 8967049)),
- mock.call(PUBLISH_SE_SQL, (2.3, 10, 8967049)),
- mock.call(N_STRAIN_SQL, (2, 10, 8967049))]
+ [mock.call('SELECT Strain.Id, PublishData.Id FROM'
+ ' (PublishData, Strain, PublishXRef, '
+ 'PublishFreeze) LEFT JOIN PublishSE ON '
+ '(PublishSE.DataId = PublishData.Id '
+ 'AND PublishSE.StrainId = '
+ 'PublishData.StrainId) LEFT JOIN NStrain ON '
+ '(NStrain.DataId = PublishData.Id AND '
+ 'NStrain.StrainId = PublishData.StrainId) WHERE '
+ 'PublishXRef.InbredSetId = '
+ 'PublishFreeze.InbredSetId AND PublishData.Id = '
+ 'PublishXRef.DataId AND PublishXRef.Id = 1 AND '
+ 'PublishXRef.PhenotypeId = 10 AND '
+ 'PublishData.StrainId = Strain.Id AND '
+ 'Strain.Name = "BXD11"'),
+ mock.call(PUBLISH_DATA_SQL, (18.7, 1, 1)),
+ mock.call(PUBLISH_SE_SQL, (2.3, 1, 1)),
+ mock.call(N_STRAIN_SQL, (2, 1, 1))]
)
def test_set_haveinfo_field(self):