aboutsummaryrefslogtreecommitdiff
path: root/gn3
diff options
context:
space:
mode:
Diffstat (limited to 'gn3')
-rw-r--r--gn3/api/heatmaps.py6
-rw-r--r--gn3/app.py13
-rw-r--r--gn3/authentication.py165
-rw-r--r--gn3/computations/biweight.py27
-rw-r--r--gn3/computations/correlations.py41
-rw-r--r--gn3/computations/correlations2.py36
-rw-r--r--gn3/computations/partial_correlations.py696
-rw-r--r--gn3/computations/wgcna.py49
-rw-r--r--gn3/data_helpers.py52
-rw-r--r--gn3/db/correlations.py564
-rw-r--r--gn3/db/species.py44
-rw-r--r--gn3/db/traits.py361
-rw-r--r--gn3/heatmaps.py202
-rw-r--r--gn3/settings.py21
14 files changed, 2001 insertions, 276 deletions
diff --git a/gn3/api/heatmaps.py b/gn3/api/heatmaps.py
index 62ca2ad..633a061 100644
--- a/gn3/api/heatmaps.py
+++ b/gn3/api/heatmaps.py
@@ -17,7 +17,9 @@ def clustered_heatmaps():
Parses the incoming data and responds with the JSON-serialized plotly figure
representing the clustered heatmap.
"""
- traits_names = request.get_json().get("traits_names", tuple())
+ heatmap_request = request.get_json()
+ traits_names = heatmap_request.get("traits_names", tuple())
+ vertical = heatmap_request.get("vertical", False)
if len(traits_names) < 2:
return jsonify({
"message": "You need to provide at least two trait names."
@@ -30,7 +32,7 @@ def clustered_heatmaps():
traits_fullnames = [parse_trait_fullname(trait) for trait in traits_names]
with io.StringIO() as io_str:
- _filename, figure = build_heatmap(traits_fullnames, conn)
+ figure = build_heatmap(traits_fullnames, conn, vertical=vertical)
figure.write_json(io_str)
fig_json = io_str.getvalue()
return fig_json, 200
diff --git a/gn3/app.py b/gn3/app.py
index a25332c..3d68b3f 100644
--- a/gn3/app.py
+++ b/gn3/app.py
@@ -21,12 +21,6 @@ def create_app(config: Union[Dict, str, None] = None) -> Flask:
# Load default configuration
app.config.from_object("gn3.settings")
- CORS(
- app,
- origins=app.config["CORS_ORIGINS"],
- allow_headers=app.config["CORS_HEADERS"],
- supports_credentials=True, intercept_exceptions=False)
-
# Load environment configuration
if "GN3_CONF" in os.environ:
app.config.from_envvar('GN3_CONF')
@@ -37,6 +31,13 @@ def create_app(config: Union[Dict, str, None] = None) -> Flask:
app.config.update(config)
elif config.endswith(".py"):
app.config.from_pyfile(config)
+
+ CORS(
+ app,
+ origins=app.config["CORS_ORIGINS"],
+ allow_headers=app.config["CORS_HEADERS"],
+ supports_credentials=True, intercept_exceptions=False)
+
app.register_blueprint(general, url_prefix="/api/")
app.register_blueprint(gemma, url_prefix="/api/gemma")
app.register_blueprint(rqtl, url_prefix="/api/rqtl")
diff --git a/gn3/authentication.py b/gn3/authentication.py
new file mode 100644
index 0000000..a6372c1
--- /dev/null
+++ b/gn3/authentication.py
@@ -0,0 +1,165 @@
+"""Methods for interacting with gn-proxy."""
+import functools
+import json
+import uuid
+import datetime
+
+from urllib.parse import urljoin
+from enum import Enum, unique
+from typing import Dict, List, Optional, Union
+
+from redis import Redis
+import requests
+
+
+@functools.total_ordering
+class OrderedEnum(Enum):
+ """A class that ordered Enums in order of position"""
+ @classmethod
+ @functools.lru_cache(None)
+ def _member_list(cls):
+ return list(cls)
+
+ def __lt__(self, other):
+ if self.__class__ is other.__class__:
+ member_list = self.__class__._member_list()
+ return member_list.index(self) < member_list.index(other)
+ return NotImplemented
+
+
+@unique
+class DataRole(OrderedEnum):
+ """Enums for Data Access"""
+ NO_ACCESS = "no-access"
+ VIEW = "view"
+ EDIT = "edit"
+
+
+@unique
+class AdminRole(OrderedEnum):
+ """Enums for Admin status"""
+ NOT_ADMIN = "not-admin"
+ EDIT_ACCESS = "edit-access"
+ EDIT_ADMINS = "edit-admins"
+
+
+def get_user_membership(conn: Redis, user_id: str,
+ group_id: str) -> Dict:
+ """Return a dictionary that indicates whether the `user_id` is a
+ member or admin of `group_id`.
+
+ Args:
+ - conn: a Redis Connection with the responses decoded.
+ - user_id: a user's unique id
+ e.g. '8ad942fe-490d-453e-bd37-56f252e41603'
+ - group_id: a group's unique id
+ e.g. '7fa95d07-0e2d-4bc5-b47c-448fdc1260b2'
+
+ Returns:
+ A dict indicating whether the user is an admin or a member of
+ the group: {"member": True, "admin": False}
+
+ """
+ results = {"member": False, "admin": False}
+ for key, value in conn.hgetall('groups').items():
+ if key == group_id:
+ group_info = json.loads(value)
+ if user_id in group_info.get("admins"):
+ results["admin"] = True
+ if user_id in group_info.get("members"):
+ results["member"] = True
+ break
+ return results
+
+
+def get_highest_user_access_role(
+ resource_id: str,
+ user_id: str,
+ gn_proxy_url: str = "http://localhost:8080") -> Dict:
+ """Get the highest access roles for a given user
+
+ Args:
+ - resource_id: The unique id of a given resource.
+ - user_id: The unique id of a given user.
+ - gn_proxy_url: The URL where gn-proxy is running.
+
+ Returns:
+ A dict indicating the highest access role the user has.
+
+ """
+ role_mapping: Dict[str, Union[DataRole, AdminRole]] = {}
+ for data_role, admin_role in zip(DataRole, AdminRole):
+ role_mapping.update({data_role.value: data_role, })
+ role_mapping.update({admin_role.value: admin_role, })
+ access_role = {}
+ response = requests.get(urljoin(gn_proxy_url,
+ ("available?resource="
+ f"{resource_id}&user={user_id}")))
+ for key, value in json.loads(response.content).items():
+ access_role[key] = max(map(lambda role: role_mapping[role], value))
+ return access_role
+
+
+def get_groups_by_user_uid(user_uid: str, conn: Redis) -> Dict:
+ """Given a user uid, get the groups in which they are a member or admin of.
+
+ Args:
+ - user_uid: A user's unique id
+ - conn: A redis connection
+
+ Returns:
+ - A dictionary containing the list of groups the user is part of e.g.:
+ {"admin": [], "member": ["ce0dddd1-6c50-4587-9eec-6c687a54ad86"]}
+ """
+ admin = []
+ member = []
+ for group_uuid, group_info in conn.hgetall("groups").items():
+ group_info = json.loads(group_info)
+ group_info["uuid"] = group_uuid
+ if user_uid in group_info.get('admins'):
+ admin.append(group_info)
+ if user_uid in group_info.get('members'):
+ member.append(group_info)
+ return {
+ "admin": admin,
+ "member": member,
+ }
+
+
+def get_user_info_by_key(key: str, value: str,
+ conn: Redis) -> Optional[Dict]:
+ """Given a key, get a user's information if value is matched"""
+ if key != "user_id":
+ for user_uuid, user_info in conn.hgetall("users").items():
+ user_info = json.loads(user_info)
+ if (key in user_info and user_info.get(key) == value):
+ user_info["user_id"] = user_uuid
+ return user_info
+ elif key == "user_id":
+ if user_info := conn.hget("users", value):
+ user_info = json.loads(user_info)
+ user_info["user_id"] = value
+ return user_info
+ return None
+
+
+def create_group(conn: Redis, group_name: Optional[str],
+ admin_user_uids: List = None,
+ member_user_uids: List = None) -> Optional[Dict]:
+ """Create a group given the group name, members and admins of that group."""
+ if admin_user_uids is None:
+ admin_user_uids = []
+ if member_user_uids is None:
+ member_user_uids = []
+ if group_name and bool(admin_user_uids + member_user_uids):
+ timestamp = datetime.datetime.utcnow().strftime('%b %d %Y %I:%M%p')
+ group = {
+ "id": (group_id := str(uuid.uuid4())),
+ "admins": admin_user_uids,
+ "members": member_user_uids,
+ "name": group_name,
+ "created_timestamp": timestamp,
+ "changed_timestamp": timestamp,
+ }
+ conn.hset("groups", group_id, json.dumps(group))
+ return group
diff --git a/gn3/computations/biweight.py b/gn3/computations/biweight.py
deleted file mode 100644
index 7accd0c..0000000
--- a/gn3/computations/biweight.py
+++ /dev/null
@@ -1,27 +0,0 @@
-"""module contains script to call biweight midcorrelation in R"""
-import subprocess
-
-from typing import List
-from typing import Tuple
-
-from gn3.settings import BIWEIGHT_RSCRIPT
-
-
-def calculate_biweight_corr(trait_vals: List,
- target_vals: List,
- path_to_script: str = BIWEIGHT_RSCRIPT,
- command: str = "Rscript"
- ) -> Tuple[float, float]:
- """biweight function"""
-
- args_1 = ' '.join(str(trait_val) for trait_val in trait_vals)
- args_2 = ' '.join(str(target_val) for target_val in target_vals)
- cmd = [command, path_to_script] + [args_1] + [args_2]
-
- results = subprocess.check_output(cmd, universal_newlines=True)
- try:
- (corr_coeff, p_val) = tuple(
- [float(y.strip()) for y in results.split()])
- return (corr_coeff, p_val)
- except Exception as error:
- raise error
diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py
index bb13ff1..c5c56db 100644
--- a/gn3/computations/correlations.py
+++ b/gn3/computations/correlations.py
@@ -1,6 +1,7 @@
"""module contains code for correlations"""
import math
import multiprocessing
+from contextlib import closing
from typing import List
from typing import Tuple
@@ -8,7 +9,7 @@ from typing import Optional
from typing import Callable
import scipy.stats
-from gn3.computations.biweight import calculate_biweight_corr
+import pingouin as pg
def map_shared_keys_to_values(target_sample_keys: List,
@@ -49,13 +50,9 @@ def normalize_values(a_values: List,
([2.3, 4.1, 5], [3.4, 6.2, 4.1], 3)
"""
- a_new = []
- b_new = []
for a_val, b_val in zip(a_values, b_values):
if (a_val and b_val is not None):
- a_new.append(a_val)
- b_new.append(b_val)
- return a_new, b_new, len(a_new)
+ yield a_val, b_val
def compute_corr_coeff_p_value(primary_values: List, target_values: List,
@@ -81,8 +78,10 @@ def compute_sample_r_correlation(trait_name, corr_method, trait_vals,
correlation coeff and p value
"""
- (sanitized_traits_vals, sanitized_target_vals,
- num_overlap) = normalize_values(trait_vals, target_samples_vals)
+
+ sanitized_traits_vals, sanitized_target_vals = list(
+ zip(*list(normalize_values(trait_vals, target_samples_vals))))
+ num_overlap = len(sanitized_traits_vals)
if num_overlap > 5:
@@ -102,11 +101,10 @@ package :not packaged in guix
"""
- try:
- results = calculate_biweight_corr(x_val, y_val)
- return results
- except Exception as error:
- raise error
+ results = pg.corr(x_val, y_val, method="bicor")
+ corr_coeff = results["r"].values[0]
+ p_val = results["p-val"].values[0]
+ return (corr_coeff, p_val)
def filter_shared_sample_keys(this_samplelist,
@@ -115,13 +113,9 @@ def filter_shared_sample_keys(this_samplelist,
filter the values using the shared keys
"""
- this_vals = []
- target_vals = []
for key, value in target_samplelist.items():
if key in this_samplelist:
- target_vals.append(value)
- this_vals.append(this_samplelist[key])
- return (this_vals, target_vals)
+ yield this_samplelist[key], value
def fast_compute_all_sample_correlation(this_trait,
@@ -140,9 +134,10 @@ def fast_compute_all_sample_correlation(this_trait,
for target_trait in target_dataset:
trait_name = target_trait.get("trait_id")
target_trait_data = target_trait["trait_sample_data"]
- processed_values.append((trait_name, corr_method, *filter_shared_sample_keys(
- this_trait_samples, target_trait_data)))
- with multiprocessing.Pool(4) as pool:
+ processed_values.append((trait_name, corr_method,
+ list(zip(*list(filter_shared_sample_keys(
+ this_trait_samples, target_trait_data))))))
+ with closing(multiprocessing.Pool()) as pool:
results = pool.starmap(compute_sample_r_correlation, processed_values)
for sample_correlation in results:
@@ -173,8 +168,8 @@ def compute_all_sample_correlation(this_trait,
for target_trait in target_dataset:
trait_name = target_trait.get("trait_id")
target_trait_data = target_trait["trait_sample_data"]
- this_vals, target_vals = filter_shared_sample_keys(
- this_trait_samples, target_trait_data)
+ this_vals, target_vals = list(zip(*list(filter_shared_sample_keys(
+ this_trait_samples, target_trait_data))))
sample_correlation = compute_sample_r_correlation(
trait_name=trait_name,
diff --git a/gn3/computations/correlations2.py b/gn3/computations/correlations2.py
index 93db3fa..d0222ae 100644
--- a/gn3/computations/correlations2.py
+++ b/gn3/computations/correlations2.py
@@ -6,45 +6,21 @@ FUNCTIONS:
compute_correlation:
TODO: Describe what the function does..."""
-from math import sqrt
-from functools import reduce
+from scipy import stats
## From GN1: mostly for clustering and heatmap generation
def __items_with_values(dbdata, userdata):
"""Retains only corresponding items in the data items that are not `None` values.
This should probably be renamed to something sensible"""
- def both_not_none(item1, item2):
- """Check that both items are not the value `None`."""
- if (item1 is not None) and (item2 is not None):
- return (item1, item2)
- return None
- def split_lists(accumulator, item):
- """Separate the 'x' and 'y' items."""
- return [accumulator[0] + [item[0]], accumulator[1] + [item[1]]]
- return reduce(
- split_lists,
- filter(lambda x: x is not None, map(both_not_none, dbdata, userdata)),
- [[], []])
+ filtered = [x for x in zip(dbdata, userdata) if x[0] is not None and x[1] is not None]
+ return tuple(zip(*filtered)) if filtered else ([], [])
def compute_correlation(dbdata, userdata):
- """Compute some form of correlation.
+ """Compute the Pearson correlation coefficient.
This is extracted from
https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/utility/webqtlUtil.py#L622-L647
"""
x_items, y_items = __items_with_values(dbdata, userdata)
- if len(x_items) < 6:
- return (0.0, len(x_items))
- meanx = sum(x_items)/len(x_items)
- meany = sum(y_items)/len(y_items)
- def cal_corr_vals(acc, item):
- xitem, yitem = item
- return [
- acc[0] + ((xitem - meanx) * (yitem - meany)),
- acc[1] + ((xitem - meanx) * (xitem - meanx)),
- acc[2] + ((yitem - meany) * (yitem - meany))]
- xyd, sxd, syd = reduce(cal_corr_vals, zip(x_items, y_items), [0.0, 0.0, 0.0])
- try:
- return ((xyd/(sqrt(sxd)*sqrt(syd))), len(x_items))
- except ZeroDivisionError:
- return(0, len(x_items))
+ correlation = stats.pearsonr(x_items, y_items)[0] if len(x_items) >= 6 else 0
+ return (correlation, len(x_items))
diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py
new file mode 100644
index 0000000..231b0a7
--- /dev/null
+++ b/gn3/computations/partial_correlations.py
@@ -0,0 +1,696 @@
+"""
+This module deals with partial correlations.
+
+It is an attempt to migrate over the partial correlations feature from
+GeneNetwork1.
+"""
+
+import math
+from functools import reduce, partial
+from typing import Any, Tuple, Union, Sequence
+
+import pandas
+import pingouin
+from scipy.stats import pearsonr, spearmanr
+
+from gn3.settings import TEXTDIR
+from gn3.random import random_string
+from gn3.function_helpers import compose
+from gn3.data_helpers import parse_csv_line
+from gn3.db.traits import export_informative
+from gn3.db.traits import retrieve_trait_info, retrieve_trait_data
+from gn3.db.species import species_name, translate_to_mouse_gene_id
+from gn3.db.correlations import (
+ get_filename,
+ fetch_all_database_data,
+ check_for_literature_info,
+ fetch_tissue_correlations,
+ fetch_literature_correlations,
+ check_symbol_for_tissue_correlation,
+ fetch_gene_symbol_tissue_value_dict_for_trait)
+
+def control_samples(controls: Sequence[dict], sampleslist: Sequence[str]):
+ """
+ Fetches data for the control traits.
+
+ This migrates `web/webqtl/correlation/correlationFunction.controlStrain` in
+ GN1, with a few modifications to the arguments passed in.
+
+ PARAMETERS:
+ controls: A map of sample names to trait data. Equivalent to the `cvals`
+ value in the corresponding source function in GN1.
+ sampleslist: A list of samples. Equivalent to `strainlst` in the
+ corresponding source function in GN1
+ """
+ def __process_control__(trait_data):
+ def __process_sample__(acc, sample):
+ if sample in trait_data["data"].keys():
+ sample_item = trait_data["data"][sample]
+ val = sample_item["value"]
+ if val is not None:
+ return (
+ acc[0] + (sample,),
+ acc[1] + (val,),
+ acc[2] + (sample_item["variance"],))
+ return acc
+ return reduce(
+ __process_sample__, sampleslist, (tuple(), tuple(), tuple()))
+
+ return reduce(
+ lambda acc, item: (
+ acc[0] + (item[0],),
+ acc[1] + (item[1],),
+ acc[2] + (item[2],),
+ acc[3] + (len(item[0]),),
+ ),
+ [__process_control__(trait_data) for trait_data in controls],
+ (tuple(), tuple(), tuple(), tuple()))
+
+def dictify_by_samples(samples_vals_vars: Sequence[Sequence]) -> Sequence[dict]:
+ """
+ Build a sequence of dictionaries from a sequence of separate sequences of
+ samples, values and variances.
+
+ This is a partial migration of
+ `web.webqtl.correlation.correlationFunction.fixStrains` function in GN1.
+ This implementation extracts code that will find common use, and that will
+ find use in more than one place.
+ """
+ return tuple(
+ {
+ sample: {"sample_name": sample, "value": val, "variance": var}
+ for sample, val, var in zip(*trait_line)
+ } for trait_line in zip(*(samples_vals_vars[0:3])))
+
+def fix_samples(primary_trait: dict, control_traits: Sequence[dict]) -> Sequence[Sequence[Any]]:
+ """
+ Corrects sample_names, values and variance such that they all contain only
+ those samples that are common to the reference trait and all control traits.
+
+ This is a partial migration of the
+ `web.webqtl.correlation.correlationFunction.fixStrain` function in GN1.
+ """
+ primary_samples = tuple(
+ present[0] for present in
+ ((sample, all(sample in control.keys() for control in control_traits))
+ for sample in primary_trait.keys())
+ if present[1])
+ control_vals_vars: tuple = reduce(
+ lambda acc, x: (acc[0] + (x[0],), acc[1] + (x[1],)),
+ ((item["value"], item["variance"])
+ for sublist in [tuple(control.values()) for control in control_traits]
+ for item in sublist),
+ (tuple(), tuple()))
+ return (
+ primary_samples,
+ tuple(primary_trait[sample]["value"] for sample in primary_samples),
+ control_vals_vars[0],
+ tuple(primary_trait[sample]["variance"] for sample in primary_samples),
+ control_vals_vars[1])
+
+def find_identical_traits(
+ primary_name: str, primary_value: float, control_names: Tuple[str, ...],
+ control_values: Tuple[float, ...]) -> Tuple[str, ...]:
+ """
+ Find traits that have the same value when the values are considered to
+ 3 decimal places.
+
+ This is a migration of the
+ `web.webqtl.correlation.correlationFunction.findIdenticalTraits` function in
+ GN1.
+ """
+ def __merge_identicals__(
+ acc: Tuple[str, ...],
+ ident: Tuple[str, Tuple[str, ...]]) -> Tuple[str, ...]:
+ return acc + ident[1]
+
+ def __dictify_controls__(acc, control_item):
+ ckey = tuple("{:.3f}".format(item) for item in control_item[0])
+ return {**acc, ckey: acc.get(ckey, tuple()) + (control_item[1],)}
+
+ return (reduce(## for identical control traits
+ __merge_identicals__,
+ (item for item in reduce(# type: ignore[var-annotated]
+ __dictify_controls__, zip(control_values, control_names),
+ {}).items() if len(item[1]) > 1),
+ tuple())
+ or
+ reduce(## If no identical control traits, try primary and controls
+ __merge_identicals__,
+ (item for item in reduce(# type: ignore[var-annotated]
+ __dictify_controls__,
+ zip((primary_value,) + control_values,
+ (primary_name,) + control_names), {}).items()
+ if len(item[1]) > 1),
+ tuple()))
+
+def tissue_correlation(
+ primary_trait_values: Tuple[float, ...],
+ target_trait_values: Tuple[float, ...],
+ method: str) -> Tuple[float, float]:
+ """
+ Compute the correlation between the primary trait values, and the values of
+ a single target value.
+
+ This migrates the `cal_tissue_corr` function embedded in the larger
+ `web.webqtl.correlation.correlationFunction.batchCalTissueCorr` function in
+ GeneNetwork1.
+ """
+ def spearman_corr(*args):
+ result = spearmanr(*args)
+ return (result.correlation, result.pvalue)
+
+ method_fns = {"pearson": pearsonr, "spearman": spearman_corr}
+
+ assert len(primary_trait_values) == len(target_trait_values), (
+ "The lengths of the `primary_trait_values` and `target_trait_values` "
+ "must be equal")
+ assert method in method_fns.keys(), (
+ "Method must be one of: {}".format(",".join(method_fns.keys())))
+
+ corr, pvalue = method_fns[method](primary_trait_values, target_trait_values)
+ return (corr, pvalue)
+
+def batch_computed_tissue_correlation(
+ primary_trait_values: Tuple[float, ...], target_traits_dict: dict,
+ method: str) -> Tuple[dict, dict]:
+ """
+ This is a migration of the
+ `web.webqtl.correlation.correlationFunction.batchCalTissueCorr` function in
+ GeneNetwork1
+ """
+ def __corr__(acc, target):
+ corr = tissue_correlation(primary_trait_values, target[1], method)
+ return ({**acc[0], target[0]: corr[0]}, {**acc[0], target[1]: corr[1]})
+ return reduce(__corr__, target_traits_dict.items(), ({}, {}))
+
+def correlations_of_all_tissue_traits(
+ primary_trait_symbol_value_dict: dict, symbol_value_dict: dict,
+ method: str) -> Tuple[dict, dict]:
+ """
+ Computes and returns the correlation of all tissue traits.
+
+ This is a migration of the
+ `web.webqtl.correlation.correlationFunction.calculateCorrOfAllTissueTrait`
+ function in GeneNetwork1.
+ """
+ primary_trait_values = tuple(primary_trait_symbol_value_dict.values())[0]
+ return batch_computed_tissue_correlation(
+ primary_trait_values, symbol_value_dict, method)
+
+def good_dataset_samples_indexes(
+ samples: Tuple[str, ...],
+ samples_from_file: Tuple[str, ...]) -> Tuple[int, ...]:
+ """
+ Return the indexes of the items in `samples_from_files` that are also found
+ in `samples`.
+
+ This is a partial migration of the
+ `web.webqtl.correlation.PartialCorrDBPage.getPartialCorrelationsFast`
+ function in GeneNetwork1.
+ """
+ return tuple(sorted(
+ samples_from_file.index(good) for good in
+ set(samples).intersection(set(samples_from_file))))
+
+def partial_correlations_fast(# pylint: disable=[R0913, R0914]
+ samples, primary_vals, control_vals, database_filename,
+ fetched_correlations, method: str, correlation_type: str) -> Tuple[
+ float, Tuple[float, ...]]:
+ """
+ Computes partial correlation coefficients using data from a CSV file.
+
+ This is a partial migration of the
+ `web.webqtl.correlation.PartialCorrDBPage.getPartialCorrelationsFast`
+ function in GeneNetwork1.
+ """
+ assert method in ("spearman", "pearson")
+ with open(database_filename, "r") as dataset_file:
+ dataset = tuple(dataset_file.readlines())
+
+ good_dataset_samples = good_dataset_samples_indexes(
+ samples, parse_csv_line(dataset[0])[1:])
+
+ def __process_trait_names_and_values__(acc, line):
+ trait_line = parse_csv_line(line)
+ trait_name = trait_line[0]
+ trait_data = trait_line[1:]
+ if trait_name in fetched_correlations.keys():
+ return (
+ acc[0] + (trait_name,),
+ acc[1] + tuple(
+ trait_data[i] if i in good_dataset_samples else None
+ for i in range(len(trait_data))))
+ return acc
+
+ processed_trait_names_values: tuple = reduce(
+ __process_trait_names_and_values__, dataset[1:], (tuple(), tuple()))
+ all_target_trait_names: Tuple[str, ...] = processed_trait_names_values[0]
+ all_target_trait_values: Tuple[float, ...] = processed_trait_names_values[1]
+
+ all_correlations = compute_partial(
+ primary_vals, control_vals, all_target_trait_names,
+ all_target_trait_values, method)
+ ## Line 772 to 779 in GN1 are the cause of the weird complexity in the
+ ## return below. Once the surrounding code is successfully migrated and
+ ## reworked, this complexity might go away, by getting rid of the
+ ## `correlation_type` parameter
+ return len(all_correlations), tuple(
+ corr + (
+ (fetched_correlations[corr[0]],) if correlation_type == "literature"
+ else fetched_correlations[corr[0]][0:2])
+ for idx, corr in enumerate(all_correlations))
+
+def build_data_frame(
+ xdata: Tuple[float, ...], ydata: Tuple[float, ...],
+ zdata: Union[
+ Tuple[float, ...],
+ Tuple[Tuple[float, ...], ...]]) -> pandas.DataFrame:
+ """
+ Build a pandas DataFrame object from xdata, ydata and zdata
+ """
+ x_y_df = pandas.DataFrame({"x": xdata, "y": ydata})
+ if isinstance(zdata[0], float):
+ return x_y_df.join(pandas.DataFrame({"z": zdata}))
+ interm_df = x_y_df.join(pandas.DataFrame(
+ {"z{}".format(i): val for i, val in enumerate(zdata)}))
+ if interm_df.shape[1] == 3:
+ return interm_df.rename(columns={"z0": "z"})
+ return interm_df
+
+def compute_partial(
+ primary_vals, control_vals, target_vals, target_names,
+ method: str) -> Tuple[
+ Union[
+ Tuple[str, int, float, float, float, float], None],
+ ...]:
+ """
+ Compute the partial correlations.
+
+ This is a re-implementation of the
+ `web.webqtl.correlation.correlationFunction.determinePartialsByR` function
+ in GeneNetwork1.
+
+ This implementation reworks the child function `compute_partial` which will
+ then be used in the place of `determinPartialsByR`.
+
+ TODO: moving forward, we might need to use the multiprocessing library to
+ speed up the computations, in case they are found to be slow.
+ """
+ # replace the R code with `pingouin.partial_corr`
+ def __compute_trait_info__(target):
+ targ_vals = target[0]
+ targ_name = target[1]
+ primary = [
+ prim for targ, prim in zip(targ_vals, primary_vals)
+ if targ is not None]
+
+ datafrm = build_data_frame(
+ primary,
+ tuple(targ for targ in targ_vals if targ is not None),
+ tuple(cont for i, cont in enumerate(control_vals)
+ if target[i] is not None))
+ covariates = "z" if datafrm.shape[1] == 3 else [
+ col for col in datafrm.columns if col not in ("x", "y")]
+ ppc = pingouin.partial_corr(
+ data=datafrm, x="x", y="y", covar=covariates, method=(
+ "pearson" if "pearson" in method.lower() else "spearman"))
+ pc_coeff = ppc["r"][0]
+
+ zero_order_corr = pingouin.corr(
+ datafrm["x"], datafrm["y"], method=(
+ "pearson" if "pearson" in method.lower() else "spearman"))
+
+ if math.isnan(pc_coeff):
+ return (
+ targ_name, len(primary), pc_coeff, 1, zero_order_corr["r"][0],
+ zero_order_corr["p-val"][0])
+ return (
+ targ_name, len(primary), pc_coeff,
+ (ppc["p-val"][0] if not math.isnan(ppc["p-val"][0]) else (
+ 0 if (abs(pc_coeff - 1) < 0.0000001) else 1)),
+ zero_order_corr["r"][0], zero_order_corr["p-val"][0])
+
+ return tuple(
+ __compute_trait_info__(target)
+ for target in zip(target_vals, target_names))
+
+def partial_correlations_normal(# pylint: disable=R0913
+ primary_vals, control_vals, input_trait_gene_id, trait_database,
+ data_start_pos: int, db_type: str, method: str) -> Tuple[
+ float, Tuple[float, ...]]:
+ """
+ Computes the correlation coefficients.
+
+ This is a migration of the
+ `web.webqtl.correlation.PartialCorrDBPage.getPartialCorrelationsNormal`
+ function in GeneNetwork1.
+ """
+ def __add_lit_and_tiss_corr__(item):
+ if method.lower() == "sgo literature correlation":
+ # if method is 'SGO Literature Correlation', `compute_partial`
+ # would give us LitCorr in the [1] position
+ return tuple(item) + trait_database[1]
+ if method.lower() in (
+ "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho"):
+ # if method is 'Tissue Correlation, *', `compute_partial` would give
+ # us Tissue Corr in the [1] position and Tissue Corr P Value in the
+ # [2] position
+ return tuple(item) + (trait_database[1], trait_database[2])
+ return item
+
+ target_trait_names, target_trait_vals = reduce(
+ lambda acc, item: (acc[0]+(item[0],), acc[1]+(item[data_start_pos:],)),
+ trait_database, (tuple(), tuple()))
+
+ all_correlations = compute_partial(
+ primary_vals, control_vals, target_trait_vals, target_trait_names,
+ method)
+
+ if (input_trait_gene_id and db_type == "ProbeSet" and method.lower() in (
+ "sgo literature correlation", "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho")):
+ return (
+ len(trait_database),
+ tuple(
+ __add_lit_and_tiss_corr__(item)
+ for idx, item in enumerate(all_correlations)))
+
+ return len(trait_database), all_correlations
+
+def partial_corrs(# pylint: disable=[R0913]
+ conn, samples, primary_vals, control_vals, return_number, species,
+ input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id,
+ method, dataset, database_filename):
+ """
+ Compute the partial correlations, selecting the fast or normal method
+ depending on the existence of the database text file.
+
+ This is a partial migration of the
+ `web.webqtl.correlation.PartialCorrDBPage.__init__` function in
+ GeneNetwork1.
+ """
+ if database_filename:
+ return partial_correlations_fast(
+ samples, primary_vals, control_vals, database_filename,
+ (
+ fetch_literature_correlations(
+ species, input_trait_geneid, dataset, return_number, conn)
+ if "literature" in method.lower() else
+ fetch_tissue_correlations(
+ dataset, input_trait_symbol, tissue_probeset_freeze_id,
+ method, return_number, conn)),
+ method,
+ ("literature" if method.lower() == "sgo literature correlation"
+ else ("tissue" if "tissue" in method.lower() else "genetic")))
+
+ trait_database, data_start_pos = fetch_all_database_data(
+ conn, species, input_trait_geneid, input_trait_symbol, samples, dataset,
+ method, return_number, tissue_probeset_freeze_id)
+ return partial_correlations_normal(
+ primary_vals, control_vals, input_trait_geneid, trait_database,
+ data_start_pos, dataset, method)
+
+def literature_correlation_by_list(
+ conn: Any, species: str, trait_list: Tuple[dict]) -> Tuple[dict]:
+ """
+ This is a migration of the
+ `web.webqtl.correlation.CorrelationPage.getLiteratureCorrelationByList`
+ function in GeneNetwork1.
+ """
+ if any((lambda t: (
+ bool(t.get("tissue_corr")) and
+ bool(t.get("tissue_p_value"))))(trait)
+ for trait in trait_list):
+ temporary_table_name = f"LITERATURE{random_string(8)}"
+ query1 = (
+ f"CREATE TEMPORARY TABLE {temporary_table_name} "
+ "(GeneId1 INT(12) UNSIGNED, GeneId2 INT(12) UNSIGNED PRIMARY KEY, "
+ "value DOUBLE)")
+ query2 = (
+ f"INSERT INTO {temporary_table_name}(GeneId1, GeneId2, value) "
+ "SELECT GeneId1, GeneId2, value FROM LCorrRamin3 "
+ "WHERE GeneId1=%(geneid)s")
+ query3 = (
+ "INSERT INTO {temporary_table_name}(GeneId1, GeneId2, value) "
+ "SELECT GeneId2, GeneId1, value FROM LCorrRamin3 "
+ "WHERE GeneId2=%s AND GeneId1 != %(geneid)s")
+
+ def __set_mouse_geneid__(trait):
+ if trait.get("geneid"):
+ return {
+ **trait,
+ "mouse_geneid": translate_to_mouse_gene_id(
+ species, trait.get("geneid"), conn)
+ }
+ return {**trait, "mouse_geneid": 0}
+
+ def __retrieve_lcorr__(cursor, geneids):
+ cursor.execute(
+ f"SELECT GeneId2, value FROM {temporary_table_name} "
+ "WHERE GeneId2 IN %(geneids)s",
+ geneids=geneids)
+ return dict(cursor.fetchall())
+
+ with conn.cursor() as cursor:
+ cursor.execute(query1)
+ cursor.execute(query2)
+ cursor.execute(query3)
+
+ traits = tuple(__set_mouse_geneid__(trait) for trait in trait_list)
+ lcorrs = __retrieve_lcorr__(
+ cursor, (
+ trait["mouse_geneid"] for trait in traits
+ if (trait["mouse_geneid"] != 0 and
+ trait["mouse_geneid"].find(";") < 0)))
+ return tuple(
+ {**trait, "l_corr": lcorrs.get(trait["mouse_geneid"], None)}
+ for trait in traits)
+
+ return trait_list
+ return trait_list
+
+def tissue_correlation_by_list(
+ conn: Any, primary_trait_symbol: str, tissue_probeset_freeze_id: int,
+ method: str, trait_list: Tuple[dict]) -> Tuple[dict]:
+ """
+ This is a migration of the
+ `web.webqtl.correlation.CorrelationPage.getTissueCorrelationByList`
+ function in GeneNetwork1.
+ """
+ def __add_tissue_corr__(trait, primary_trait_values, trait_values):
+ result = pingouin.corr(
+ primary_trait_values, trait_values,
+ method=("spearman" if "spearman" in method.lower() else "pearson"))
+ return {
+ **trait,
+ "tissue_corr": result["r"],
+ "tissue_p_value": result["p-val"]
+ }
+
+ if any((lambda t: bool(t.get("l_corr")))(trait) for trait in trait_list):
+ prim_trait_symbol_value_dict = fetch_gene_symbol_tissue_value_dict_for_trait(
+ (primary_trait_symbol,), tissue_probeset_freeze_id, conn)
+ if primary_trait_symbol.lower() in prim_trait_symbol_value_dict:
+ primary_trait_value = prim_trait_symbol_value_dict[
+ primary_trait_symbol.lower()]
+ gene_symbol_list = tuple(
+ trait for trait in trait_list if "symbol" in trait.keys())
+ symbol_value_dict = fetch_gene_symbol_tissue_value_dict_for_trait(
+ gene_symbol_list, tissue_probeset_freeze_id, conn)
+ return tuple(
+ __add_tissue_corr__(
+ trait, primary_trait_value,
+ symbol_value_dict[trait["symbol"].lower()])
+ for trait in trait_list
+ if ("symbol" in trait and
+ bool(trait["symbol"]) and
+ trait["symbol"].lower() in symbol_value_dict))
+ return tuple({
+ **trait,
+ "tissue_corr": None,
+ "tissue_p_value": None
+ } for trait in trait_list)
+ return trait_list
+
+def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
+ conn: Any, primary_trait_name: str,
+ control_trait_names: Tuple[str, ...], method: str,
+ criteria: int, group: str, target_db_name: str) -> dict:
+ """
+ This is the 'ochestration' function for the partial-correlation feature.
+
+ This function will dispatch the functions doing data fetches from the
+ database (and various other places) and feed that data to the functions
+ doing the conversions and computations. It will then return the results of
+ all of that work.
+
+ This function is doing way too much. Look into splitting out the
+ functionality into smaller functions that do fewer things.
+ """
+ threshold = 0
+ corr_min_informative = 4
+
+ primary_trait = retrieve_trait_info(threshold, primary_trait_name, conn)
+ primary_trait_data = retrieve_trait_data(primary_trait, conn)
+ primary_samples, primary_values, _primary_variances = export_informative(
+ primary_trait_data)
+
+ cntrl_traits = tuple(
+ retrieve_trait_info(threshold, trait_full_name, conn)
+ for trait_full_name in control_trait_names)
+ cntrl_traits_data = tuple(
+ retrieve_trait_data(cntrl_trait, conn)
+ for cntrl_trait in cntrl_traits)
+ species = species_name(conn, group)
+
+ (cntrl_samples,
+ cntrl_values,
+ _cntrl_variances,
+ _cntrl_ns) = control_samples(cntrl_traits_data, primary_samples)
+
+ common_primary_control_samples = primary_samples
+ fixed_primary_vals = primary_values
+ fixed_control_vals = cntrl_values
+ if not all(cnt_smp == primary_samples for cnt_smp in cntrl_samples):
+ (common_primary_control_samples,
+ fixed_primary_vals,
+ fixed_control_vals,
+ _primary_variances,
+ _cntrl_variances) = fix_samples(primary_trait, cntrl_traits)
+
+ if len(common_primary_control_samples) < corr_min_informative:
+ return {
+ "status": "error",
+ "message": (
+ f"Fewer than {corr_min_informative} samples data entered for "
+ f"{group} dataset. No calculation of correlation has been "
+ "attempted."),
+ "error_type": "Inadequate Samples"}
+
+ identical_traits_names = find_identical_traits(
+ primary_trait_name, primary_values, control_trait_names, cntrl_values)
+ if len(identical_traits_names) > 0:
+ return {
+ "status": "error",
+ "message": (
+ f"{identical_traits_names[0]} and {identical_traits_names[1]} "
+ "have the same values for the {len(fixed_primary_vals)} "
+ "samples that will be used to compute the partial correlation "
+ "(common for all primary and control traits). In such cases, "
+ "partial correlation cannot be computed. Please re-select your "
+ "traits."),
+ "error_type": "Identical Traits"}
+
+ input_trait_geneid = primary_trait.get("geneid")
+ input_trait_symbol = primary_trait.get("symbol")
+ input_trait_mouse_geneid = translate_to_mouse_gene_id(
+ species, input_trait_geneid, conn)
+
+ tissue_probeset_freeze_id = 1
+ db_type = primary_trait["db"]["dataset_type"]
+
+ if db_type == "ProbeSet" and method.lower() in (
+ "sgo literature correlation",
+ "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho"):
+ return {
+ "status": "error",
+ "message": (
+ "Wrong correlation type: It is not possible to compute the "
+ f"{method} between your trait and data in the {target_db_name} "
+ "database. Please try again after selecting another type of "
+ "correlation."),
+ "error_type": "Correlation Type"}
+
+ if (method.lower() == "sgo literature correlation" and (
+ input_trait_geneid is None or
+ check_for_literature_info(conn, input_trait_mouse_geneid))):
+ return {
+ "status": "error",
+ "message": (
+ "No Literature Information: This gene does not have any "
+ "associated Literature Information."),
+ "error_type": "Literature Correlation"}
+
+ if (
+ method.lower() in (
+ "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho")
+ and input_trait_symbol is None):
+ return {
+ "status": "error",
+ "message": (
+ "No Tissue Correlation Information: This gene does not have "
+ "any associated Tissue Correlation Information."),
+ "error_type": "Tissue Correlation"}
+
+ if (
+ method.lower() in (
+ "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho")
+ and check_symbol_for_tissue_correlation(
+ conn, tissue_probeset_freeze_id, input_trait_symbol)):
+ return {
+ "status": "error",
+ "message": (
+ "No Tissue Correlation Information: This gene does not have "
+ "any associated Tissue Correlation Information."),
+ "error_type": "Tissue Correlation"}
+
+ database_filename = get_filename(conn, target_db_name, TEXTDIR)
+ _total_traits, all_correlations = partial_corrs(
+ conn, common_primary_control_samples, fixed_primary_vals,
+ fixed_control_vals, len(fixed_primary_vals), species,
+ input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id,
+ method, primary_trait["db"], database_filename)
+
+
+ def __make_sorter__(method):
+ def __sort_6__(row):
+ return row[6]
+
+ def __sort_3__(row):
+ return row[3]
+
+ if "literature" in method.lower():
+ return __sort_6__
+
+ if "tissue" in method.lower():
+ return __sort_6__
+
+ return __sort_3__
+
+ sorted_correlations = sorted(
+ all_correlations, key=__make_sorter__(method))
+
+ add_lit_corr_and_tiss_corr = compose(
+ partial(literature_correlation_by_list, conn, species),
+ partial(
+ tissue_correlation_by_list, conn, input_trait_symbol,
+ tissue_probeset_freeze_id, method))
+
+ trait_list = add_lit_corr_and_tiss_corr(tuple(
+ {
+ **retrieve_trait_info(
+ threshold,
+ f"{primary_trait['db']['dataset_name']}::{item[0]}",
+ conn),
+ "noverlap": item[1],
+ "partial_corr": item[2],
+ "partial_corr_p_value": item[3],
+ "corr": item[4],
+ "corr_p_value": item[5],
+ "rank_order": (1 if "spearman" in method.lower() else 0),
+ **({
+ "tissue_corr": item[6],
+ "tissue_p_value": item[7]}
+ if len(item) == 8 else {}),
+ **({"l_corr": item[6]}
+ if len(item) == 7 else {})
+ }
+ for item in
+ sorted_correlations[:min(criteria, len(all_correlations))]))
+
+ return trait_list
diff --git a/gn3/computations/wgcna.py b/gn3/computations/wgcna.py
index fd508fa..ab12fe7 100644
--- a/gn3/computations/wgcna.py
+++ b/gn3/computations/wgcna.py
@@ -3,8 +3,11 @@
import os
import json
import uuid
-from gn3.settings import TMPDIR
+import subprocess
+import base64
+
+from gn3.settings import TMPDIR
from gn3.commands import run_cmd
@@ -14,12 +17,46 @@ def dump_wgcna_data(request_data: dict):
temp_file_path = os.path.join(TMPDIR, filename)
+ request_data["TMPDIR"] = TMPDIR
+
with open(temp_file_path, "w") as output_file:
json.dump(request_data, output_file)
return temp_file_path
+def stream_cmd_output(socketio, request_data, cmd: str):
+ """function to stream in realtime"""
+ # xtodo syncing and closing /edge cases
+
+ socketio.emit("output", {"data": f"calling you script {cmd}"},
+ namespace="/", room=request_data["socket_id"])
+ results = subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True)
+
+ if results.stdout is not None:
+
+ for line in iter(results.stdout.readline, b""):
+ socketio.emit("output",
+ {"data": line.decode("utf-8").rstrip()},
+ namespace="/", room=request_data["socket_id"])
+
+ socketio.emit(
+ "output", {"data":
+ "parsing the output results"}, namespace="/",
+ room=request_data["socket_id"])
+
+
+def process_image(image_loc: str) -> bytes:
+ """encode the image"""
+
+ try:
+ with open(image_loc, "rb") as image_file:
+ return base64.b64encode(image_file.read())
+ except FileNotFoundError:
+ return b""
+
+
def compose_wgcna_cmd(rscript_path: str, temp_file_path: str):
"""function to componse wgcna cmd"""
# (todo):issue relative paths to abs paths
@@ -32,6 +69,8 @@ def call_wgcna_script(rscript_path: str, request_data: dict):
generated_file = dump_wgcna_data(request_data)
cmd = compose_wgcna_cmd(rscript_path, generated_file)
+ # stream_cmd_output(request_data, cmd) disable streaming of data
+
try:
run_cmd_results = run_cmd(cmd)
@@ -40,8 +79,14 @@ def call_wgcna_script(rscript_path: str, request_data: dict):
if run_cmd_results["code"] != 0:
return run_cmd_results
+
+ output_file_data = json.load(outputfile)
+ output_file_data["output"]["image_data"] = process_image(
+ output_file_data["output"]["imageLoc"]).decode("ascii")
+ # json format only supports unicode string// to get image data reconvert
+
return {
- "data": json.load(outputfile),
+ "data": output_file_data,
**run_cmd_results
}
except FileNotFoundError:
diff --git a/gn3/data_helpers.py b/gn3/data_helpers.py
new file mode 100644
index 0000000..b72fbc5
--- /dev/null
+++ b/gn3/data_helpers.py
@@ -0,0 +1,52 @@
+"""
+This module will hold generic functions that can operate on a wide-array of
+data structures.
+"""
+
+from math import ceil
+from functools import reduce
+from typing import Any, Tuple, Sequence, Optional
+
+def partition_all(num: int, items: Sequence[Any]) -> Tuple[Tuple[Any, ...], ...]:
+ """
+ Given a sequence `items`, return a new sequence of the same type as `items`
+ with the data partitioned into sections of `n` items per partition.
+
+ This is an approximation of clojure's `partition-all` function.
+ """
+ def __compute_start_stop__(acc, iteration):
+ start = iteration * num
+ return acc + ((start, start + num),)
+
+ iterations = range(ceil(len(items) / num))
+ return tuple([# type: ignore[misc]
+ tuple(items[start:stop]) for start, stop # type: ignore[has-type]
+ in reduce(
+ __compute_start_stop__, iterations, tuple())])
+
+def partition_by(partition_fn, items):
+ """
+ Given a sequence `items`, return a tuple of tuples, each of which contain
+ the values in `items` partitioned such that the first item in each internal
+ tuple, when passed to `partition_function` returns True.
+
+ This is an approximation of Clojure's `partition-by` function.
+ """
+ def __partitioner__(accumulator, item):
+ if partition_fn(item):
+ return accumulator + ((item,),)
+ return accumulator[:-1] + (accumulator[-1] + (item,),)
+
+ return reduce(__partitioner__, items, tuple())
+
+def parse_csv_line(
+ line: str, delimiter: str = ",",
+ quoting: Optional[str] = '"') -> Tuple[str, ...]:
+ """
+ Parses a line from a CSV file into a tuple of strings.
+
+ This is a migration of the `web.webqtl.utility.webqtlUtil.readLineCSV`
+ function in GeneNetwork1.
+ """
+ return tuple(
+ col.strip("{} \t\n".format(quoting)) for col in line.split(delimiter))
diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py
new file mode 100644
index 0000000..3d12019
--- /dev/null
+++ b/gn3/db/correlations.py
@@ -0,0 +1,564 @@
+"""
+This module will hold functions that are used in the (partial) correlations
+feature to access the database to retrieve data needed for computations.
+"""
+import os
+from functools import reduce
+from typing import Any, Dict, Tuple, Union
+
+from gn3.random import random_string
+from gn3.data_helpers import partition_all
+from gn3.db.species import translate_to_mouse_gene_id
+
+def get_filename(conn: Any, target_db_name: str, text_files_dir: str) -> Union[
+ str, bool]:
+ """
+ Retrieve the name of the reference database file with which correlations are
+ computed.
+
+ This is a migration of the
+ `web.webqtl.correlation.CorrelationPage.getFileName` function in
+ GeneNetwork1.
+ """
+ with conn.cursor() as cursor:
+ cursor.execute(
+ "SELECT Id, FullName from ProbeSetFreeze WHERE Name=%s",
+ (target_db_name,))
+ result = cursor.fetchone()
+ if result:
+ filename = "ProbeSetFreezeId_{tid}_FullName_{fname}.txt".format(
+ tid=result[0],
+ fname=result[1].replace(' ', '_').replace('/', '_'))
+ return ((filename in os.listdir(text_files_dir))
+ and f"{text_files_dir}/{filename}")
+
+ return False
+
+def build_temporary_literature_table(
+ conn: Any, species: str, gene_id: int, return_number: int) -> str:
+ """
+ Build and populate a temporary table to hold the literature correlation data
+ to be used in computations.
+
+ "This is a migration of the
+ `web.webqtl.correlation.CorrelationPage.getTempLiteratureTable` function in
+ GeneNetwork1.
+ """
+ def __translated_species_id(row, cursor):
+ if species == "mouse":
+ return row[1]
+ query = {
+ "rat": "SELECT rat FROM GeneIDXRef WHERE mouse=%s",
+ "human": "SELECT human FROM GeneIDXRef WHERE mouse=%d"}
+ if species in query.keys():
+ cursor.execute(query[species], row[1])
+ record = cursor.fetchone()
+ if record:
+ return record[0]
+ return None
+ return None
+
+ temp_table_name = f"TOPLITERATURE{random_string(8)}"
+ with conn.cursor as cursor:
+ mouse_geneid = translate_to_mouse_gene_id(species, gene_id, conn)
+ data_query = (
+ "SELECT GeneId1, GeneId2, value FROM LCorrRamin3 "
+ "WHERE GeneId1 = %(mouse_gene_id)s "
+ "UNION ALL "
+ "SELECT GeneId2, GeneId1, value FROM LCorrRamin3 "
+ "WHERE GeneId2 = %(mouse_gene_id)s "
+ "AND GeneId1 != %(mouse_gene_id)s")
+ cursor.execute(
+ (f"CREATE TEMPORARY TABLE {temp_table_name} ("
+ "GeneId1 int(12) unsigned, "
+ "GeneId2 int(12) unsigned PRIMARY KEY, "
+ "value double)"))
+ cursor.execute(data_query, mouse_gene_id=mouse_geneid)
+ literature_data = [
+ {"GeneId1": row[0], "GeneId2": row[1], "value": row[2]}
+ for row in cursor.fetchall()
+ if __translated_species_id(row, cursor)]
+
+ cursor.execute(
+ (f"INSERT INTO {temp_table_name} "
+ "VALUES (%(GeneId1)s, %(GeneId2)s, %(value)s)"),
+ literature_data[0:(2 * return_number)])
+
+ return temp_table_name
+
+def fetch_geno_literature_correlations(temp_table: str) -> str:
+ """
+ Helper function for `fetch_literature_correlations` below, to build query
+ for `Geno*` tables.
+ """
+ return (
+ f"SELECT Geno.Name, {temp_table}.value "
+ "FROM Geno, GenoXRef, GenoFreeze "
+ f"LEFT JOIN {temp_table} ON {temp_table}.GeneId2=ProbeSet.GeneId "
+ "WHERE ProbeSet.GeneId IS NOT NULL "
+ f"AND {temp_table}.value IS NOT NULL "
+ "AND GenoXRef.GenoFreezeId = GenoFreeze.Id "
+ "AND GenoFreeze.Name = %(db_name)s "
+ "AND Geno.Id=GenoXRef.GenoId "
+ "ORDER BY Geno.Id")
+
+def fetch_probeset_literature_correlations(temp_table: str) -> str:
+ """
+ Helper function for `fetch_literature_correlations` below, to build query
+ for `ProbeSet*` tables.
+ """
+ return (
+ f"SELECT ProbeSet.Name, {temp_table}.value "
+ "FROM ProbeSet, ProbeSetXRef, ProbeSetFreeze "
+ "LEFT JOIN {temp_table} ON {temp_table}.GeneId2=ProbeSet.GeneId "
+ "WHERE ProbeSet.GeneId IS NOT NULL "
+ "AND {temp_table}.value IS NOT NULL "
+ "AND ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id "
+ "AND ProbeSetFreeze.Name = %(db_name)s "
+ "AND ProbeSet.Id=ProbeSetXRef.ProbeSetId "
+ "ORDER BY ProbeSet.Id")
+
+def fetch_literature_correlations(
+ species: str, gene_id: int, dataset: dict, return_number: int,
+ conn: Any) -> dict:
+ """
+ Gather the literature correlation data and pair it with trait id string(s).
+
+ This is a migration of the
+ `web.webqtl.correlation.CorrelationPage.fetchLitCorrelations` function in
+ GeneNetwork1.
+ """
+ temp_table = build_temporary_literature_table(
+ conn, species, gene_id, return_number)
+ query_fns = {
+ "Geno": fetch_geno_literature_correlations,
+ # "Temp": fetch_temp_literature_correlations,
+ # "Publish": fetch_publish_literature_correlations,
+ "ProbeSet": fetch_probeset_literature_correlations}
+ with conn.cursor as cursor:
+ cursor.execute(
+ query_fns[dataset["dataset_type"]](temp_table),
+ db_name=dataset["dataset_name"])
+ results = cursor.fetchall()
+ cursor.execute("DROP TEMPORARY TABLE %s", temp_table)
+ return dict(results)
+
+def fetch_symbol_value_pair_dict(
+ symbol_list: Tuple[str, ...], data_id_dict: dict,
+ conn: Any) -> Dict[str, Tuple[float, ...]]:
+ """
+ Map each gene symbols to the corresponding tissue expression data.
+
+ This is a migration of the
+ `web.webqtl.correlation.correlationFunction.getSymbolValuePairDict` function
+ in GeneNetwork1.
+ """
+ data_ids = {
+ symbol: data_id_dict.get(symbol) for symbol in symbol_list
+ if data_id_dict.get(symbol) is not None
+ }
+ query = "SELECT Id, value FROM TissueProbeSetData WHERE Id IN %(data_ids)s"
+ with conn.cursor() as cursor:
+ cursor.execute(
+ query,
+ data_ids=tuple(data_ids.values()))
+ value_results = cursor.fetchall()
+ return {
+ key: tuple(row[1] for row in value_results if row[0] == key)
+ for key in data_ids.keys()
+ }
+
+ return {}
+
+def fetch_gene_symbol_tissue_value_dict(
+ symbol_list: Tuple[str, ...], data_id_dict: dict, conn: Any,
+ limit_num: int = 1000) -> dict:#getGeneSymbolTissueValueDict
+ """
+ Wrapper function for `gn3.db.correlations.fetch_symbol_value_pair_dict`.
+
+ This is a migrations of the
+ `web.webqtl.correlation.correlationFunction.getGeneSymbolTissueValueDict` in
+ GeneNetwork1.
+ """
+ count = len(symbol_list)
+ if count != 0 and count <= limit_num:
+ return fetch_symbol_value_pair_dict(symbol_list, data_id_dict, conn)
+
+ if count > limit_num:
+ return {
+ key: value for dct in [
+ fetch_symbol_value_pair_dict(sl, data_id_dict, conn)
+ for sl in partition_all(limit_num, symbol_list)]
+ for key, value in dct.items()
+ }
+
+ return {}
+
+def fetch_tissue_probeset_xref_info(
+ gene_name_list: Tuple[str, ...], probeset_freeze_id: int,
+ conn: Any) -> Tuple[tuple, dict, dict, dict, dict, dict, dict]:
+ """
+ Retrieve the ProbeSet XRef information for tissues.
+
+ This is a migration of the
+ `web.webqtl.correlation.correlationFunction.getTissueProbeSetXRefInfo`
+ function in GeneNetwork1."""
+ with conn.cursor() as cursor:
+ if len(gene_name_list) == 0:
+ query = (
+ "SELECT t.Symbol, t.GeneId, t.DataId, t.Chr, t.Mb, "
+ "t.description, t.Probe_Target_Description "
+ "FROM "
+ "("
+ " SELECT Symbol, max(Mean) AS maxmean "
+ " FROM TissueProbeSetXRef "
+ " WHERE TissueProbeSetFreezeId=%(probeset_freeze_id)s "
+ " AND Symbol != '' "
+ " AND Symbol IS NOT NULL "
+ " GROUP BY Symbol"
+ ") AS x "
+ "INNER JOIN TissueProbeSetXRef AS t ON t.Symbol = x.Symbol "
+ "AND t.Mean = x.maxmean")
+ cursor.execute(query, probeset_freeze_id=probeset_freeze_id)
+ else:
+ query = (
+ "SELECT t.Symbol, t.GeneId, t.DataId, t.Chr, t.Mb, "
+ "t.description, t.Probe_Target_Description "
+ "FROM "
+ "("
+ " SELECT Symbol, max(Mean) AS maxmean "
+ " FROM TissueProbeSetXRef "
+ " WHERE TissueProbeSetFreezeId=%(probeset_freeze_id)s "
+ " AND Symbol in %(symbols)s "
+ " GROUP BY Symbol"
+ ") AS x "
+ "INNER JOIN TissueProbeSetXRef AS t ON t.Symbol = x.Symbol "
+ "AND t.Mean = x.maxmean")
+ cursor.execute(
+ query, probeset_freeze_id=probeset_freeze_id,
+ symbols=tuple(gene_name_list))
+
+ results = cursor.fetchall()
+
+ return reduce(
+ lambda acc, item: (
+ acc[0] + (item[0],),
+ {**acc[1], item[0].lower(): item[1]},
+ {**acc[1], item[0].lower(): item[2]},
+ {**acc[1], item[0].lower(): item[3]},
+ {**acc[1], item[0].lower(): item[4]},
+ {**acc[1], item[0].lower(): item[5]},
+ {**acc[1], item[0].lower(): item[6]}),
+ results or tuple(),
+ (tuple(), {}, {}, {}, {}, {}, {}))
+
+def fetch_gene_symbol_tissue_value_dict_for_trait(
+ gene_name_list: Tuple[str, ...], probeset_freeze_id: int,
+ conn: Any) -> dict:
+ """
+ Fetches a map of the gene symbols to the tissue values.
+
+ This is a migration of the
+ `web.webqtl.correlation.correlationFunction.getGeneSymbolTissueValueDictForTrait`
+ function in GeneNetwork1.
+ """
+ xref_info = fetch_tissue_probeset_xref_info(
+ gene_name_list, probeset_freeze_id, conn)
+ if xref_info[0]:
+ return fetch_gene_symbol_tissue_value_dict(xref_info[0], xref_info[2], conn)
+ return {}
+
+def build_temporary_tissue_correlations_table(
+ conn: Any, trait_symbol: str, probeset_freeze_id: int, method: str,
+ return_number: int) -> str:
+ """
+ Build a temporary table to hold the tissue correlations data.
+
+ This is a migration of the
+ `web.webqtl.correlation.CorrelationPage.getTempTissueCorrTable` function in
+ GeneNetwork1."""
+ # We should probably pass the `correlations_of_all_tissue_traits` function
+ # as an argument to this function and get rid of the one call immediately
+ # following this comment.
+ from gn3.computations.partial_correlations import (#pylint: disable=[C0415, R0401]
+ correlations_of_all_tissue_traits)
+ # This import above is necessary within the function to avoid
+ # circular-imports.
+ #
+ #
+ # This import above is indicative of convoluted code, with the computation
+ # being interwoven with the data retrieval. This needs to be changed, such
+ # that the function being imported here is no longer necessary, or have the
+ # imported function passed to this function as an argument.
+ symbol_corr_dict, symbol_p_value_dict = correlations_of_all_tissue_traits(
+ fetch_gene_symbol_tissue_value_dict_for_trait(
+ (trait_symbol,), probeset_freeze_id, conn),
+ fetch_gene_symbol_tissue_value_dict_for_trait(
+ tuple(), probeset_freeze_id, conn),
+ method)
+
+ symbol_corr_list = sorted(
+ symbol_corr_dict.items(), key=lambda key_val: key_val[1])
+
+ temp_table_name = f"TOPTISSUE{random_string(8)}"
+ create_query = (
+ "CREATE TEMPORARY TABLE {temp_table_name}"
+ "(Symbol varchar(100) PRIMARY KEY, Correlation float, PValue float)")
+ insert_query = (
+ f"INSERT INTO {temp_table_name}(Symbol, Correlation, PValue) "
+ " VALUES (%(symbol)s, %(correlation)s, %(pvalue)s)")
+
+ with conn.cursor() as cursor:
+ cursor.execute(create_query)
+ cursor.execute(
+ insert_query,
+ tuple({
+ "symbol": symbol,
+ "correlation": corr,
+ "pvalue": symbol_p_value_dict[symbol]
+ } for symbol, corr in symbol_corr_list[0: 2 * return_number]))
+
+ return temp_table_name
+
+def fetch_tissue_correlations(# pylint: disable=R0913
+ dataset: dict, trait_symbol: str, probeset_freeze_id: int, method: str,
+ return_number: int, conn: Any) -> dict:
+ """
+ Pair tissue correlations data with a trait id string.
+
+ This is a migration of the
+ `web.webqtl.correlation.CorrelationPage.fetchTissueCorrelations` function in
+ GeneNetwork1.
+ """
+ temp_table = build_temporary_tissue_correlations_table(
+ conn, trait_symbol, probeset_freeze_id, method, return_number)
+ with conn.cursor() as cursor:
+ cursor.execute(
+ (
+ f"SELECT ProbeSet.Name, {temp_table}.Correlation, "
+ f"{temp_table}.PValue "
+ "FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze) "
+ "LEFT JOIN {temp_table} ON {temp_table}.Symbol=ProbeSet.Symbol "
+ "WHERE ProbeSetFreeze.Name = %(db_name) "
+ "AND ProbeSetFreeze.Id=ProbeSetXRef.ProbeSetFreezeId "
+ "AND ProbeSet.Id = ProbeSetXRef.ProbeSetId "
+ "AND ProbeSet.Symbol IS NOT NULL "
+ "AND %s.Correlation IS NOT NULL"),
+ db_name=dataset["dataset_name"])
+ results = cursor.fetchall()
+ cursor.execute("DROP TEMPORARY TABLE %s", temp_table)
+ return {
+ trait_name: (tiss_corr, tiss_p_val)
+ for trait_name, tiss_corr, tiss_p_val in results}
+
+def check_for_literature_info(conn: Any, geneid: int) -> bool:
+ """
+ Checks the database to find out whether the trait with `geneid` has any
+ associated literature.
+
+ This is a migration of the
+ `web.webqtl.correlation.CorrelationPage.checkForLitInfo` function in
+ GeneNetwork1.
+ """
+ query = "SELECT 1 FROM LCorrRamin3 WHERE GeneId1=%s LIMIT 1"
+ with conn.cursor() as cursor:
+ cursor.execute(query, geneid)
+ result = cursor.fetchone()
+ if result:
+ return True
+
+ return False
+
+def check_symbol_for_tissue_correlation(
+ conn: Any, tissue_probeset_freeze_id: int, symbol: str = "") -> bool:
+ """
+ Checks whether a symbol has any associated tissue correlations.
+
+ This is a migration of the
+ `web.webqtl.correlation.CorrelationPage.checkSymbolForTissueCorr` function
+ in GeneNetwork1.
+ """
+ query = (
+ "SELECT 1 FROM TissueProbeSetXRef "
+ "WHERE TissueProbeSetFreezeId=%(probeset_freeze_id)s "
+ "AND Symbol=%(symbol)s LIMIT 1")
+ with conn.cursor() as cursor:
+ cursor.execute(
+ query, probeset_freeze_id=tissue_probeset_freeze_id, symbol=symbol)
+ result = cursor.fetchone()
+ if result:
+ return True
+
+ return False
+
+def fetch_sample_ids(
+ conn: Any, sample_names: Tuple[str, ...], species_name: str) -> Tuple[
+ int, ...]:
+ """
+ Given a sequence of sample names, and a species name, return the sample ids
+ that correspond to both.
+
+ This is a partial migration of the
+ `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in
+ GeneNetwork1.
+ """
+ query = (
+ "SELECT Strain.Id FROM Strain, Species "
+ "WHERE Strain.Name IN %(samples_names)s "
+ "AND Strain.SpeciesId=Species.Id "
+ "AND Species.name=%(species_name)s")
+ with conn.cursor() as cursor:
+ cursor.execute(
+ query,
+ {
+ "samples_names": tuple(sample_names),
+ "species_name": species_name
+ })
+ return tuple(row[0] for row in cursor.fetchall())
+
+def build_query_sgo_lit_corr(
+ db_type: str, temp_table: str, sample_id_columns: str,
+ joins: Tuple[str, ...]) -> str:
+ """
+ Build query for `SGO Literature Correlation` data, when querying the given
+ `temp_table` temporary table.
+
+ This is a partial migration of the
+ `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in
+ GeneNetwork1.
+ """
+ return (
+ (f"SELECT {db_type}.Name, {temp_table}.value, " +
+ sample_id_columns +
+ f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " +
+ f"LEFT JOIN {temp_table} ON {temp_table}.GeneId2=ProbeSet.GeneId " +
+ " ".join(joins) +
+ " WHERE ProbeSet.GeneId IS NOT NULL " +
+ f"AND {temp_table}.value IS NOT NULL " +
+ f"AND {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " +
+ f"AND {db_type}Freeze.Name = %(db_name)s " +
+ f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " +
+ f"ORDER BY {db_type}.Id"),
+ 2)
+
+def build_query_tissue_corr(db_type, temp_table, sample_id_columns, joins):
+ """
+ Build query for `Tissue Correlation` data, when querying the given
+ `temp_table` temporary table.
+
+ This is a partial migration of the
+ `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in
+ GeneNetwork1.
+ """
+ return (
+ (f"SELECT {db_type}.Name, {temp_table}.Correlation, " +
+ f"{temp_table}.PValue, " +
+ sample_id_columns +
+ f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " +
+ f"LEFT JOIN {temp_table} ON {temp_table}.Symbol=ProbeSet.Symbol " +
+ " ".join(joins) +
+ " WHERE ProbeSet.Symbol IS NOT NULL " +
+ f"AND {temp_table}.Correlation IS NOT NULL " +
+ f"AND {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " +
+ f"AND {db_type}Freeze.Name = %(db_name)s " +
+ f"AND {db_type}.Id = {db_type}XRef.{db_type}Id "
+ f"ORDER BY {db_type}.Id"),
+ 3)
+
+def fetch_all_database_data(# pylint: disable=[R0913, R0914]
+ conn: Any, species: str, gene_id: int, trait_symbol: str,
+ samples: Tuple[str, ...], dataset: dict, method: str,
+ return_number: int, probeset_freeze_id: int) -> Tuple[
+ Tuple[float], int]:
+ """
+ This is a migration of the
+ `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in
+ GeneNetwork1.
+ """
+ db_type = dataset["dataset_type"]
+ db_name = dataset["dataset_name"]
+ def __build_query__(sample_ids, temp_table):
+ sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in sample_ids)
+ if db_type == "Publish":
+ joins = tuple(
+ ("LEFT JOIN PublishData AS T{item} "
+ "ON T{item}.Id = PublishXRef.DataId "
+ "AND T{item}.StrainId = %(T{item}_sample_id)s")
+ for item in sample_ids)
+ return (
+ ("SELECT PublishXRef.Id, " +
+ sample_id_columns +
+ "FROM (PublishXRef, PublishFreeze) " +
+ " ".join(joins) +
+ " WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
+ "AND PublishFreeze.Name = %(db_name)s"),
+ 1)
+ if temp_table is not None:
+ joins = tuple(
+ (f"LEFT JOIN {db_type}Data AS T{item} "
+ f"ON T{item}.Id = {db_type}XRef.DataId "
+ f"AND T{item}.StrainId=%(T{item}_sample_id)s")
+ for item in sample_ids)
+ if method.lower() == "sgo literature correlation":
+ return build_query_sgo_lit_corr(
+ sample_ids, temp_table, sample_id_columns, joins)
+ if method.lower() in (
+ "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho"):
+ return build_query_tissue_corr(
+ sample_ids, temp_table, sample_id_columns, joins)
+ joins = tuple(
+ (f"LEFT JOIN {db_type}Data AS T{item} "
+ f"ON T{item}.Id = {db_type}XRef.DataId "
+ f"AND T{item}.StrainId = %(T{item}_sample_id)s")
+ for item in sample_ids)
+ return (
+ (
+ f"SELECT {db_type}.Name, " +
+ sample_id_columns +
+ f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " +
+ " ".join(joins) +
+ f" WHERE {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " +
+ f"AND {db_type}Freeze.Name = %(db_name)s " +
+ f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " +
+ f"ORDER BY {db_type}.Id"),
+ 1)
+
+ def __fetch_data__(sample_ids, temp_table):
+ query, data_start_pos = __build_query__(sample_ids, temp_table)
+ with conn.cursor() as cursor:
+ cursor.execute(
+ query,
+ {"db_name": db_name,
+ **{f"T{item}_sample_id": item for item in sample_ids}})
+ return (cursor.fetchall(), data_start_pos)
+
+ sample_ids = tuple(
+ # look into graduating this to an argument and removing the `samples`
+ # and `species` argument: function currying and compositions might help
+ # with this
+ f"{sample_id}" for sample_id in
+ fetch_sample_ids(conn, samples, species))
+
+ temp_table = None
+ if gene_id and db_type == "probeset":
+ if method.lower() == "sgo literature correlation":
+ temp_table = build_temporary_literature_table(
+ conn, species, gene_id, return_number)
+ if method.lower() in (
+ "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho"):
+ temp_table = build_temporary_tissue_correlations_table(
+ conn, trait_symbol, probeset_freeze_id, method, return_number)
+
+ trait_database = tuple(
+ item for sublist in
+ (__fetch_data__(ssample_ids, temp_table)
+ for ssample_ids in partition_all(25, sample_ids))
+ for item in sublist)
+
+ if temp_table:
+ with conn.cursor() as cursor:
+ cursor.execute(f"DROP TEMPORARY TABLE {temp_table}")
+
+ return (trait_database[0], trait_database[1])
diff --git a/gn3/db/species.py b/gn3/db/species.py
index 0deae4e..5b8e096 100644
--- a/gn3/db/species.py
+++ b/gn3/db/species.py
@@ -30,3 +30,47 @@ def get_chromosome(name: str, is_species: bool, conn: Any) -> Optional[Tuple]:
with conn.cursor() as cursor:
cursor.execute(_sql)
return cursor.fetchall()
+
+def translate_to_mouse_gene_id(species: str, geneid: int, conn: Any) -> int:
+ """
+ Translate rat or human geneid to mouse geneid
+
+ This is a migration of the
+ `web.webqtl.correlation/CorrelationPage.translateToMouseGeneID` function in
+ GN1
+ """
+ assert species in ("rat", "mouse", "human"), "Invalid species"
+ if geneid is None:
+ return 0
+
+ if species == "mouse":
+ return geneid
+
+ with conn.cursor as cursor:
+ query = {
+ "rat": "SELECT mouse FROM GeneIDXRef WHERE rat = %s",
+ "human": "SELECT mouse FROM GeneIDXRef WHERE human = %s"
+ }
+ cursor.execute(query[species], geneid)
+ translated_gene_id = cursor.fetchone()
+ if translated_gene_id:
+ return translated_gene_id[0]
+
+ return 0 # default if all else fails
+
+def species_name(conn: Any, group: str) -> str:
+ """
+ Retrieve the name of the species, given the group (RISet).
+
+ This is a migration of the
+ `web.webqtl.dbFunction.webqtlDatabaseFunction.retrieveSpecies` function in
+ GeneNetwork1.
+ """
+ with conn.cursor() as cursor:
+ cursor.execute(
+ ("SELECT Species.Name FROM Species, InbredSet "
+ "WHERE InbredSet.Name = %(group_name)s "
+ "AND InbredSet.SpeciesId = Species.Id"),
+ {"group_name": group})
+ return cursor.fetchone()[0]
+ return None
diff --git a/gn3/db/traits.py b/gn3/db/traits.py
index f2673c8..4098b08 100644
--- a/gn3/db/traits.py
+++ b/gn3/db/traits.py
@@ -1,17 +1,93 @@
"""This class contains functions relating to trait data manipulation"""
import os
+import MySQLdb
+from functools import reduce
from typing import Any, Dict, Union, Sequence
+
+import MySQLdb
+
from gn3.settings import TMPDIR
from gn3.random import random_string
from gn3.function_helpers import compose
from gn3.db.datasets import retrieve_trait_dataset
+def export_trait_data(
+ trait_data: dict, samplelist: Sequence[str], dtype: str = "val",
+ var_exists: bool = False, n_exists: bool = False):
+ """
+ Export data according to `samplelist`. Mostly used in calculating
+ correlations.
+
+ DESCRIPTION:
+ Migrated from
+ https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L166-L211
+
+ PARAMETERS
+ trait: (dict)
+ The dictionary of key-value pairs representing a trait
+ samplelist: (list)
+ A list of sample names
+ dtype: (str)
+ ... verify what this is ...
+ var_exists: (bool)
+ A flag indicating existence of variance
+ n_exists: (bool)
+ A flag indicating existence of ndata
+ """
+ def __export_all_types(tdata, sample):
+ sample_data = []
+ if tdata[sample]["value"]:
+ sample_data.append(tdata[sample]["value"])
+ if var_exists:
+ if tdata[sample]["variance"]:
+ sample_data.append(tdata[sample]["variance"])
+ else:
+ sample_data.append(None)
+ if n_exists:
+ if tdata[sample]["ndata"]:
+ sample_data.append(tdata[sample]["ndata"])
+ else:
+ sample_data.append(None)
+ else:
+ if var_exists and n_exists:
+ sample_data += [None, None, None]
+ elif var_exists or n_exists:
+ sample_data += [None, None]
+ else:
+ sample_data.append(None)
+
+ return tuple(sample_data)
+
+ def __exporter(accumulator, sample):
+ # pylint: disable=[R0911]
+ if sample in trait_data["data"]:
+ if dtype == "val":
+ return accumulator + (trait_data["data"][sample]["value"], )
+ if dtype == "var":
+ return accumulator + (trait_data["data"][sample]["variance"], )
+ if dtype == "N":
+ return accumulator + (trait_data["data"][sample]["ndata"], )
+ if dtype == "all":
+ return accumulator + __export_all_types(trait_data["data"], sample)
+ raise KeyError("Type `%s` is incorrect" % dtype)
+ if var_exists and n_exists:
+ return accumulator + (None, None, None)
+ if var_exists or n_exists:
+ return accumulator + (None, None)
+ return accumulator + (None,)
+
+ return reduce(__exporter, samplelist, tuple())
+
+
def get_trait_csv_sample_data(conn: Any,
trait_name: int, phenotype_id: int):
"""Fetch a trait and return it as a csv string"""
- sql = ("SELECT DISTINCT Strain.Id, PublishData.Id, Strain.Name, "
- "PublishData.value, "
+ def __float_strip(num_str):
+ if str(num_str)[-2:] == ".0":
+ return str(int(num_str))
+ return str(num_str)
+ sql = ("SELECT DISTINCT Strain.Name, PublishData.value, "
"PublishSE.error, NStrain.count FROM "
"(PublishData, Strain, PublishXRef, PublishFreeze) "
"LEFT JOIN PublishSE ON "
@@ -23,65 +99,189 @@ def get_trait_csv_sample_data(conn: Any,
"PublishData.Id = PublishXRef.DataId AND "
"PublishXRef.Id = %s AND PublishXRef.PhenotypeId = %s "
"AND PublishData.StrainId = Strain.Id Order BY Strain.Name")
- csv_data = ["Strain Id,Strain Name,Value,SE,Count"]
- publishdata_id = ""
+ csv_data = ["Strain Name,Value,SE,Count"]
with conn.cursor() as cursor:
cursor.execute(sql, (trait_name, phenotype_id,))
for record in cursor.fetchall():
- (strain_id, publishdata_id,
- strain_name, value, error, count) = record
+ (strain_name, value, error, count) = record
csv_data.append(
- ",".join([str(val) if val else "x"
- for val in (strain_id, strain_name,
- value, error, count)]))
- return f"# Publish Data Id: {publishdata_id}\n\n" + "\n".join(csv_data)
+ ",".join([__float_strip(val) if val else "x"
+ for val in (strain_name, value, error, count)]))
+ return "\n".join(csv_data)
+
+def update_sample_data(conn: Any, #pylint: disable=[R0913]
-def update_sample_data(conn: Any,
+ trait_name: str,
strain_name: str,
- strain_id: int,
- publish_data_id: int,
+ phenotype_id: int,
value: Union[int, float, str],
error: Union[int, float, str],
count: Union[int, str]):
"""Given the right parameters, update sample-data from the relevant
table."""
- # pylint: disable=[R0913, R0914, C0103]
- STRAIN_ID_SQL: str = "UPDATE Strain SET Name = %s WHERE Id = %s"
- PUBLISH_DATA_SQL: str = ("UPDATE PublishData SET value = %s "
- "WHERE StrainId = %s AND Id = %s")
- PUBLISH_SE_SQL: str = ("UPDATE PublishSE SET error = %s "
- "WHERE StrainId = %s AND DataId = %s")
- N_STRAIN_SQL: str = ("UPDATE NStrain SET count = %s "
- "WHERE StrainId = %s AND DataId = %s")
-
- updated_strains: int = 0
+ strain_id, data_id = "", ""
+
+ with conn.cursor() as cursor:
+ cursor.execute(
+ ("SELECT Strain.Id, PublishData.Id FROM "
+ "(PublishData, Strain, PublishXRef, PublishFreeze) "
+ "LEFT JOIN PublishSE ON "
+ "(PublishSE.DataId = PublishData.Id AND "
+ "PublishSE.StrainId = PublishData.StrainId) "
+ "LEFT JOIN NStrain ON "
+ "(NStrain.DataId = PublishData.Id AND "
+ "NStrain.StrainId = PublishData.StrainId) "
+ "WHERE PublishXRef.InbredSetId = "
+ "PublishFreeze.InbredSetId AND "
+ "PublishData.Id = PublishXRef.DataId AND "
+ "PublishXRef.Id = %s AND "
+ "PublishXRef.PhenotypeId = %s "
+ "AND PublishData.StrainId = Strain.Id "
+ "AND Strain.Name = \"%s\"") % (trait_name,
+ phenotype_id,
+ str(strain_name)))
+ strain_id, data_id = cursor.fetchone()
updated_published_data: int = 0
updated_se_data: int = 0
updated_n_strains: int = 0
with conn.cursor() as cursor:
- # Update the Strains table
- cursor.execute(STRAIN_ID_SQL, (strain_name, strain_id))
- updated_strains = cursor.rowcount
# Update the PublishData table
- cursor.execute(PUBLISH_DATA_SQL,
+ cursor.execute(("UPDATE PublishData SET value = %s "
+ "WHERE StrainId = %s AND Id = %s"),
(None if value == "x" else value,
- strain_id, publish_data_id))
+ strain_id, data_id))
updated_published_data = cursor.rowcount
+
# Update the PublishSE table
- cursor.execute(PUBLISH_SE_SQL,
+ cursor.execute(("UPDATE PublishSE SET error = %s "
+ "WHERE StrainId = %s AND DataId = %s"),
(None if error == "x" else error,
- strain_id, publish_data_id))
+ strain_id, data_id))
updated_se_data = cursor.rowcount
+
# Update the NStrain table
- cursor.execute(N_STRAIN_SQL,
+ cursor.execute(("UPDATE NStrain SET count = %s "
+ "WHERE StrainId = %s AND DataId = %s"),
(None if count == "x" else count,
- strain_id, publish_data_id))
+ strain_id, data_id))
updated_n_strains = cursor.rowcount
- return (updated_strains, updated_published_data,
+ return (updated_published_data,
updated_se_data, updated_n_strains)
+
+def delete_sample_data(conn: Any,
+ trait_name: str,
+ strain_name: str,
+ phenotype_id: int):
+ """Given the right parameters, delete sample-data from the relevant
+ table."""
+ strain_id, data_id = "", ""
+
+ deleted_published_data: int = 0
+ deleted_se_data: int = 0
+ deleted_n_strains: int = 0
+
+ with conn.cursor() as cursor:
+ # Delete the PublishData table
+ try:
+ cursor.execute(
+ ("SELECT Strain.Id, PublishData.Id FROM "
+ "(PublishData, Strain, PublishXRef, PublishFreeze) "
+ "LEFT JOIN PublishSE ON "
+ "(PublishSE.DataId = PublishData.Id AND "
+ "PublishSE.StrainId = PublishData.StrainId) "
+ "LEFT JOIN NStrain ON "
+ "(NStrain.DataId = PublishData.Id AND "
+ "NStrain.StrainId = PublishData.StrainId) "
+ "WHERE PublishXRef.InbredSetId = "
+ "PublishFreeze.InbredSetId AND "
+ "PublishData.Id = PublishXRef.DataId AND "
+ "PublishXRef.Id = %s AND "
+ "PublishXRef.PhenotypeId = %s "
+ "AND PublishData.StrainId = Strain.Id "
+ "AND Strain.Name = \"%s\"") % (trait_name,
+ phenotype_id,
+ str(strain_name)))
+ strain_id, data_id = cursor.fetchone()
+
+ cursor.execute(("DELETE FROM PublishData "
+ "WHERE StrainId = %s AND Id = %s")
+ % (strain_id, data_id))
+ deleted_published_data = cursor.rowcount
+
+ # Delete the PublishSE table
+ cursor.execute(("DELETE FROM PublishSE "
+ "WHERE StrainId = %s AND DataId = %s") %
+ (strain_id, data_id))
+ deleted_se_data = cursor.rowcount
+
+ # Delete the NStrain table
+ cursor.execute(("DELETE FROM NStrain "
+ "WHERE StrainId = %s AND DataId = %s" %
+ (strain_id, data_id)))
+ deleted_n_strains = cursor.rowcount
+ except Exception as e: #pylint: disable=[C0103, W0612]
+ conn.rollback()
+ raise MySQLdb.Error
+ conn.commit()
+ cursor.close()
+ cursor.close()
+
+ return (deleted_published_data,
+ deleted_se_data, deleted_n_strains)
+
+
+def insert_sample_data(conn: Any, #pylint: disable=[R0913]
+ trait_name: str,
+ strain_name: str,
+ phenotype_id: int,
+ value: Union[int, float, str],
+ error: Union[int, float, str],
+ count: Union[int, str]):
+ """Given the right parameters, insert sample-data to the relevant table.
+
+ """
+
+ inserted_published_data, inserted_se_data, inserted_n_strains = 0, 0, 0
+ with conn.cursor() as cursor:
+ try:
+ cursor.execute("SELECT DataId FROM PublishXRef WHERE Id = %s AND "
+ "PhenotypeId = %s", (trait_name, phenotype_id))
+ data_id = cursor.fetchone()
+
+ cursor.execute("SELECT Id FROM Strain WHERE Name = %s",
+ (strain_name,))
+ strain_id = cursor.fetchone()
+
+ # Insert the PublishData table
+ cursor.execute(("INSERT INTO PublishData (Id, StrainId, value)"
+ "VALUES (%s, %s, %s)"),
+ (data_id, strain_id, value))
+ inserted_published_data = cursor.rowcount
+
+ # Insert into the PublishSE table if error is specified
+ if error and error != "x":
+ cursor.execute(("INSERT INTO PublishSE (StrainId, DataId, "
+ " error) VALUES (%s, %s, %s)") %
+ (strain_id, data_id, error))
+ inserted_se_data = cursor.rowcount
+
+ # Insert into the NStrain table
+ if count and count != "x":
+ cursor.execute(("INSERT INTO NStrain "
+ "(StrainId, DataId, error) "
+ "VALUES (%s, %s, %s)") %
+ (strain_id, data_id, count))
+ inserted_n_strains = cursor.rowcount
+ except Exception as e: #pylint: disable=[C0103, W0612]
+ conn.rollback()
+ raise MySQLdb.Error
+ return (inserted_published_data,
+ inserted_se_data, inserted_n_strains)
+
+
def retrieve_publish_trait_info(trait_data_source: Dict[str, Any], conn: Any):
"""Retrieve trait information for type `Publish` traits.
@@ -121,11 +321,12 @@ def retrieve_publish_trait_info(trait_data_source: Dict[str, Any], conn: Any):
cursor.execute(
query,
{
- k:v for k, v in trait_data_source.items()
+ k: v for k, v in trait_data_source.items()
if k in ["trait_name", "trait_dataset_id"]
})
return dict(zip([k.lower() for k in keys], cursor.fetchone()))
+
def set_confidential_field(trait_type, trait_info):
"""Post processing function for 'Publish' trait types.
@@ -138,6 +339,7 @@ def set_confidential_field(trait_type, trait_info):
and not trait_info.get("pubmed_id", None)) else 0}
return trait_info
+
def retrieve_probeset_trait_info(trait_data_source: Dict[str, Any], conn: Any):
"""Retrieve trait information for type `ProbeSet` traits.
@@ -165,11 +367,12 @@ def retrieve_probeset_trait_info(trait_data_source: Dict[str, Any], conn: Any):
cursor.execute(
query,
{
- k:v for k, v in trait_data_source.items()
+ k: v for k, v in trait_data_source.items()
if k in ["trait_name", "trait_dataset_name"]
})
return dict(zip(keys, cursor.fetchone()))
+
def retrieve_geno_trait_info(trait_data_source: Dict[str, Any], conn: Any):
"""Retrieve trait information for type `Geno` traits.
@@ -189,11 +392,12 @@ def retrieve_geno_trait_info(trait_data_source: Dict[str, Any], conn: Any):
cursor.execute(
query,
{
- k:v for k, v in trait_data_source.items()
+ k: v for k, v in trait_data_source.items()
if k in ["trait_name", "trait_dataset_name"]
})
return dict(zip(keys, cursor.fetchone()))
+
def retrieve_temp_trait_info(trait_data_source: Dict[str, Any], conn: Any):
"""Retrieve trait information for type `Temp` traits.
@@ -206,11 +410,12 @@ def retrieve_temp_trait_info(trait_data_source: Dict[str, Any], conn: Any):
cursor.execute(
query,
{
- k:v for k, v in trait_data_source.items()
+ k: v for k, v in trait_data_source.items()
if k in ["trait_name"]
})
return dict(zip(keys, cursor.fetchone()))
+
def set_haveinfo_field(trait_info):
"""
Common postprocessing function for all trait types.
@@ -218,6 +423,7 @@ def set_haveinfo_field(trait_info):
Sets the value for the 'haveinfo' field."""
return {**trait_info, "haveinfo": 1 if trait_info else 0}
+
def set_homologene_id_field_probeset(trait_info, conn):
"""
Postprocessing function for 'ProbeSet' traits.
@@ -233,7 +439,7 @@ def set_homologene_id_field_probeset(trait_info, conn):
cursor.execute(
query,
{
- k:v for k, v in trait_info.items()
+ k: v for k, v in trait_info.items()
if k in ["geneid", "group"]
})
res = cursor.fetchone()
@@ -241,12 +447,13 @@ def set_homologene_id_field_probeset(trait_info, conn):
return {**trait_info, "homologeneid": res[0]}
return {**trait_info, "homologeneid": None}
+
def set_homologene_id_field(trait_type, trait_info, conn):
"""
Common postprocessing function for all trait types.
Sets the value for the 'homologene' key."""
- set_to_null = lambda ti: {**ti, "homologeneid": None}
+ def set_to_null(ti): return {**ti, "homologeneid": None} # pylint: disable=[C0103, C0321]
functions_table = {
"Temp": set_to_null,
"Geno": set_to_null,
@@ -255,6 +462,7 @@ def set_homologene_id_field(trait_type, trait_info, conn):
}
return functions_table[trait_type](trait_info)
+
def load_publish_qtl_info(trait_info, conn):
"""
Load extra QTL information for `Publish` traits
@@ -275,6 +483,7 @@ def load_publish_qtl_info(trait_info, conn):
return dict(zip(["locus", "lrs", "additive"], cursor.fetchone()))
return {"locus": "", "lrs": "", "additive": ""}
+
def load_probeset_qtl_info(trait_info, conn):
"""
Load extra QTL information for `ProbeSet` traits
@@ -297,6 +506,7 @@ def load_probeset_qtl_info(trait_info, conn):
["locus", "lrs", "pvalue", "mean", "additive"], cursor.fetchone()))
return {"locus": "", "lrs": "", "pvalue": "", "mean": "", "additive": ""}
+
def load_qtl_info(qtl, trait_type, trait_info, conn):
"""
Load extra QTL information for traits
@@ -325,6 +535,7 @@ def load_qtl_info(qtl, trait_type, trait_info, conn):
return qtl_info_functions[trait_type](trait_info, conn)
+
def build_trait_name(trait_fullname):
"""
Initialises the trait's name, and other values from the search data provided
@@ -351,6 +562,7 @@ def build_trait_name(trait_fullname):
"cellid": name_parts[2] if len(name_parts) == 3 else ""
}
+
def retrieve_probeset_sequence(trait, conn):
"""
Retrieve a 'ProbeSet' trait's sequence information
@@ -372,6 +584,7 @@ def retrieve_probeset_sequence(trait, conn):
seq = cursor.fetchone()
return {**trait, "sequence": seq[0] if seq else ""}
+
def retrieve_trait_info(
threshold: int, trait_full_name: str, conn: Any,
qtl=None):
@@ -427,6 +640,7 @@ def retrieve_trait_info(
}
return trait_info
+
def retrieve_temp_trait_data(trait_info: dict, conn: Any):
"""
Retrieve trait data for `Temp` traits.
@@ -445,10 +659,12 @@ def retrieve_temp_trait_data(trait_info: dict, conn: Any):
query,
{"trait_name": trait_info["trait_name"]})
return [dict(zip(
- ["sample_name", "value", "se_error", "nstrain", "id"], row))
+ ["sample_name", "value", "se_error", "nstrain", "id"],
+ row))
for row in cursor.fetchall()]
return []
+
def retrieve_species_id(group, conn: Any):
"""
Retrieve a species id given the Group value
@@ -460,6 +676,7 @@ def retrieve_species_id(group, conn: Any):
return cursor.fetchone()[0]
return None
+
def retrieve_geno_trait_data(trait_info: Dict, conn: Any):
"""
Retrieve trait data for `Geno` traits.
@@ -483,11 +700,14 @@ def retrieve_geno_trait_data(trait_info: Dict, conn: Any):
"dataset_name": trait_info["db"]["dataset_name"],
"species_id": retrieve_species_id(
trait_info["db"]["group"], conn)})
- return [dict(zip(
- ["sample_name", "value", "se_error", "id"], row))
- for row in cursor.fetchall()]
+ return [
+ dict(zip(
+ ["sample_name", "value", "se_error", "id"],
+ row))
+ for row in cursor.fetchall()]
return []
+
def retrieve_publish_trait_data(trait_info: Dict, conn: Any):
"""
Retrieve trait data for `Publish` traits.
@@ -514,11 +734,13 @@ def retrieve_publish_trait_data(trait_info: Dict, conn: Any):
query,
{"trait_name": trait_info["trait_name"],
"dataset_id": trait_info["db"]["dataset_id"]})
- return [dict(zip(
- ["sample_name", "value", "se_error", "nstrain", "id"], row))
- for row in cursor.fetchall()]
+ return [
+ dict(zip(
+ ["sample_name", "value", "se_error", "nstrain", "id"], row))
+ for row in cursor.fetchall()]
return []
+
def retrieve_cellid_trait_data(trait_info: Dict, conn: Any):
"""
Retrieve trait data for `Probe Data` types.
@@ -547,11 +769,13 @@ def retrieve_cellid_trait_data(trait_info: Dict, conn: Any):
{"cellid": trait_info["cellid"],
"trait_name": trait_info["trait_name"],
"dataset_id": trait_info["db"]["dataset_id"]})
- return [dict(zip(
- ["sample_name", "value", "se_error", "id"], row))
- for row in cursor.fetchall()]
+ return [
+ dict(zip(
+ ["sample_name", "value", "se_error", "id"], row))
+ for row in cursor.fetchall()]
return []
+
def retrieve_probeset_trait_data(trait_info: Dict, conn: Any):
"""
Retrieve trait data for `ProbeSet` traits.
@@ -576,11 +800,13 @@ def retrieve_probeset_trait_data(trait_info: Dict, conn: Any):
query,
{"trait_name": trait_info["trait_name"],
"dataset_name": trait_info["db"]["dataset_name"]})
- return [dict(zip(
- ["sample_name", "value", "se_error", "id"], row))
- for row in cursor.fetchall()]
+ return [
+ dict(zip(
+ ["sample_name", "value", "se_error", "id"], row))
+ for row in cursor.fetchall()]
return []
+
def with_samplelist_data_setup(samplelist: Sequence[str]):
"""
Build function that computes the trait data from provided list of samples.
@@ -607,6 +833,7 @@ def with_samplelist_data_setup(samplelist: Sequence[str]):
return None
return setup_fn
+
def without_samplelist_data_setup():
"""
Build function that computes the trait data.
@@ -627,6 +854,7 @@ def without_samplelist_data_setup():
return None
return setup_fn
+
def retrieve_trait_data(trait: dict, conn: Any, samplelist: Sequence[str] = tuple()):
"""
Retrieve trait data
@@ -666,11 +894,38 @@ def retrieve_trait_data(trait: dict, conn: Any, samplelist: Sequence[str] = tupl
"data": dict(map(
lambda x: (
x["sample_name"],
- {k:v for k, v in x.items() if x != "sample_name"}),
+ {k: v for k, v in x.items() if x != "sample_name"}),
data))}
return {}
+
def generate_traits_filename(base_path: str = TMPDIR):
"""Generate a unique filename for use with generated traits files."""
return "{}/traits_test_file_{}.txt".format(
os.path.abspath(base_path), random_string(10))
+
+
+def export_informative(trait_data: dict, inc_var: bool = False) -> tuple:
+ """
+ Export informative strain
+
+ This is a migration of the `exportInformative` function in
+ web/webqtl/base/webqtlTrait.py module in GeneNetwork1.
+
+ There is a chance that the original implementation has a bug, especially
+ dealing with the `inc_var` value. It the `inc_var` value is meant to control
+ the inclusion of the `variance` value, then the current implementation, and
+ that one in GN1 have a bug.
+ """
+ def __exporter__(acc, data_item):
+ if not inc_var or data_item["variance"] is not None:
+ return (
+ acc[0] + (data_item["sample_name"],),
+ acc[1] + (data_item["value"],),
+ acc[2] + (data_item["variance"],))
+ return acc
+ return reduce(
+ __exporter__,
+ filter(lambda td: td["value"] is not None,
+ trait_data["data"].values()),
+ (tuple(), tuple(), tuple()))
diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py
index adbfbc6..f0af409 100644
--- a/gn3/heatmaps.py
+++ b/gn3/heatmaps.py
@@ -14,6 +14,7 @@ from plotly.subplots import make_subplots # type: ignore
from gn3.settings import TMPDIR
from gn3.random import random_string
from gn3.computations.slink import slink
+from gn3.db.traits import export_trait_data
from gn3.computations.correlations2 import compute_correlation
from gn3.db.genotypes import (
build_genotype_file, load_genotype_samples)
@@ -26,72 +27,6 @@ from gn3.computations.qtlreaper import (
parse_reaper_main_results,
organise_reaper_main_results)
-def export_trait_data(
- trait_data: dict, samplelist: Sequence[str], dtype: str = "val",
- var_exists: bool = False, n_exists: bool = False):
- """
- Export data according to `samplelist`. Mostly used in calculating
- correlations.
-
- DESCRIPTION:
- Migrated from
- https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L166-L211
-
- PARAMETERS
- trait: (dict)
- The dictionary of key-value pairs representing a trait
- samplelist: (list)
- A list of sample names
- dtype: (str)
- ... verify what this is ...
- var_exists: (bool)
- A flag indicating existence of variance
- n_exists: (bool)
- A flag indicating existence of ndata
- """
- def __export_all_types(tdata, sample):
- sample_data = []
- if tdata[sample]["value"]:
- sample_data.append(tdata[sample]["value"])
- if var_exists:
- if tdata[sample]["variance"]:
- sample_data.append(tdata[sample]["variance"])
- else:
- sample_data.append(None)
- if n_exists:
- if tdata[sample]["ndata"]:
- sample_data.append(tdata[sample]["ndata"])
- else:
- sample_data.append(None)
- else:
- if var_exists and n_exists:
- sample_data += [None, None, None]
- elif var_exists or n_exists:
- sample_data += [None, None]
- else:
- sample_data.append(None)
-
- return tuple(sample_data)
-
- def __exporter(accumulator, sample):
- # pylint: disable=[R0911]
- if sample in trait_data["data"]:
- if dtype == "val":
- return accumulator + (trait_data["data"][sample]["value"], )
- if dtype == "var":
- return accumulator + (trait_data["data"][sample]["variance"], )
- if dtype == "N":
- return accumulator + (trait_data["data"][sample]["ndata"], )
- if dtype == "all":
- return accumulator + __export_all_types(trait_data["data"], sample)
- raise KeyError("Type `%s` is incorrect" % dtype)
- if var_exists and n_exists:
- return accumulator + (None, None, None)
- if var_exists or n_exists:
- return accumulator + (None, None)
- return accumulator + (None,)
-
- return reduce(__exporter, samplelist, tuple())
def trait_display_name(trait: Dict):
"""
@@ -129,11 +64,7 @@ def cluster_traits(traits_data_list: Sequence[Dict]):
def __compute_corr(tdata_i, tdata_j):
if tdata_i[0] == tdata_j[0]:
return 0.0
- corr_vals = compute_correlation(tdata_i[1], tdata_j[1])
- corr = corr_vals[0]
- if (1 - corr) < 0:
- return 0.0
- return 1 - corr
+ return 1 - compute_correlation(tdata_i[1], tdata_j[1])[0]
def __cluster(tdata_i):
return tuple(
@@ -168,7 +99,9 @@ def get_loci_names(
__get_trait_loci, [v[1] for v in organised.items()], {})
return tuple(loci_dict[_chr] for _chr in chromosome_names)
-def build_heatmap(traits_names, conn: Any):
+def build_heatmap(
+ traits_names: Sequence[str], conn: Any,
+ vertical: bool = False) -> go.Figure:
"""
heatmap function
@@ -220,17 +153,21 @@ def build_heatmap(traits_names, conn: Any):
zip(traits_ids,
[traits[idx]["trait_fullname"] for idx in traits_order]))
- return generate_clustered_heatmap(
+ return clustered_heatmap(
process_traits_data_for_heatmap(
organised, traits_ids, chromosome_names),
clustered,
- "single_heatmap_{}".format(random_string(10)),
- y_axis=tuple(
- ordered_traits_names[traits_ids[order]]
- for order in traits_order),
- y_label="Traits",
- x_axis=chromosome_names,
- x_label="Chromosomes",
+ x_axis={
+ "label": "Chromosomes",
+ "data": chromosome_names
+ },
+ y_axis={
+ "label": "Traits",
+ "data": tuple(
+ ordered_traits_names[traits_ids[order]]
+ for order in traits_order)
+ },
+ vertical=vertical,
loci_names=get_loci_names(organised, chromosome_names))
def compute_traits_order(slink_data, neworder: tuple = tuple()):
@@ -349,68 +286,81 @@ def process_traits_data_for_heatmap(data, trait_names, chromosome_names):
for chr_name in chromosome_names]
return hdata
-def generate_clustered_heatmap(
- data, clustering_data, image_filename_prefix, x_axis=None,
- x_label: str = "", y_axis=None, y_label: str = "",
+def clustered_heatmap(
+ data: Sequence[Sequence[float]], clustering_data: Sequence[float],
+ x_axis,#: Dict[Union[str, int], Union[str, Sequence[str]]],
+ y_axis: Dict[str, Union[str, Sequence[str]]],
loci_names: Sequence[Sequence[str]] = tuple(),
- output_dir: str = TMPDIR,
- colorscale=((0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))):
+ vertical: bool = False,
+ colorscale: Sequence[Sequence[Union[float, str]]] = (
+ (0.0, '#0000FF'), (0.5, '#00FF00'), (1.0, '#FF0000'))) -> go.Figure:
"""
Generate a dendrogram, and heatmaps for each chromosome, and put them all
into one plot.
"""
# pylint: disable=[R0913, R0914]
- num_cols = 1 + len(x_axis)
+ x_axis_data = x_axis["data"]
+ y_axis_data = y_axis["data"]
+ num_plots = 1 + len(x_axis_data)
fig = make_subplots(
- rows=1,
- cols=num_cols,
- shared_yaxes="rows",
+ rows=num_plots if vertical else 1,
+ cols=1 if vertical else num_plots,
+ shared_xaxes="columns" if vertical else False,
+ shared_yaxes=False if vertical else "rows",
+ vertical_spacing=0.010,
horizontal_spacing=0.001,
- subplot_titles=["distance"] + x_axis,
+ subplot_titles=["" if vertical else x_axis["label"]] + [
+ "Chromosome: {}".format(chromo) if vertical else chromo
+ for chromo in x_axis_data],#+ x_axis_data,
figure=ff.create_dendrogram(
- np.array(clustering_data), orientation="right", labels=y_axis))
+ np.array(clustering_data),
+ orientation="bottom" if vertical else "right",
+ labels=y_axis_data))
hms = [go.Heatmap(
name=chromo,
- x=loci,
- y=y_axis,
+ x=y_axis_data if vertical else loci,
+ y=loci if vertical else y_axis_data,
z=data_array,
+ transpose=vertical,
showscale=False)
for chromo, data_array, loci
- in zip(x_axis, data, loci_names)]
+ in zip(x_axis_data, data, loci_names)]
for i, heatmap in enumerate(hms):
- fig.add_trace(heatmap, row=1, col=(i + 2))
-
- fig.update_layout(
- {
- "width": 1500,
- "height": 800,
- "xaxis": {
+ fig.add_trace(
+ heatmap,
+ row=((i + 2) if vertical else 1),
+ col=(1 if vertical else (i + 2)))
+
+ axes_layouts = {
+ "{axis}axis{count}".format(
+ axis=("y" if vertical else "x"),
+ count=(i+1 if i > 0 else "")): {
"mirror": False,
- "showgrid": True,
- "title": x_label
- },
- "yaxis": {
- "title": y_label
+ "showticklabels": i == 0,
+ "ticks": "outside" if i == 0 else ""
}
- })
+ for i in range(num_plots)}
- x_axes_layouts = {
- "xaxis{}".format(i+1 if i > 0 else ""): {
- "mirror": False,
- "showticklabels": i == 0,
- "ticks": "outside" if i == 0 else ""
- }
- for i in range(num_cols)}
+ print("vertical?: {} ==> {}".format("T" if vertical else "F", axes_layouts))
- fig.update_layout(
- {
- "width": 4000,
- "height": 800,
- "yaxis": {
- "mirror": False,
- "ticks": ""
- },
- **x_axes_layouts})
+ fig.update_layout({
+ "width": 800 if vertical else 4000,
+ "height": 4000 if vertical else 800,
+ "{}axis".format("x" if vertical else "y"): {
+ "mirror": False,
+ "ticks": "",
+ "side": "top" if vertical else "left",
+ "title": y_axis["label"],
+ "tickangle": 90 if vertical else 0,
+ "ticklabelposition": "outside top" if vertical else "outside left"
+ },
+ "{}axis".format("y" if vertical else "x"): {
+ "mirror": False,
+ "showgrid": True,
+ "title": "Distance",
+ "side": "right" if vertical else "top"
+ },
+ **axes_layouts})
fig.update_traces(
showlegend=False,
colorscale=colorscale,
@@ -418,7 +368,5 @@ def generate_clustered_heatmap(
fig.update_traces(
showlegend=True,
showscale=True,
- selector={"name": x_axis[-1]})
- image_filename = "{}/{}.html".format(output_dir, image_filename_prefix)
- fig.write_html(image_filename)
- return image_filename, fig
+ selector={"name": x_axis_data[-1]})
+ return fig
diff --git a/gn3/settings.py b/gn3/settings.py
index 150d96d..0ac6698 100644
--- a/gn3/settings.py
+++ b/gn3/settings.py
@@ -17,14 +17,10 @@ RQTL_WRAPPER = "rqtl_wrapper.R"
SQL_URI = os.environ.get(
"SQL_URI", "mysql://webqtlout:webqtlout@localhost/db_webqtl")
SECRET_KEY = "password"
-SQLALCHEMY_TRACK_MODIFICATIONS = False
# gn2 results only used in fetching dataset info
GN2_BASE_URL = "http://www.genenetwork.org/"
-# biweight script
-BIWEIGHT_RSCRIPT = "~/genenetwork3/scripts/calculate_biweight.R"
-
# wgcna script
WGCNA_RSCRIPT = "wgcna_analysis.R"
# qtlreaper command
@@ -35,13 +31,26 @@ GENOTYPE_FILES = os.environ.get(
"GENOTYPE_FILES", "{}/genotype_files/genotype".format(os.environ.get("HOME")))
# CROSS-ORIGIN SETUP
-CORS_ORIGINS = [
+def parse_env_cors(default):
+ """Parse comma-separated configuration into list of strings."""
+ origins_str = os.environ.get("CORS_ORIGINS", None)
+ if origins_str:
+ return [
+ origin.strip() for origin in origins_str.split(",") if origin != ""]
+ return default
+
+CORS_ORIGINS = parse_env_cors([
"http://localhost:*",
"http://127.0.0.1:*"
-]
+])
CORS_HEADERS = [
"Content-Type",
"Authorization",
"Access-Control-Allow-Credentials"
]
+
+GNSHARE = os.environ.get("GNSHARE", "/gnshare/gn/")
+TEXTDIR = f"{GNSHARE}/web/ProbeSetFreeze_DataMatrix"
+
+ROUND_TO = 10