aboutsummaryrefslogtreecommitdiff
path: root/gn3
diff options
context:
space:
mode:
Diffstat (limited to 'gn3')
-rw-r--r--gn3/api/async_commands.py16
-rw-r--r--gn3/api/correlation.py73
-rw-r--r--gn3/api/ctl.py24
-rw-r--r--gn3/api/general.py7
-rw-r--r--gn3/api/heatmaps.py21
-rw-r--r--gn3/api/rqtl.py2
-rw-r--r--gn3/app.py4
-rw-r--r--gn3/authentication.py20
-rw-r--r--gn3/commands.py34
-rw-r--r--gn3/computations/correlations.py58
-rw-r--r--gn3/computations/correlations2.py36
-rw-r--r--gn3/computations/ctl.py30
-rw-r--r--gn3/computations/diff.py2
-rw-r--r--gn3/computations/gemma.py2
-rw-r--r--gn3/computations/parsers.py2
-rw-r--r--gn3/computations/partial_correlations.py628
-rw-r--r--gn3/computations/partial_correlations_optimised.py244
-rw-r--r--gn3/computations/pca.py189
-rw-r--r--gn3/computations/qtlreaper.py16
-rw-r--r--gn3/computations/rqtl.py5
-rw-r--r--gn3/computations/wgcna.py28
-rw-r--r--gn3/csvcmp.py146
-rw-r--r--gn3/data_helpers.py28
-rw-r--r--gn3/db/correlations.py234
-rw-r--r--gn3/db/datasets.py152
-rw-r--r--gn3/db/genotypes.py44
-rw-r--r--gn3/db/partial_correlations.py791
-rw-r--r--gn3/db/sample_data.py365
-rw-r--r--gn3/db/species.py17
-rw-r--r--gn3/db/traits.py195
-rw-r--r--gn3/db_utils.py7
-rw-r--r--gn3/fs_helpers.py7
-rw-r--r--gn3/heatmaps.py36
-rw-r--r--gn3/responses/__init__.py0
-rw-r--r--gn3/responses/pcorrs_responses.py24
-rw-r--r--gn3/settings.py12
36 files changed, 3067 insertions, 432 deletions
diff --git a/gn3/api/async_commands.py b/gn3/api/async_commands.py
new file mode 100644
index 0000000..c0cf4bb
--- /dev/null
+++ b/gn3/api/async_commands.py
@@ -0,0 +1,16 @@
+"""Endpoints and functions concerning commands run in external processes."""
+import redis
+from flask import jsonify, Blueprint
+
+async_commands = Blueprint("async_commands", __name__)
+
+@async_commands.route("/state/<command_id>")
+def command_state(command_id):
+ """Respond with the current state of command identified by `command_id`."""
+ with redis.Redis(decode_responses=True) as rconn:
+ state = rconn.hgetall(name=command_id)
+ if not state:
+ return jsonify(
+ status=404,
+ error="The command id provided does not exist.")
+ return jsonify(dict(state.items()))
diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py
index 46121f8..7eb7cd6 100644
--- a/gn3/api/correlation.py
+++ b/gn3/api/correlation.py
@@ -1,13 +1,21 @@
"""Endpoints for running correlations"""
+import sys
+from functools import reduce
+
+import redis
from flask import jsonify
from flask import Blueprint
from flask import request
+from flask import current_app
-from gn3.computations.correlations import compute_all_sample_correlation
-from gn3.computations.correlations import compute_all_lit_correlation
-from gn3.computations.correlations import compute_tissue_correlation
-from gn3.computations.correlations import map_shared_keys_to_values
+from gn3.settings import SQL_URI
+from gn3.commands import queue_cmd, compose_pcorrs_command
from gn3.db_utils import database_connector
+from gn3.responses.pcorrs_responses import build_response
+from gn3.computations.correlations import map_shared_keys_to_values
+from gn3.computations.correlations import compute_tissue_correlation
+from gn3.computations.correlations import compute_all_lit_correlation
+from gn3.computations.correlations import compute_all_sample_correlation
correlation = Blueprint("correlation", __name__)
@@ -58,17 +66,15 @@ def compute_lit_corr(species=None, gene_id=None):
might be needed for actual computing of the correlation results
"""
- conn, _cursor_object = database_connector()
- target_traits_gene_ids = request.get_json()
- target_trait_gene_list = list(target_traits_gene_ids.items())
+ with database_connector() as conn:
+ target_traits_gene_ids = request.get_json()
+ target_trait_gene_list = list(target_traits_gene_ids.items())
- lit_corr_results = compute_all_lit_correlation(
- conn=conn, trait_lists=target_trait_gene_list,
- species=species, gene_id=gene_id)
+ lit_corr_results = compute_all_lit_correlation(
+ conn=conn, trait_lists=target_trait_gene_list,
+ species=species, gene_id=gene_id)
- conn.close()
-
- return jsonify(lit_corr_results)
+ return jsonify(lit_corr_results)
@correlation.route("/tissue_corr/<string:corr_method>", methods=["POST"])
@@ -83,3 +89,44 @@ def compute_tissue_corr(corr_method="pearson"):
corr_method=corr_method)
return jsonify(results)
+
+@correlation.route("/partial", methods=["POST"])
+def partial_correlation():
+ """API endpoint for partial correlations."""
+ def trait_fullname(trait):
+ return f"{trait['dataset']}::{trait['trait_name']}"
+
+ def __field_errors__(args):
+ def __check__(acc, field):
+ if args.get(field) is None:
+ return acc + (f"Field '{field}' missing",)
+ return acc
+ return __check__
+
+ def __errors__(request_data, fields):
+ errors = tuple()
+ if request_data is None:
+ return ("No request data",)
+
+ return reduce(__field_errors__(request_data), fields, errors)
+
+ args = request.get_json()
+ request_errors = __errors__(
+ args, ("primary_trait", "control_traits", "target_db", "method"))
+ if request_errors:
+ return build_response({
+ "status": "error",
+ "messages": request_errors,
+ "error_type": "Client Error"})
+ return build_response({
+ "status": "success",
+ "results": queue_cmd(
+ conn=redis.Redis(),
+ cmd=compose_pcorrs_command(
+ trait_fullname(args["primary_trait"]),
+ tuple(
+ trait_fullname(trait) for trait in args["control_traits"]),
+ args["method"], args["target_db"],
+ int(args.get("criteria", 500))),
+ job_queue=current_app.config.get("REDIS_JOB_QUEUE"),
+ env = {"PYTHONPATH": ":".join(sys.path), "SQL_URI": SQL_URI})})
diff --git a/gn3/api/ctl.py b/gn3/api/ctl.py
new file mode 100644
index 0000000..ac33d63
--- /dev/null
+++ b/gn3/api/ctl.py
@@ -0,0 +1,24 @@
+"""module contains endpoints for ctl"""
+
+from flask import Blueprint
+from flask import request
+from flask import jsonify
+
+from gn3.computations.ctl import call_ctl_script
+
+ctl = Blueprint("ctl", __name__)
+
+
+@ctl.route("/run_ctl", methods=["POST"])
+def run_ctl():
+ """endpoint to run ctl
+ input: request form object
+ output:json object enum::(response,error)
+
+ """
+ ctl_data = request.json
+
+ (cmd_results, response) = call_ctl_script(ctl_data)
+ return (jsonify({
+ "results": response
+ }), 200) if response is not None else (jsonify({"error": str(cmd_results)}), 401)
diff --git a/gn3/api/general.py b/gn3/api/general.py
index 69ec343..e0bfc81 100644
--- a/gn3/api/general.py
+++ b/gn3/api/general.py
@@ -7,7 +7,7 @@ from flask import request
from gn3.fs_helpers import extract_uploaded_file
from gn3.commands import run_cmd
-
+from gn3.db import datasets
general = Blueprint("general", __name__)
@@ -68,3 +68,8 @@ def run_r_qtl(geno_filestr, pheno_filestr):
cmd = (f"Rscript {rqtl_wrapper} "
f"{geno_filestr} {pheno_filestr}")
return jsonify(run_cmd(cmd)), 201
+
+@general.route("/dataset/<accession_id>")
+def dataset_metadata(accession_id):
+ """Return info as JSON for dataset with ACCESSION_ID."""
+ return jsonify(datasets.dataset_metadata(accession_id))
diff --git a/gn3/api/heatmaps.py b/gn3/api/heatmaps.py
index 633a061..80c8ca8 100644
--- a/gn3/api/heatmaps.py
+++ b/gn3/api/heatmaps.py
@@ -24,15 +24,14 @@ def clustered_heatmaps():
return jsonify({
"message": "You need to provide at least two trait names."
}), 400
- conn, _cursor = database_connector()
- def parse_trait_fullname(trait):
- name_parts = trait.split(":")
- return "{dataset_name}::{trait_name}".format(
- dataset_name=name_parts[1], trait_name=name_parts[0])
- traits_fullnames = [parse_trait_fullname(trait) for trait in traits_names]
+ with database_connector() as conn:
+ def parse_trait_fullname(trait):
+ name_parts = trait.split(":")
+ return f"{name_parts[1]}::{name_parts[0]}"
+ traits_fullnames = [parse_trait_fullname(trait) for trait in traits_names]
- with io.StringIO() as io_str:
- figure = build_heatmap(traits_fullnames, conn, vertical=vertical)
- figure.write_json(io_str)
- fig_json = io_str.getvalue()
- return fig_json, 200
+ with io.StringIO() as io_str:
+ figure = build_heatmap(traits_fullnames, conn, vertical=vertical)
+ figure.write_json(io_str)
+ fig_json = io_str.getvalue()
+ return fig_json, 200
diff --git a/gn3/api/rqtl.py b/gn3/api/rqtl.py
index 85b2460..70ebe12 100644
--- a/gn3/api/rqtl.py
+++ b/gn3/api/rqtl.py
@@ -25,7 +25,7 @@ run the rqtl_wrapper script and return the results as JSON
raise FileNotFoundError
# Split kwargs by those with values and boolean ones that just convert to True/False
- kwargs = ["model", "method", "nperm", "scale", "control_marker"]
+ kwargs = ["covarstruct", "model", "method", "nperm", "scale", "control_marker"]
boolean_kwargs = ["addcovar", "interval", "pstrata", "pairscan"]
all_kwargs = kwargs + boolean_kwargs
diff --git a/gn3/app.py b/gn3/app.py
index 3d68b3f..790e87c 100644
--- a/gn3/app.py
+++ b/gn3/app.py
@@ -14,6 +14,8 @@ from gn3.api.heatmaps import heatmaps
from gn3.api.correlation import correlation
from gn3.api.data_entry import data_entry
from gn3.api.wgcna import wgcna
+from gn3.api.ctl import ctl
+from gn3.api.async_commands import async_commands
def create_app(config: Union[Dict, str, None] = None) -> Flask:
"""Create a new flask object"""
@@ -45,4 +47,6 @@ def create_app(config: Union[Dict, str, None] = None) -> Flask:
app.register_blueprint(correlation, url_prefix="/api/correlation")
app.register_blueprint(data_entry, url_prefix="/api/dataentry")
app.register_blueprint(wgcna, url_prefix="/api/wgcna")
+ app.register_blueprint(ctl, url_prefix="/api/ctl")
+ app.register_blueprint(async_commands, url_prefix="/api/async_commands")
return app
diff --git a/gn3/authentication.py b/gn3/authentication.py
index 6719631..d0b35bc 100644
--- a/gn3/authentication.py
+++ b/gn3/authentication.py
@@ -113,9 +113,9 @@ def get_groups_by_user_uid(user_uid: str, conn: Redis) -> Dict:
"""
admin = []
member = []
- for uuid, group_info in conn.hgetall("groups").items():
+ for group_uuid, group_info in conn.hgetall("groups").items():
group_info = json.loads(group_info)
- group_info["uuid"] = uuid
+ group_info["uuid"] = group_uuid
if user_uid in group_info.get('admins'):
admin.append(group_info)
if user_uid in group_info.get('members'):
@@ -130,11 +130,10 @@ def get_user_info_by_key(key: str, value: str,
conn: Redis) -> Optional[Dict]:
"""Given a key, get a user's information if value is matched"""
if key != "user_id":
- for uuid, user_info in conn.hgetall("users").items():
+ for user_uuid, user_info in conn.hgetall("users").items():
user_info = json.loads(user_info)
- if (key in user_info and
- user_info.get(key) == value):
- user_info["user_id"] = uuid
+ if (key in user_info and user_info.get(key) == value):
+ user_info["user_id"] = user_uuid
return user_info
elif key == "user_id":
if user_info := conn.hget("users", value):
@@ -145,9 +144,13 @@ def get_user_info_by_key(key: str, value: str,
def create_group(conn: Redis, group_name: Optional[str],
- admin_user_uids: List = [],
- member_user_uids: List = []) -> Optional[Dict]:
+ admin_user_uids: List = None,
+ member_user_uids: List = None) -> Optional[Dict]:
"""Create a group given the group name, members and admins of that group."""
+ if admin_user_uids is None:
+ admin_user_uids = []
+ if member_user_uids is None:
+ member_user_uids = []
if group_name and bool(admin_user_uids + member_user_uids):
timestamp = datetime.datetime.utcnow().strftime('%b %d %Y %I:%M%p')
group = {
@@ -160,3 +163,4 @@ def create_group(conn: Redis, group_name: Optional[str],
}
conn.hset("groups", group_id, json.dumps(group))
return group
+ return None
diff --git a/gn3/commands.py b/gn3/commands.py
index 7d42ced..e622068 100644
--- a/gn3/commands.py
+++ b/gn3/commands.py
@@ -1,5 +1,8 @@
"""Procedures used to work with the various bio-informatics cli
commands"""
+import os
+import sys
+import json
import subprocess
from datetime import datetime
@@ -7,6 +10,8 @@ from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
+from typing import Union
+from typing import Sequence
from uuid import uuid4
from redis.client import Redis # Used only in type hinting
@@ -46,10 +51,21 @@ def compose_rqtl_cmd(rqtl_wrapper_cmd: str,
return cmd
+def compose_pcorrs_command(
+ primary_trait: str, control_traits: Tuple[str, ...], method: str,
+ target_database: str, criteria: int = 500):
+ """Compose the command to run partias correlations"""
+ rundir = os.path.abspath(".")
+ return (
+ f"{sys.executable}", f"{rundir}/scripts/partial_correlations.py",
+ primary_trait, ",".join(control_traits), f'"{method}"',
+ f"{target_database}", f"--criteria={criteria}")
+
def queue_cmd(conn: Redis,
job_queue: str,
- cmd: str,
- email: Optional[str] = None) -> str:
+ cmd: Union[str, Sequence[str]],
+ email: Optional[str] = None,
+ env: Optional[dict] = None) -> str:
"""Given a command CMD; (optional) EMAIL; and a redis connection CONN, queue
it in Redis with an initial status of 'queued'. The following status codes
are supported:
@@ -68,17 +84,23 @@ Returns the name of the specific redis hash for the specific task.
f"{datetime.now().strftime('%Y-%m-%d%H-%M%S-%M%S-')}"
f"{str(uuid4())}")
conn.rpush(job_queue, unique_id)
- for key, value in {"cmd": cmd, "result": "", "status": "queued"}.items():
+ for key, value in {
+ "cmd": json.dumps(cmd), "result": "", "status": "queued"}.items():
conn.hset(name=unique_id, key=key, value=value)
if email:
conn.hset(name=unique_id, key="email", value=email)
+ if env:
+ conn.hset(name=unique_id, key="env", value=json.dumps(env))
return unique_id
-def run_cmd(cmd: str, success_codes: Tuple = (0,)) -> Dict:
+def run_cmd(cmd: str, success_codes: Tuple = (0,), env: str = None) -> Dict:
"""Run CMD and return the CMD's status code and output as a dict"""
- results = subprocess.run(cmd, capture_output=True, shell=True,
- check=False)
+ parsed_cmd = json.loads(cmd)
+ parsed_env = (json.loads(env) if env is not None else None)
+ results = subprocess.run(
+ parsed_cmd, capture_output=True, shell=isinstance(parsed_cmd, str),
+ check=False, env=parsed_env)
out = str(results.stdout, 'utf-8')
if results.returncode not in success_codes: # Error!
out = str(results.stderr, 'utf-8')
diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py
index c5c56db..a0da2c4 100644
--- a/gn3/computations/correlations.py
+++ b/gn3/computations/correlations.py
@@ -7,6 +7,7 @@ from typing import List
from typing import Tuple
from typing import Optional
from typing import Callable
+from typing import Generator
import scipy.stats
import pingouin as pg
@@ -38,20 +39,15 @@ def map_shared_keys_to_values(target_sample_keys: List,
return target_dataset_data
-def normalize_values(a_values: List,
- b_values: List) -> Tuple[List[float], List[float], int]:
- """Trim two lists of values to contain only the values they both share Given
- two lists of sample values, trim each list so that it contains only the
- samples that contain a value in both lists. Also returns the number of
- such samples.
-
- >>> normalize_values([2.3, None, None, 3.2, 4.1, 5],
- [3.4, 7.2, 1.3, None, 6.2, 4.1])
- ([2.3, 4.1, 5], [3.4, 6.2, 4.1], 3)
-
+def normalize_values(a_values: List, b_values: List) -> Generator:
+ """
+ :param a_values: list of primary strain values
+ :param b_values: a list of target strain values
+ :return: yield 2 values if none of them is none
"""
+
for a_val, b_val in zip(a_values, b_values):
- if (a_val and b_val is not None):
+ if (a_val is not None) and (b_val is not None):
yield a_val, b_val
@@ -79,15 +75,18 @@ def compute_sample_r_correlation(trait_name, corr_method, trait_vals,
"""
- sanitized_traits_vals, sanitized_target_vals = list(
- zip(*list(normalize_values(trait_vals, target_samples_vals))))
- num_overlap = len(sanitized_traits_vals)
+ try:
+ normalized_traits_vals, normalized_target_vals = list(
+ zip(*list(normalize_values(trait_vals, target_samples_vals))))
+ num_overlap = len(normalized_traits_vals)
+ except ValueError:
+ return None
if num_overlap > 5:
(corr_coefficient, p_value) =\
- compute_corr_coeff_p_value(primary_values=sanitized_traits_vals,
- target_values=sanitized_target_vals,
+ compute_corr_coeff_p_value(primary_values=normalized_traits_vals,
+ target_values=normalized_target_vals,
corr_method=corr_method)
if corr_coefficient is not None and not math.isnan(corr_coefficient):
@@ -108,7 +107,7 @@ package :not packaged in guix
def filter_shared_sample_keys(this_samplelist,
- target_samplelist) -> Tuple[List, List]:
+ target_samplelist) -> Generator:
"""Given primary and target sample-list for two base and target trait select
filter the values using the shared keys
@@ -134,9 +133,16 @@ def fast_compute_all_sample_correlation(this_trait,
for target_trait in target_dataset:
trait_name = target_trait.get("trait_id")
target_trait_data = target_trait["trait_sample_data"]
- processed_values.append((trait_name, corr_method,
- list(zip(*list(filter_shared_sample_keys(
- this_trait_samples, target_trait_data))))))
+
+ try:
+ this_vals, target_vals = list(zip(*list(filter_shared_sample_keys(
+ this_trait_samples, target_trait_data))))
+
+ processed_values.append(
+ (trait_name, corr_method, this_vals, target_vals))
+ except ValueError:
+ continue
+
with closing(multiprocessing.Pool()) as pool:
results = pool.starmap(compute_sample_r_correlation, processed_values)
@@ -168,8 +174,14 @@ def compute_all_sample_correlation(this_trait,
for target_trait in target_dataset:
trait_name = target_trait.get("trait_id")
target_trait_data = target_trait["trait_sample_data"]
- this_vals, target_vals = list(zip(*list(filter_shared_sample_keys(
- this_trait_samples, target_trait_data))))
+
+ try:
+ this_vals, target_vals = list(zip(*list(filter_shared_sample_keys(
+ this_trait_samples, target_trait_data))))
+
+ except ValueError:
+ # case where no matching strain names
+ continue
sample_correlation = compute_sample_r_correlation(
trait_name=trait_name,
diff --git a/gn3/computations/correlations2.py b/gn3/computations/correlations2.py
index 93db3fa..d0222ae 100644
--- a/gn3/computations/correlations2.py
+++ b/gn3/computations/correlations2.py
@@ -6,45 +6,21 @@ FUNCTIONS:
compute_correlation:
TODO: Describe what the function does..."""
-from math import sqrt
-from functools import reduce
+from scipy import stats
## From GN1: mostly for clustering and heatmap generation
def __items_with_values(dbdata, userdata):
"""Retains only corresponding items in the data items that are not `None` values.
This should probably be renamed to something sensible"""
- def both_not_none(item1, item2):
- """Check that both items are not the value `None`."""
- if (item1 is not None) and (item2 is not None):
- return (item1, item2)
- return None
- def split_lists(accumulator, item):
- """Separate the 'x' and 'y' items."""
- return [accumulator[0] + [item[0]], accumulator[1] + [item[1]]]
- return reduce(
- split_lists,
- filter(lambda x: x is not None, map(both_not_none, dbdata, userdata)),
- [[], []])
+ filtered = [x for x in zip(dbdata, userdata) if x[0] is not None and x[1] is not None]
+ return tuple(zip(*filtered)) if filtered else ([], [])
def compute_correlation(dbdata, userdata):
- """Compute some form of correlation.
+ """Compute the Pearson correlation coefficient.
This is extracted from
https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/utility/webqtlUtil.py#L622-L647
"""
x_items, y_items = __items_with_values(dbdata, userdata)
- if len(x_items) < 6:
- return (0.0, len(x_items))
- meanx = sum(x_items)/len(x_items)
- meany = sum(y_items)/len(y_items)
- def cal_corr_vals(acc, item):
- xitem, yitem = item
- return [
- acc[0] + ((xitem - meanx) * (yitem - meany)),
- acc[1] + ((xitem - meanx) * (xitem - meanx)),
- acc[2] + ((yitem - meany) * (yitem - meany))]
- xyd, sxd, syd = reduce(cal_corr_vals, zip(x_items, y_items), [0.0, 0.0, 0.0])
- try:
- return ((xyd/(sqrt(sxd)*sqrt(syd))), len(x_items))
- except ZeroDivisionError:
- return(0, len(x_items))
+ correlation = stats.pearsonr(x_items, y_items)[0] if len(x_items) >= 6 else 0
+ return (correlation, len(x_items))
diff --git a/gn3/computations/ctl.py b/gn3/computations/ctl.py
new file mode 100644
index 0000000..f881410
--- /dev/null
+++ b/gn3/computations/ctl.py
@@ -0,0 +1,30 @@
+"""module contains code to process ctl analysis data"""
+import json
+from gn3.commands import run_cmd
+
+from gn3.computations.wgcna import dump_wgcna_data
+from gn3.computations.wgcna import compose_wgcna_cmd
+from gn3.computations.wgcna import process_image
+
+from gn3.settings import TMPDIR
+
+
+def call_ctl_script(data):
+ """function to call ctl script"""
+ data["imgDir"] = TMPDIR
+ temp_file_name = dump_wgcna_data(data)
+ cmd = compose_wgcna_cmd("ctl_analysis.R", temp_file_name)
+
+ cmd_results = run_cmd(cmd)
+ with open(temp_file_name, "r", encoding="utf-8") as outputfile:
+ if cmd_results["code"] != 0:
+ return (cmd_results, None)
+ output_file_data = json.load(outputfile)
+
+ output_file_data["image_data"] = process_image(
+ output_file_data["image_loc"]).decode("ascii")
+
+ output_file_data["ctl_plots"] = [process_image(ctl_plot).decode("ascii") for
+ ctl_plot in output_file_data["ctl_plots"]]
+
+ return (cmd_results, output_file_data)
diff --git a/gn3/computations/diff.py b/gn3/computations/diff.py
index af02f7f..0b6edd6 100644
--- a/gn3/computations/diff.py
+++ b/gn3/computations/diff.py
@@ -6,7 +6,7 @@ from gn3.commands import run_cmd
def generate_diff(data: str, edited_data: str) -> Optional[str]:
"""Generate the diff between 2 files"""
- results = run_cmd(f"diff {data} {edited_data}", success_codes=(1, 2))
+ results = run_cmd(f'"diff {data} {edited_data}"', success_codes=(1, 2))
if results.get("code", -1) > 0:
return results.get("output")
return None
diff --git a/gn3/computations/gemma.py b/gn3/computations/gemma.py
index 0b22d3c..8036a7b 100644
--- a/gn3/computations/gemma.py
+++ b/gn3/computations/gemma.py
@@ -31,7 +31,7 @@ def generate_pheno_txt_file(trait_filename: str,
# Early return if this already exists!
if os.path.isfile(f"{tmpdir}/gn2/{trait_filename}"):
return f"{tmpdir}/gn2/{trait_filename}"
- with open(f"{tmpdir}/gn2/{trait_filename}", "w") as _file:
+ with open(f"{tmpdir}/gn2/{trait_filename}", "w", encoding="utf-8") as _file:
for value in values:
if value == "x":
_file.write("NA\n")
diff --git a/gn3/computations/parsers.py b/gn3/computations/parsers.py
index 1af35d6..79e3955 100644
--- a/gn3/computations/parsers.py
+++ b/gn3/computations/parsers.py
@@ -15,7 +15,7 @@ def parse_genofile(file_path: str) -> Tuple[List[str],
'u': None,
}
genotypes, samples = [], []
- with open(file_path, "r") as _genofile:
+ with open(file_path, "r", encoding="utf-8") as _genofile:
for line in _genofile:
line = line.strip()
if line.startswith(("#", "@")):
diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py
index 07dc16d..5017796 100644
--- a/gn3/computations/partial_correlations.py
+++ b/gn3/computations/partial_correlations.py
@@ -5,12 +5,32 @@ It is an attempt to migrate over the partial correlations feature from
GeneNetwork1.
"""
-from functools import reduce
-from typing import Any, Tuple, Sequence
+import math
+import warnings
+from functools import reduce, partial
+from typing import Any, Tuple, Union, Sequence
+
+import numpy
+import pandas
+import pingouin
from scipy.stats import pearsonr, spearmanr
from gn3.settings import TEXTDIR
+from gn3.random import random_string
+from gn3.function_helpers import compose
from gn3.data_helpers import parse_csv_line
+from gn3.db.traits import export_informative
+from gn3.db.datasets import retrieve_trait_dataset
+from gn3.db.partial_correlations import traits_info, traits_data
+from gn3.db.species import species_name, translate_to_mouse_gene_id
+from gn3.db.correlations import (
+ get_filename,
+ fetch_all_database_data,
+ check_for_literature_info,
+ fetch_tissue_correlations,
+ fetch_literature_correlations,
+ check_symbol_for_tissue_correlation,
+ fetch_gene_symbol_tissue_value_dict_for_trait)
def control_samples(controls: Sequence[dict], sampleslist: Sequence[str]):
"""
@@ -40,7 +60,7 @@ def control_samples(controls: Sequence[dict], sampleslist: Sequence[str]):
__process_sample__, sampleslist, (tuple(), tuple(), tuple()))
return reduce(
- lambda acc, item: (
+ lambda acc, item: (# type: ignore[arg-type, return-value]
acc[0] + (item[0],),
acc[1] + (item[1],),
acc[2] + (item[2],),
@@ -49,22 +69,6 @@ def control_samples(controls: Sequence[dict], sampleslist: Sequence[str]):
[__process_control__(trait_data) for trait_data in controls],
(tuple(), tuple(), tuple(), tuple()))
-def dictify_by_samples(samples_vals_vars: Sequence[Sequence]) -> Sequence[dict]:
- """
- Build a sequence of dictionaries from a sequence of separate sequences of
- samples, values and variances.
-
- This is a partial migration of
- `web.webqtl.correlation.correlationFunction.fixStrains` function in GN1.
- This implementation extracts code that will find common use, and that will
- find use in more than one place.
- """
- return tuple(
- {
- sample: {"sample_name": sample, "value": val, "variance": var}
- for sample, val, var in zip(*trait_line)
- } for trait_line in zip(*(samples_vals_vars[0:3])))
-
def fix_samples(primary_trait: dict, control_traits: Sequence[dict]) -> Sequence[Sequence[Any]]:
"""
Corrects sample_names, values and variance such that they all contain only
@@ -108,7 +112,7 @@ def find_identical_traits(
return acc + ident[1]
def __dictify_controls__(acc, control_item):
- ckey = "{:.3f}".format(control_item[0])
+ ckey = tuple(f"{item:.3f}" for item in control_item[0])
return {**acc, ckey: acc.get(ckey, tuple()) + (control_item[1],)}
return (reduce(## for identical control traits
@@ -148,11 +152,11 @@ def tissue_correlation(
assert len(primary_trait_values) == len(target_trait_values), (
"The lengths of the `primary_trait_values` and `target_trait_values` "
"must be equal")
- assert method in method_fns.keys(), (
- "Method must be one of: {}".format(",".join(method_fns.keys())))
+ assert method in method_fns, (
+ "Method must be one of: {','.join(method_fns.keys())}")
corr, pvalue = method_fns[method](primary_trait_values, target_trait_values)
- return (round(corr, 10), round(pvalue, 10))
+ return (corr, pvalue)
def batch_computed_tissue_correlation(
primary_trait_values: Tuple[float, ...], target_traits_dict: dict,
@@ -196,33 +200,19 @@ def good_dataset_samples_indexes(
samples_from_file.index(good) for good in
set(samples).intersection(set(samples_from_file))))
-def determine_partials(
- primary_vals, control_vals, all_target_trait_names,
- all_target_trait_values, method):
- """
- This **WILL** be a migration of
- `web.webqtl.correlation.correlationFunction.determinePartialsByR` function
- in GeneNetwork1.
-
- The function in GeneNetwork1 contains code written in R that is then used to
- compute the partial correlations.
- """
- ## This function is not implemented at this stage
- return tuple(
- primary_vals, control_vals, all_target_trait_names,
- all_target_trait_values, method)
-
-def compute_partial_correlations_fast(# pylint: disable=[R0913, R0914]
+def partial_correlations_fast(# pylint: disable=[R0913, R0914]
samples, primary_vals, control_vals, database_filename,
fetched_correlations, method: str, correlation_type: str) -> Tuple[
- float, Tuple[float, ...]]:
+ int, Tuple[float, ...]]:
"""
+ Computes partial correlation coefficients using data from a CSV file.
+
This is a partial migration of the
`web.webqtl.correlation.PartialCorrDBPage.getPartialCorrelationsFast`
function in GeneNetwork1.
"""
assert method in ("spearman", "pearson")
- with open(f"{TEXTDIR}/{database_filename}", "r") as dataset_file:
+ with open(database_filename, "r", encoding="utf-8") as dataset_file: # pytest: disable=[W1514]
dataset = tuple(dataset_file.readlines())
good_dataset_samples = good_dataset_samples_indexes(
@@ -245,7 +235,7 @@ def compute_partial_correlations_fast(# pylint: disable=[R0913, R0914]
all_target_trait_names: Tuple[str, ...] = processed_trait_names_values[0]
all_target_trait_values: Tuple[float, ...] = processed_trait_names_values[1]
- all_correlations = determine_partials(
+ all_correlations = compute_partial(
primary_vals, control_vals, all_target_trait_names,
all_target_trait_values, method)
## Line 772 to 779 in GN1 are the cause of the weird complexity in the
@@ -254,36 +244,544 @@ def compute_partial_correlations_fast(# pylint: disable=[R0913, R0914]
## `correlation_type` parameter
return len(all_correlations), tuple(
corr + (
- (fetched_correlations[corr[0]],) if correlation_type == "literature"
- else fetched_correlations[corr[0]][0:2])
+ (fetched_correlations[corr[0]],) # type: ignore[index]
+ if correlation_type == "literature"
+ else fetched_correlations[corr[0]][0:2]) # type: ignore[index]
for idx, corr in enumerate(all_correlations))
-def partial_correlation_matrix(
+def build_data_frame(
xdata: Tuple[float, ...], ydata: Tuple[float, ...],
- zdata: Tuple[float, ...], method: str = "pearsons",
- omit_nones: bool = True) -> float:
+ zdata: Union[
+ Tuple[float, ...],
+ Tuple[Tuple[float, ...], ...]]) -> pandas.DataFrame:
+ """
+ Build a pandas DataFrame object from xdata, ydata and zdata
+ """
+ x_y_df = pandas.DataFrame({"x": xdata, "y": ydata})
+ if isinstance(zdata[0], float):
+ return x_y_df.join(pandas.DataFrame({"z": zdata}))
+ interm_df = x_y_df.join(pandas.DataFrame(
+ {f"z{i}": val for i, val in enumerate(zdata)}))
+ if interm_df.shape[1] == 3:
+ return interm_df.rename(columns={"z0": "z"})
+ return interm_df
+
+def compute_trait_info(primary_vals, control_vals, target, method):
"""
- Computes the partial correlation coefficient using the
- 'variance-covariance matrix' method
+ Compute the correlation values for the given arguments.
+ """
+ targ_vals = target[0]
+ targ_name = target[1]
+ primary = [
+ prim for targ, prim in zip(targ_vals, primary_vals)
+ if targ is not None]
+
+ if len(primary) < 3:
+ return None
+
+ def __remove_controls_for_target_nones(cont_targ):
+ return tuple(cont for cont, targ in cont_targ if targ is not None)
+
+ datafrm = build_data_frame(
+ primary,
+ [targ for targ in targ_vals if targ is not None],
+ [__remove_controls_for_target_nones(tuple(zip(control, targ_vals)))
+ for control in control_vals])
+ covariates = "z" if datafrm.shape[1] == 3 else [
+ col for col in datafrm.columns if col not in ("x", "y")]
+ ppc = pingouin.partial_corr(
+ data=datafrm, x="x", y="y", covar=covariates, method=(
+ "pearson" if "pearson" in method.lower() else "spearman"))
+ pc_coeff = ppc["r"][0]
+
+ zero_order_corr = pingouin.corr(
+ datafrm["x"], datafrm["y"], method=(
+ "pearson" if "pearson" in method.lower() else "spearman"))
+
+ if math.isnan(pc_coeff):
+ return (
+ targ_name, len(primary), pc_coeff, 1, zero_order_corr["r"][0],
+ zero_order_corr["p-val"][0])
+ return (
+ targ_name, len(primary), pc_coeff,
+ (ppc["p-val"][0] if not math.isnan(ppc["p-val"][0]) else (
+ 0 if (abs(pc_coeff - 1) < 0.0000001) else 1)),
+ zero_order_corr["r"][0], zero_order_corr["p-val"][0])
+
+def compute_partial(
+ primary_vals, control_vals, target_vals, target_names,
+ method: str) -> Tuple[
+ Union[
+ Tuple[str, int, float, float, float, float], None],
+ ...]:
+ """
+ Compute the partial correlations.
- This is a partial migration of the
- `web.webqtl.correlation.correlationFunction.determinPartialsByR` function in
- GeneNetwork1, specifically the `pcor.mat` function written in the R
- programming language.
+ This is a re-implementation of the
+ `web.webqtl.correlation.correlationFunction.determinePartialsByR` function
+ in GeneNetwork1.
+
+ This implementation reworks the child function `compute_partial` which will
+ then be used in the place of `determinPartialsByR`.
+ """
+ return tuple(
+ result for result in (
+ compute_trait_info(
+ primary_vals, control_vals, (tvals, tname), method)
+ for tvals, tname in zip(target_vals, target_names))
+ if result is not None)
+
+def partial_correlations_normal(# pylint: disable=R0913
+ primary_vals, control_vals, input_trait_gene_id, trait_database,
+ data_start_pos: int, db_type: str, method: str) -> Tuple[
+ int, Tuple[Union[
+ Tuple[str, int, float, float, float, float], None],
+ ...]]:#Tuple[float, ...]
"""
- return 0
+ Computes the correlation coefficients.
-def partial_correlation_recursive(
- xdata: Tuple[float, ...], ydata: Tuple[float, ...],
- zdata: Tuple[float, ...], method: str = "pearsons",
- omit_nones: bool = True) -> float:
+ This is a migration of the
+ `web.webqtl.correlation.PartialCorrDBPage.getPartialCorrelationsNormal`
+ function in GeneNetwork1.
"""
- Computes the partial correlation coefficient using the 'recursive formula'
- method
+ def __add_lit_and_tiss_corr__(item):
+ if method.lower() == "sgo literature correlation":
+ # if method is 'SGO Literature Correlation', `compute_partial`
+ # would give us LitCorr in the [1] position
+ return tuple(item) + trait_database[1]
+ if method.lower() in (
+ "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho"):
+ # if method is 'Tissue Correlation, *', `compute_partial` would give
+ # us Tissue Corr in the [1] position and Tissue Corr P Value in the
+ # [2] position
+ return tuple(item) + (trait_database[1], trait_database[2])
+ return item
+
+ target_trait_names, target_trait_vals = reduce(# type: ignore[var-annotated]
+ lambda acc, item: (acc[0]+(item[0],), acc[1]+(item[data_start_pos:],)),
+ trait_database, (tuple(), tuple()))
+
+ all_correlations = compute_partial(
+ primary_vals, control_vals, target_trait_vals, target_trait_names,
+ method)
+
+ if (input_trait_gene_id and db_type == "ProbeSet" and method.lower() in (
+ "sgo literature correlation", "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho")):
+ return (
+ len(trait_database),
+ tuple(
+ __add_lit_and_tiss_corr__(item)
+ for idx, item in enumerate(all_correlations)))
+
+ return len(trait_database), all_correlations
+
+def partial_corrs(# pylint: disable=[R0913]
+ conn, samples, primary_vals, control_vals, return_number, species,
+ input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id,
+ method, dataset, database_filename):
+ """
+ Compute the partial correlations, selecting the fast or normal method
+ depending on the existence of the database text file.
This is a partial migration of the
- `web.webqtl.correlation.correlationFunction.determinPartialsByR` function in
- GeneNetwork1, specifically the `pcor.rec` function written in the R
- programming language.
+ `web.webqtl.correlation.PartialCorrDBPage.__init__` function in
+ GeneNetwork1.
+ """
+ if database_filename:
+ return partial_correlations_fast(
+ samples, primary_vals, control_vals, database_filename,
+ (
+ fetch_literature_correlations(
+ species, input_trait_geneid, dataset, return_number, conn)
+ if "literature" in method.lower() else
+ fetch_tissue_correlations(
+ dataset, input_trait_symbol, tissue_probeset_freeze_id,
+ method, return_number, conn)),
+ method,
+ ("literature" if method.lower() == "sgo literature correlation"
+ else ("tissue" if "tissue" in method.lower() else "genetic")))
+
+ trait_database, data_start_pos = fetch_all_database_data(
+ conn, species, input_trait_geneid, input_trait_symbol, samples, dataset,
+ method, return_number, tissue_probeset_freeze_id)
+ return partial_correlations_normal(
+ primary_vals, control_vals, input_trait_geneid, trait_database,
+ data_start_pos, dataset, method)
+
+def literature_correlation_by_list(
+ conn: Any, species: str, trait_list: Tuple[dict]) -> Tuple[dict, ...]:
+ """
+ This is a migration of the
+ `web.webqtl.correlation.CorrelationPage.getLiteratureCorrelationByList`
+ function in GeneNetwork1.
+ """
+ if any((lambda t: (
+ bool(t.get("tissue_corr")) and
+ bool(t.get("tissue_p_value"))))(trait)
+ for trait in trait_list):
+ temporary_table_name = f"LITERATURE{random_string(8)}"
+ query1 = (
+ f"CREATE TEMPORARY TABLE {temporary_table_name} "
+ "(GeneId1 INT(12) UNSIGNED, GeneId2 INT(12) UNSIGNED PRIMARY KEY, "
+ "value DOUBLE)")
+ query2 = (
+ f"INSERT INTO {temporary_table_name}(GeneId1, GeneId2, value) "
+ "SELECT GeneId1, GeneId2, value FROM LCorrRamin3 "
+ "WHERE GeneId1=%(geneid)s")
+ query3 = (
+ "INSERT INTO {temporary_table_name}(GeneId1, GeneId2, value) "
+ "SELECT GeneId2, GeneId1, value FROM LCorrRamin3 "
+ "WHERE GeneId2=%s AND GeneId1 != %(geneid)s")
+
+ def __set_mouse_geneid__(trait):
+ if trait.get("geneid"):
+ return {
+ **trait,
+ "mouse_geneid": translate_to_mouse_gene_id(
+ species, trait.get("geneid"), conn)
+ }
+ return {**trait, "mouse_geneid": 0}
+
+ def __retrieve_lcorr__(cursor, geneids):
+ cursor.execute(
+ f"SELECT GeneId2, value FROM {temporary_table_name} "
+ "WHERE GeneId2 IN %(geneids)s",
+ geneids=geneids)
+ return dict(cursor.fetchall())
+
+ with conn.cursor() as cursor:
+ cursor.execute(query1)
+ cursor.execute(query2)
+ cursor.execute(query3)
+
+ traits = tuple(__set_mouse_geneid__(trait) for trait in trait_list)
+ lcorrs = __retrieve_lcorr__(
+ cursor, (
+ trait["mouse_geneid"] for trait in traits
+ if (trait["mouse_geneid"] != 0 and
+ trait["mouse_geneid"].find(";") < 0)))
+ return tuple(
+ {**trait, "l_corr": lcorrs.get(trait["mouse_geneid"], None)}
+ for trait in traits)
+
+ return trait_list
+ return trait_list
+
+def tissue_correlation_by_list(
+ conn: Any, primary_trait_symbol: str, tissue_probeset_freeze_id: int,
+ method: str, trait_list: Tuple[dict]) -> Tuple[dict, ...]:
+ """
+ This is a migration of the
+ `web.webqtl.correlation.CorrelationPage.getTissueCorrelationByList`
+ function in GeneNetwork1.
+ """
+ def __add_tissue_corr__(trait, primary_trait_values, trait_values):
+ result = pingouin.corr(
+ primary_trait_values, trait_values,
+ method=("spearman" if "spearman" in method.lower() else "pearson"))
+ return {
+ **trait,
+ "tissue_corr": result["r"],
+ "tissue_p_value": result["p-val"]
+ }
+
+ if any((lambda t: bool(t.get("l_corr")))(trait) for trait in trait_list):
+ prim_trait_symbol_value_dict = fetch_gene_symbol_tissue_value_dict_for_trait(
+ (primary_trait_symbol,), tissue_probeset_freeze_id, conn)
+ if primary_trait_symbol.lower() in prim_trait_symbol_value_dict:
+ primary_trait_value = prim_trait_symbol_value_dict[
+ primary_trait_symbol.lower()]
+ gene_symbol_list = tuple(
+ trait["symbol"] for trait in trait_list if "symbol" in trait.keys())
+ symbol_value_dict = fetch_gene_symbol_tissue_value_dict_for_trait(
+ gene_symbol_list, tissue_probeset_freeze_id, conn)
+ return tuple(
+ __add_tissue_corr__(
+ trait, primary_trait_value,
+ symbol_value_dict[trait["symbol"].lower()])
+ for trait in trait_list
+ if ("symbol" in trait and
+ bool(trait["symbol"]) and
+ trait["symbol"].lower() in symbol_value_dict))
+ return tuple({
+ **trait,
+ "tissue_corr": None,
+ "tissue_p_value": None
+ } for trait in trait_list)
+ return trait_list
+
+def trait_for_output(trait):
+ """
+ Process a trait for output.
+
+ Removes a lot of extraneous data from the trait, that is not needed for
+ the display of partial correlation results.
+ This function also removes all key-value pairs, for which the value is
+ `None`, because it is a waste of network resources to transmit the key-value
+ pair just to indicate it does not exist.
+ """
+ def __nan_to_none__(val):
+ if val is None:
+ return None
+ if math.isnan(val) or numpy.isnan(val):
+ return None
+ return val
+
+ trait = {
+ "trait_type": trait["db"]["dataset_type"],
+ "dataset_name": trait["db"]["dataset_name"],
+ "dataset_type": trait["db"]["dataset_type"],
+ "group": trait["db"]["group"],
+ "trait_fullname": trait["trait_fullname"],
+ "trait_name": trait["trait_name"],
+ "symbol": trait.get("symbol"),
+ "description": trait.get("description"),
+ "pre_publication_description": trait.get("Pre_publication_description"),
+ "post_publication_description": trait.get(
+ "Post_publication_description"),
+ "original_description": trait.get("Original_description"),
+ "authors": trait.get("Authors"),
+ "year": trait.get("Year"),
+ "probe_target_description": trait.get("Probe_target_description"),
+ "chr": trait.get("chr"),
+ "mb": trait.get("mb"),
+ "geneid": trait.get("geneid"),
+ "homologeneid": trait.get("homologeneid"),
+ "noverlap": trait.get("noverlap"),
+ "partial_corr": __nan_to_none__(trait.get("partial_corr")),
+ "partial_corr_p_value": __nan_to_none__(
+ trait.get("partial_corr_p_value")),
+ "corr": __nan_to_none__(trait.get("corr")),
+ "corr_p_value": __nan_to_none__(trait.get("corr_p_value")),
+ "rank_order": __nan_to_none__(trait.get("rank_order")),
+ "delta": (
+ None if trait.get("partial_corr") is None
+ else (trait.get("partial_corr") - trait.get("corr"))),
+ "l_corr": __nan_to_none__(trait.get("l_corr")),
+ "tissue_corr": __nan_to_none__(trait.get("tissue_corr")),
+ "tissue_p_value": __nan_to_none__(trait.get("tissue_p_value"))
+ }
+ return {key: val for key, val in trait.items() if val is not None}
+
+def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
+ conn: Any, primary_trait_name: str,
+ control_trait_names: Tuple[str, ...], method: str,
+ criteria: int, target_db_name: str) -> dict:
+ """
+ This is the 'ochestration' function for the partial-correlation feature.
+
+ This function will dispatch the functions doing data fetches from the
+ database (and various other places) and feed that data to the functions
+ doing the conversions and computations. It will then return the results of
+ all of that work.
+
+ This function is doing way too much. Look into splitting out the
+ functionality into smaller functions that do fewer things.
"""
- return 0
+ threshold = 0
+ corr_min_informative = 4
+
+ all_traits = traits_info(
+ conn, threshold, (primary_trait_name,) + control_trait_names)
+ all_traits_data = traits_data(conn, all_traits)
+
+ primary_trait = tuple(
+ trait for trait in all_traits
+ if trait["trait_fullname"] == primary_trait_name)[0]
+ if not primary_trait["haveinfo"]:
+ return {
+ "status": "not-found",
+ "message": f"Could not find primary trait {primary_trait['trait_fullname']}"
+ }
+ cntrl_traits = tuple(
+ trait for trait in all_traits
+ if trait["trait_fullname"] != primary_trait_name)
+ if not any(trait["haveinfo"] for trait in cntrl_traits):
+ return {
+ "status": "not-found",
+ "message": "None of the requested control traits were found."}
+ for trait in cntrl_traits:
+ if trait["haveinfo"] is False:
+ warnings.warn(
+ (f"Control traits {trait['trait_fullname']} was not found "
+ "- continuing without it."),
+ category=UserWarning)
+
+ group = primary_trait["db"]["group"]
+ primary_trait_data = all_traits_data[primary_trait["trait_name"]]
+ primary_samples, primary_values, _primary_variances = export_informative(
+ primary_trait_data)
+
+ cntrl_traits_data = tuple(
+ data for trait_name, data in all_traits_data.items()
+ if trait_name != primary_trait["trait_name"])
+ species = species_name(conn, group)
+
+ (cntrl_samples,
+ cntrl_values,
+ _cntrl_variances,
+ _cntrl_ns) = control_samples(cntrl_traits_data, primary_samples)
+
+ common_primary_control_samples = primary_samples
+ fixed_primary_vals = primary_values
+ fixed_control_vals = cntrl_values
+ if not all(cnt_smp == primary_samples for cnt_smp in cntrl_samples):
+ (common_primary_control_samples,
+ fixed_primary_vals,
+ fixed_control_vals,
+ _primary_variances,
+ _cntrl_variances) = fix_samples(primary_trait, cntrl_traits)
+
+ if len(common_primary_control_samples) < corr_min_informative:
+ return {
+ "status": "error",
+ "message": (
+ f"Fewer than {corr_min_informative} samples data entered for "
+ f"{group} dataset. No calculation of correlation has been "
+ "attempted."),
+ "error_type": "Inadequate Samples"}
+
+ identical_traits_names = find_identical_traits(
+ primary_trait_name, primary_values, control_trait_names, cntrl_values)
+ if len(identical_traits_names) > 0:
+ return {
+ "status": "error",
+ "message": (
+ f"{identical_traits_names[0]} and {identical_traits_names[1]} "
+ "have the same values for the {len(fixed_primary_vals)} "
+ "samples that will be used to compute the partial correlation "
+ "(common for all primary and control traits). In such cases, "
+ "partial correlation cannot be computed. Please re-select your "
+ "traits."),
+ "error_type": "Identical Traits"}
+
+ input_trait_geneid = primary_trait.get("geneid", 0)
+ input_trait_symbol = primary_trait.get("symbol", "")
+ input_trait_mouse_geneid = translate_to_mouse_gene_id(
+ species, input_trait_geneid, conn)
+
+ tissue_probeset_freeze_id = 1
+ db_type = primary_trait["db"]["dataset_type"]
+
+ if db_type == "ProbeSet" and method.lower() in (
+ "sgo literature correlation",
+ "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho"):
+ return {
+ "status": "error",
+ "message": (
+ "Wrong correlation type: It is not possible to compute the "
+ f"{method} between your trait and data in the {target_db_name} "
+ "database. Please try again after selecting another type of "
+ "correlation."),
+ "error_type": "Correlation Type"}
+
+ if (method.lower() == "sgo literature correlation" and (
+ bool(input_trait_geneid) is False or
+ check_for_literature_info(conn, input_trait_mouse_geneid))):
+ return {
+ "status": "error",
+ "message": (
+ "No Literature Information: This gene does not have any "
+ "associated Literature Information."),
+ "error_type": "Literature Correlation"}
+
+ if (
+ method.lower() in (
+ "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho")
+ and bool(input_trait_symbol) is False):
+ return {
+ "status": "error",
+ "message": (
+ "No Tissue Correlation Information: This gene does not have "
+ "any associated Tissue Correlation Information."),
+ "error_type": "Tissue Correlation"}
+
+ if (
+ method.lower() in (
+ "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho")
+ and check_symbol_for_tissue_correlation(
+ conn, tissue_probeset_freeze_id, input_trait_symbol)):
+ return {
+ "status": "error",
+ "message": (
+ "No Tissue Correlation Information: This gene does not have "
+ "any associated Tissue Correlation Information."),
+ "error_type": "Tissue Correlation"}
+
+ target_dataset = retrieve_trait_dataset(
+ ("Temp" if "Temp" in target_db_name else
+ ("Publish" if "Publish" in target_db_name else
+ "Geno" if "Geno" in target_db_name else "ProbeSet")),
+ {"db": {"dataset_name": target_db_name}, "trait_name": "_"},
+ threshold,
+ conn)
+
+ database_filename = get_filename(conn, target_db_name, TEXTDIR)
+ _total_traits, all_correlations = partial_corrs(
+ conn, common_primary_control_samples, fixed_primary_vals,
+ fixed_control_vals, len(fixed_primary_vals), species,
+ input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id,
+ method, {**target_dataset, "dataset_type": target_dataset["type"]}, database_filename)
+
+
+ def __make_sorter__(method):
+ def __by_lit_or_tiss_corr_then_p_val__(row):
+ return (row[6], row[3])
+
+ def __by_partial_corr_p_value__(row):
+ return row[3]
+
+ if (("literature" in method.lower()) or ("tissue" in method.lower())):
+ return __by_lit_or_tiss_corr_then_p_val__
+
+ return __by_partial_corr_p_value__
+
+ add_lit_corr_and_tiss_corr = compose(
+ partial(literature_correlation_by_list, conn, species),
+ partial(
+ tissue_correlation_by_list, conn, input_trait_symbol,
+ tissue_probeset_freeze_id, method))
+
+ selected_results = sorted(
+ all_correlations,
+ key=__make_sorter__(method))[:criteria]
+ traits_list_corr_info = {
+ f"{target_dataset['dataset_name']}::{item[0]}": {
+ "noverlap": item[1],
+ "partial_corr": item[2],
+ "partial_corr_p_value": item[3],
+ "corr": item[4],
+ "corr_p_value": item[5],
+ "rank_order": (1 if "spearman" in method.lower() else 0),
+ **({
+ "tissue_corr": item[6],
+ "tissue_p_value": item[7]}
+ if len(item) == 8 else {}),
+ **({"l_corr": item[6]}
+ if len(item) == 7 else {})
+ } for item in selected_results}
+
+ trait_list = add_lit_corr_and_tiss_corr(tuple(
+ {**trait, **traits_list_corr_info.get(trait["trait_fullname"], {})}
+ for trait in traits_info(
+ conn, threshold,
+ tuple(
+ f"{target_dataset['dataset_name']}::{item[0]}"
+ for item in selected_results))))
+
+ return {
+ "status": "success",
+ "results": {
+ "primary_trait": trait_for_output(primary_trait),
+ "control_traits": tuple(
+ trait_for_output(trait) for trait in cntrl_traits),
+ "correlations": tuple(
+ trait_for_output(trait) for trait in trait_list),
+ "dataset_type": target_dataset["type"],
+ "method": "spearman" if "spearman" in method.lower() else "pearson"
+ }}
diff --git a/gn3/computations/partial_correlations_optimised.py b/gn3/computations/partial_correlations_optimised.py
new file mode 100644
index 0000000..601289c
--- /dev/null
+++ b/gn3/computations/partial_correlations_optimised.py
@@ -0,0 +1,244 @@
+"""
+This contains an optimised version of the
+ `gn3.computations.partial_correlations.partial_correlations_entry`
+function.
+"""
+from functools import partial
+from typing import Any, Tuple
+
+from gn3.settings import TEXTDIR
+from gn3.function_helpers import compose
+from gn3.db.partial_correlations import traits_info, traits_data
+from gn3.db.species import species_name, translate_to_mouse_gene_id
+from gn3.db.traits import export_informative, retrieve_trait_dataset
+from gn3.db.correlations import (
+ get_filename,
+ check_for_literature_info,
+ check_symbol_for_tissue_correlation)
+from gn3.computations.partial_correlations import (
+ fix_samples,
+ partial_corrs,
+ control_samples,
+ trait_for_output,
+ find_identical_traits,
+ tissue_correlation_by_list,
+ literature_correlation_by_list)
+
+def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
+ conn: Any, primary_trait_name: str,
+ control_trait_names: Tuple[str, ...], method: str,
+ criteria: int, target_db_name: str) -> dict:
+ """
+ This is the 'ochestration' function for the partial-correlation feature.
+
+ This function will dispatch the functions doing data fetches from the
+ database (and various other places) and feed that data to the functions
+ doing the conversions and computations. It will then return the results of
+ all of that work.
+
+ This function is doing way too much. Look into splitting out the
+ functionality into smaller functions that do fewer things.
+ """
+ threshold = 0
+ corr_min_informative = 4
+
+ all_traits = traits_info(
+ conn, threshold, (primary_trait_name,) + control_trait_names)
+ all_traits_data = traits_data(conn, all_traits)
+
+ # primary_trait = retrieve_trait_info(threshold, primary_trait_name, conn)
+ primary_trait = tuple(
+ trait for trait in all_traits
+ if trait["trait_fullname"] == primary_trait_name)[0]
+ group = primary_trait["db"]["group"]
+ # primary_trait_data = retrieve_trait_data(primary_trait, conn)
+ primary_trait_data = all_traits_data[primary_trait["trait_name"]]
+ primary_samples, primary_values, _primary_variances = export_informative(
+ primary_trait_data)
+
+ # cntrl_traits = tuple(
+ # retrieve_trait_info(threshold, trait_full_name, conn)
+ # for trait_full_name in control_trait_names)
+ # cntrl_traits_data = tuple(
+ # retrieve_trait_data(cntrl_trait, conn)
+ # for cntrl_trait in cntrl_traits)
+ cntrl_traits = tuple(
+ trait for trait in all_traits
+ if trait["trait_fullname"] != primary_trait_name)
+ cntrl_traits_data = tuple(
+ data for trait_name, data in all_traits_data.items()
+ if trait_name != primary_trait["trait_name"])
+ species = species_name(conn, group)
+
+ (cntrl_samples,
+ cntrl_values,
+ _cntrl_variances,
+ _cntrl_ns) = control_samples(cntrl_traits_data, primary_samples)
+
+ common_primary_control_samples = primary_samples
+ fixed_primary_vals = primary_values
+ fixed_control_vals = cntrl_values
+ if not all(cnt_smp == primary_samples for cnt_smp in cntrl_samples):
+ (common_primary_control_samples,
+ fixed_primary_vals,
+ fixed_control_vals,
+ _primary_variances,
+ _cntrl_variances) = fix_samples(primary_trait, cntrl_traits)
+
+ if len(common_primary_control_samples) < corr_min_informative:
+ return {
+ "status": "error",
+ "message": (
+ f"Fewer than {corr_min_informative} samples data entered for "
+ f"{group} dataset. No calculation of correlation has been "
+ "attempted."),
+ "error_type": "Inadequate Samples"}
+
+ identical_traits_names = find_identical_traits(
+ primary_trait_name, primary_values, control_trait_names, cntrl_values)
+ if len(identical_traits_names) > 0:
+ return {
+ "status": "error",
+ "message": (
+ f"{identical_traits_names[0]} and {identical_traits_names[1]} "
+ "have the same values for the {len(fixed_primary_vals)} "
+ "samples that will be used to compute the partial correlation "
+ "(common for all primary and control traits). In such cases, "
+ "partial correlation cannot be computed. Please re-select your "
+ "traits."),
+ "error_type": "Identical Traits"}
+
+ input_trait_geneid = primary_trait.get("geneid", 0)
+ input_trait_symbol = primary_trait.get("symbol", "")
+ input_trait_mouse_geneid = translate_to_mouse_gene_id(
+ species, input_trait_geneid, conn)
+
+ tissue_probeset_freeze_id = 1
+ db_type = primary_trait["db"]["dataset_type"]
+
+ if db_type == "ProbeSet" and method.lower() in (
+ "sgo literature correlation",
+ "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho"):
+ return {
+ "status": "error",
+ "message": (
+ "Wrong correlation type: It is not possible to compute the "
+ f"{method} between your trait and data in the {target_db_name} "
+ "database. Please try again after selecting another type of "
+ "correlation."),
+ "error_type": "Correlation Type"}
+
+ if (method.lower() == "sgo literature correlation" and (
+ bool(input_trait_geneid) is False or
+ check_for_literature_info(conn, input_trait_mouse_geneid))):
+ return {
+ "status": "error",
+ "message": (
+ "No Literature Information: This gene does not have any "
+ "associated Literature Information."),
+ "error_type": "Literature Correlation"}
+
+ if (
+ method.lower() in (
+ "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho")
+ and bool(input_trait_symbol) is False):
+ return {
+ "status": "error",
+ "message": (
+ "No Tissue Correlation Information: This gene does not have "
+ "any associated Tissue Correlation Information."),
+ "error_type": "Tissue Correlation"}
+
+ if (
+ method.lower() in (
+ "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho")
+ and check_symbol_for_tissue_correlation(
+ conn, tissue_probeset_freeze_id, input_trait_symbol)):
+ return {
+ "status": "error",
+ "message": (
+ "No Tissue Correlation Information: This gene does not have "
+ "any associated Tissue Correlation Information."),
+ "error_type": "Tissue Correlation"}
+
+ target_dataset = retrieve_trait_dataset(
+ ("Temp" if "Temp" in target_db_name else
+ ("Publish" if "Publish" in target_db_name else
+ "Geno" if "Geno" in target_db_name else "ProbeSet")),
+ {"db": {"dataset_name": target_db_name}, "trait_name": "_"},
+ threshold,
+ conn)
+
+ database_filename = get_filename(conn, target_db_name, TEXTDIR)
+ _total_traits, all_correlations = partial_corrs(
+ conn, common_primary_control_samples, fixed_primary_vals,
+ fixed_control_vals, len(fixed_primary_vals), species,
+ input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id,
+ method, {**target_dataset, "dataset_type": target_dataset["type"]}, database_filename)
+
+
+ def __make_sorter__(method):
+ def __sort_6__(row):
+ return row[6]
+
+ def __sort_3__(row):
+ return row[3]
+
+ if "literature" in method.lower():
+ return __sort_6__
+
+ if "tissue" in method.lower():
+ return __sort_6__
+
+ return __sort_3__
+
+ # sorted_correlations = sorted(
+ # all_correlations, key=__make_sorter__(method))
+
+ add_lit_corr_and_tiss_corr = compose(
+ partial(literature_correlation_by_list, conn, species),
+ partial(
+ tissue_correlation_by_list, conn, input_trait_symbol,
+ tissue_probeset_freeze_id, method))
+
+ selected_results = sorted(
+ all_correlations,
+ key=__make_sorter__(method))[:min(criteria, len(all_correlations))]
+ traits_list_corr_info = {
+ "{target_dataset['dataset_name']}::{item[0]}": {
+ "noverlap": item[1],
+ "partial_corr": item[2],
+ "partial_corr_p_value": item[3],
+ "corr": item[4],
+ "corr_p_value": item[5],
+ "rank_order": (1 if "spearman" in method.lower() else 0),
+ **({
+ "tissue_corr": item[6],
+ "tissue_p_value": item[7]}
+ if len(item) == 8 else {}),
+ **({"l_corr": item[6]}
+ if len(item) == 7 else {})
+ } for item in selected_results}
+
+ trait_list = add_lit_corr_and_tiss_corr(tuple(
+ {**trait, **traits_list_corr_info.get(trait["trait_fullname"], {})}
+ for trait in traits_info(
+ conn, threshold,
+ tuple(
+ f"{target_dataset['dataset_name']}::{item[0]}"
+ for item in selected_results))))
+
+ return {
+ "status": "success",
+ "results": {
+ "primary_trait": trait_for_output(primary_trait),
+ "control_traits": tuple(
+ trait_for_output(trait) for trait in cntrl_traits),
+ "correlations": tuple(
+ trait_for_output(trait) for trait in trait_list),
+ "dataset_type": target_dataset["type"],
+ "method": "spearman" if "spearman" in method.lower() else "pearson"
+ }}
diff --git a/gn3/computations/pca.py b/gn3/computations/pca.py
new file mode 100644
index 0000000..35c9f03
--- /dev/null
+++ b/gn3/computations/pca.py
@@ -0,0 +1,189 @@
+"""module contains pca implementation using python"""
+
+
+from typing import Any
+from scipy import stats
+
+from sklearn.decomposition import PCA
+from sklearn import preprocessing
+
+import numpy as np
+import redis
+
+
+from typing_extensions import TypeAlias
+
+fArray: TypeAlias = list[float]
+
+
+def compute_pca(array: list[fArray]) -> dict[str, Any]:
+ """
+ computes the principal component analysis
+
+ Parameters:
+
+ array(list[list]):a list of lists contains data to perform pca
+
+
+ Returns:
+ pca_dict(dict):dict contains the pca_object,pca components,pca scores
+
+
+ """
+
+ corr_matrix = np.array(array)
+
+ pca_obj = PCA()
+ scaled_data = preprocessing.scale(corr_matrix)
+
+ pca_obj.fit(scaled_data)
+
+ return {
+ "pca": pca_obj,
+ "components": pca_obj.components_,
+ "scores": pca_obj.transform(scaled_data)
+ }
+
+
+def generate_scree_plot_data(variance_ratio: fArray) -> tuple[list, fArray]:
+ """
+ generates the scree data for plotting
+
+ Parameters:
+
+ variance_ratio(list[floats]):ratios for contribution of each pca
+
+ Returns:
+
+ coordinates(list[(x_coor,y_coord)])
+
+
+ """
+
+ perc_var = [round(ratio*100, 1) for ratio in variance_ratio]
+
+ x_coordinates = [f"PC{val}" for val in range(1, len(perc_var)+1)]
+
+ return (x_coordinates, perc_var)
+
+
+def generate_pca_traits_vals(trait_data_array: list[fArray],
+ corr_array: list[fArray]) -> list[list[Any]]:
+ """
+ generates datasets from zscores of the traits and eigen_vectors\
+ of correlation matrix
+
+ Parameters:
+
+ trait_data_array(list[floats]):an list of the traits
+ corr_array(list[list]): list of arrays for computing eigen_vectors
+
+ Returns:
+
+ pca_vals[list[list]]:
+
+
+ """
+
+ trait_zscores = stats.zscore(trait_data_array)
+
+ if len(trait_data_array[0]) < 10:
+ trait_zscores = trait_data_array
+
+ (eigen_values, corr_eigen_vectors) = np.linalg.eig(np.array(corr_array))
+ idx = eigen_values.argsort()[::-1]
+
+ return np.dot(corr_eigen_vectors[:, idx], trait_zscores)
+
+
+def process_factor_loadings_tdata(factor_loadings, traits_num: int):
+ """
+
+ transform loadings for tables visualization
+
+ Parameters:
+ factor_loading(numpy.ndarray)
+ traits_num(int):number of traits
+
+ Returns:
+ tabular_loadings(list[list[float]])
+ """
+
+ target_columns = 3 if traits_num > 2 else 2
+
+ trait_loadings = list(factor_loadings.T)
+
+ return [list(trait_loading[:target_columns])
+ for trait_loading in trait_loadings]
+
+
+def generate_pca_temp_traits(
+ species: str,
+ group: str,
+ traits_data: list[fArray],
+ corr_array: list[fArray],
+ dataset_samples: list[str],
+ shared_samples: list[str],
+ create_time: str
+) -> dict[str, list[Any]]:
+ """
+
+
+ generate pca temp datasets
+
+ """
+
+ # pylint: disable=too-many-arguments
+
+ pca_trait_dict = {}
+
+ pca_vals = generate_pca_traits_vals(traits_data, corr_array)
+
+ for (idx, pca_trait) in enumerate(list(pca_vals)):
+
+ trait_id = f"PCA{str(idx+1)}_{species}_{group}_{create_time}"
+ sample_vals = []
+
+ pointer = 0
+
+ for sample in dataset_samples:
+ if sample in shared_samples:
+
+ sample_vals.append(str(pca_trait[pointer]))
+ pointer += 1
+
+ else:
+ sample_vals.append("x")
+
+ pca_trait_dict[trait_id] = sample_vals
+
+ return pca_trait_dict
+
+
+def cache_pca_dataset(redis_conn: Any, exp_days: int,
+ pca_trait_dict: dict[str, list[Any]]):
+ """
+
+ caches pca dataset to redis
+
+ Parameters:
+
+ redis_conn(object)
+ exp_days(int): fo redis cache
+ pca_trait_dict(Dict): contains traits and traits vals to cache
+
+ Returns:
+
+ boolean(True if correct conn object False incase of exception)
+
+
+ """
+
+ try:
+ for trait_id, sample_data in pca_trait_dict.items():
+ samples_str = " ".join([str(x) for x in sample_data])
+ redis_conn.set(trait_id, samples_str, ex=exp_days)
+ return True
+
+ except (redis.ConnectionError, AttributeError):
+ return False
diff --git a/gn3/computations/qtlreaper.py b/gn3/computations/qtlreaper.py
index d1ff4ac..b61bdae 100644
--- a/gn3/computations/qtlreaper.py
+++ b/gn3/computations/qtlreaper.py
@@ -27,7 +27,7 @@ def generate_traits_file(samples, trait_values, traits_filename):
["{}\t{}".format(
len(trait_values), "\t".join([str(i) for i in t]))
for t in trait_values[-1:]])
- with open(traits_filename, "w") as outfile:
+ with open(traits_filename, "w", encoding="utf8") as outfile:
outfile.writelines(data)
def create_output_directory(path: str):
@@ -68,13 +68,13 @@ def run_reaper(
The function will raise a `subprocess.CalledProcessError` exception in case
of any errors running the `qtlreaper` command.
"""
- create_output_directory("{}/qtlreaper".format(output_dir))
- output_filename = "{}/qtlreaper/main_output_{}.txt".format(
- output_dir, random_string(10))
+ create_output_directory(f"{output_dir}/qtlreaper")
+ output_filename = (
+ f"{output_dir}/qtlreaper/main_output_{random_string(10)}.txt")
output_list = ["--main_output", output_filename]
if separate_nperm_output:
- permu_output_filename: Union[None, str] = "{}/qtlreaper/permu_output_{}.txt".format(
- output_dir, random_string(10))
+ permu_output_filename: Union[None, str] = (
+ f"{output_dir}/qtlreaper/permu_output_{random_string(10)}.txt")
output_list = output_list + [
"--permu_output", permu_output_filename] # type: ignore[list-item]
else:
@@ -135,7 +135,7 @@ def parse_reaper_main_results(results_file):
"""
Parse the results file of running QTLReaper into a list of dicts.
"""
- with open(results_file, "r") as infile:
+ with open(results_file, "r", encoding="utf8") as infile:
lines = infile.readlines()
def __parse_column_float_value(value):
@@ -164,7 +164,7 @@ def parse_reaper_permutation_results(results_file):
"""
Parse the results QTLReaper permutations into a list of values.
"""
- with open(results_file, "r") as infile:
+ with open(results_file, "r", encoding="utf8") as infile:
lines = infile.readlines()
return [float(line.strip()) for line in lines]
diff --git a/gn3/computations/rqtl.py b/gn3/computations/rqtl.py
index e81aba3..65ee6de 100644
--- a/gn3/computations/rqtl.py
+++ b/gn3/computations/rqtl.py
@@ -53,7 +53,7 @@ def process_rqtl_mapping(file_name: str) -> List:
# Later I should probably redo this using csv.read to avoid the
# awkwardness with removing quotes with [1:-1]
with open(os.path.join(current_app.config.get("TMPDIR", "/tmp"),
- "output", file_name), "r") as the_file:
+ "output", file_name), "r", encoding="utf-8") as the_file:
for line in the_file:
line_items = line.split(",")
if line_items[1][1:-1] == "chr" or not line_items:
@@ -118,7 +118,6 @@ def pairscan_for_figure(file_name: str) -> Dict:
return figure_data
-
def get_marker_list(map_file: str) -> List:
"""
Open the map file with the list of markers/pseudomarkers and create list of marker obs
@@ -255,7 +254,7 @@ def process_perm_output(file_name: str) -> Tuple[List, float, float]:
perm_results = []
with open(os.path.join(current_app.config.get("TMPDIR", "/tmp"),
- "output", "PERM_" + file_name), "r") as the_file:
+ "output", "PERM_" + file_name), "r", encoding="utf-8") as the_file:
for i, line in enumerate(the_file):
if i == 0:
# Skip header line
diff --git a/gn3/computations/wgcna.py b/gn3/computations/wgcna.py
index ab12fe7..c985491 100644
--- a/gn3/computations/wgcna.py
+++ b/gn3/computations/wgcna.py
@@ -19,7 +19,7 @@ def dump_wgcna_data(request_data: dict):
request_data["TMPDIR"] = TMPDIR
- with open(temp_file_path, "w") as output_file:
+ with open(temp_file_path, "w", encoding="utf-8") as output_file:
json.dump(request_data, output_file)
return temp_file_path
@@ -31,20 +31,18 @@ def stream_cmd_output(socketio, request_data, cmd: str):
socketio.emit("output", {"data": f"calling you script {cmd}"},
namespace="/", room=request_data["socket_id"])
- results = subprocess.Popen(
- cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True)
+ with subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True) as results:
+ if results.stdout is not None:
+ for line in iter(results.stdout.readline, b""):
+ socketio.emit("output",
+ {"data": line.decode("utf-8").rstrip()},
+ namespace="/", room=request_data["socket_id"])
- if results.stdout is not None:
-
- for line in iter(results.stdout.readline, b""):
- socketio.emit("output",
- {"data": line.decode("utf-8").rstrip()},
- namespace="/", room=request_data["socket_id"])
-
- socketio.emit(
- "output", {"data":
- "parsing the output results"}, namespace="/",
- room=request_data["socket_id"])
+ socketio.emit(
+ "output", {"data":
+ "parsing the output results"}, namespace="/",
+ room=request_data["socket_id"])
def process_image(image_loc: str) -> bytes:
@@ -75,7 +73,7 @@ def call_wgcna_script(rscript_path: str, request_data: dict):
run_cmd_results = run_cmd(cmd)
- with open(generated_file, "r") as outputfile:
+ with open(generated_file, "r", encoding="utf-8") as outputfile:
if run_cmd_results["code"] != 0:
return run_cmd_results
diff --git a/gn3/csvcmp.py b/gn3/csvcmp.py
new file mode 100644
index 0000000..8db89ca
--- /dev/null
+++ b/gn3/csvcmp.py
@@ -0,0 +1,146 @@
+"""This module contains functions for manipulating and working with csv
+texts"""
+from typing import Any, List
+
+import json
+import os
+import uuid
+from gn3.commands import run_cmd
+
+
+def extract_strain_name(csv_header, data, seek="Strain Name") -> str:
+ """Extract a strain's name given a csv header"""
+ for column, value in zip(csv_header.split(","), data.split(",")):
+ if seek in column:
+ return value
+ return ""
+
+
+def create_dirs_if_not_exists(dirs: list) -> None:
+ """Create directories from a list"""
+ for dir_ in dirs:
+ if not os.path.exists(dir_):
+ os.makedirs(dir_)
+
+
+def remove_insignificant_edits(diff_data, epsilon=0.001):
+ """Remove or ignore edits that are not within ε"""
+ __mod = []
+ if diff_data.get("Modifications"):
+ for mod in diff_data.get("Modifications"):
+ original = mod.get("Original").split(",")
+ current = mod.get("Current").split(",")
+ for i, (_x, _y) in enumerate(zip(original, current)):
+ if (
+ _x.replace(".", "").isdigit()
+ and _y.replace(".", "").isdigit()
+ and abs(float(_x) - float(_y)) < epsilon
+ ):
+ current[i] = _x
+ if not (__o := ",".join(original)) == (__c := ",".join(current)):
+ __mod.append(
+ {
+ "Original": __o,
+ "Current": __c,
+ }
+ )
+ diff_data["Modifications"] = __mod
+ return diff_data
+
+
+def clean_csv_text(csv_text: str) -> str:
+ """Remove extra white space elements in all elements of the CSV file"""
+ _csv_text = []
+ for line in csv_text.strip().split("\n"):
+ _csv_text.append(
+ ",".join([el.strip() for el in line.split(",")]))
+ return "\n".join(_csv_text)
+
+
+def csv_diff(base_csv, delta_csv, tmp_dir="/tmp") -> dict:
+ """Diff 2 csv strings"""
+ base_csv = clean_csv_text(base_csv)
+ delta_csv = clean_csv_text(delta_csv)
+ base_csv_list = base_csv.split("\n")
+ delta_csv_list = delta_csv.split("\n")
+
+ base_csv_header, delta_csv_header = "", ""
+ for i, line in enumerate(base_csv_list):
+ if line.startswith("Strain Name,Value,SE,Count"):
+ base_csv_header, delta_csv_header = line, delta_csv_list[i]
+ break
+ longest_header = max(base_csv_header, delta_csv_header)
+
+ if base_csv_header != delta_csv_header:
+ if longest_header != base_csv_header:
+ base_csv = base_csv.replace("Strain Name,Value,SE,Count",
+ longest_header, 1)
+ else:
+ delta_csv = delta_csv.replace(
+ "Strain Name,Value,SE,Count", longest_header, 1
+ )
+ file_name1 = os.path.join(tmp_dir, str(uuid.uuid4()))
+ file_name2 = os.path.join(tmp_dir, str(uuid.uuid4()))
+
+ with open(file_name1, "w", encoding="utf-8") as _f:
+ _l = len(longest_header.split(","))
+ _f.write(fill_csv(csv_text=base_csv, width=_l))
+ with open(file_name2, "w", encoding="utf-8") as _f:
+ _f.write(fill_csv(delta_csv, width=_l))
+
+ # Now we can run the diff!
+ _r = run_cmd(cmd=('"csvdiff '
+ f"{file_name1} {file_name2} "
+ '--format json"'))
+ if _r.get("code") == 0:
+ _r = json.loads(_r.get("output", ""))
+ if any(_r.values()):
+ _r["Columns"] = max(base_csv_header, delta_csv_header)
+ else:
+ _r = {}
+
+ # Clean Up!
+ if os.path.exists(file_name1):
+ os.remove(file_name1)
+ if os.path.exists(file_name2):
+ os.remove(file_name2)
+ return _r
+
+
+def fill_csv(csv_text, width, value="x"):
+ """Fill a csv text with 'value' if it's length is less than width"""
+ data = []
+ for line in csv_text.strip().split("\n"):
+ if line.startswith("Strain") or line.startswith("#"):
+ data.append(line)
+ elif line:
+ _n = line.split(",")
+ for i, val in enumerate(_n):
+ if not val.strip():
+ _n[i] = value
+ data.append(",".join(_n + [value] * (width - len(_n))))
+ return "\n".join(data)
+
+
+def get_allowable_sampledata_headers(conn: Any) -> List:
+ """Get a list of all the case-attributes stored in the database"""
+ attributes = ["Strain Name", "Value", "SE", "Count"]
+ with conn.cursor() as cursor:
+ cursor.execute("SELECT Name from CaseAttribute")
+ attributes += [attributes[0] for attributes in
+ cursor.fetchall()]
+ return attributes
+
+
+def extract_invalid_csv_headers(allowed_headers: List, csv_text: str) -> List:
+ """Check whether a csv text's columns contains valid headers"""
+ csv_header = []
+ for line in csv_text.split("\n"):
+ if line.startswith("Strain Name"):
+ csv_header = [_l.strip() for _l in line.split(",")]
+ break
+ invalid_headers = []
+ for header in csv_header:
+ if header not in allowed_headers:
+ invalid_headers.append(header)
+ return invalid_headers
diff --git a/gn3/data_helpers.py b/gn3/data_helpers.py
index d3f942b..268a0bb 100644
--- a/gn3/data_helpers.py
+++ b/gn3/data_helpers.py
@@ -5,9 +5,9 @@ data structures.
from math import ceil
from functools import reduce
-from typing import Any, Tuple, Sequence, Optional
+from typing import Any, Tuple, Sequence, Optional, Generator
-def partition_all(num: int, items: Sequence[Any]) -> Tuple[Tuple[Any, ...], ...]:
+def partition_all(num: int, items: Sequence[Any]) -> Generator:
"""
Given a sequence `items`, return a new sequence of the same type as `items`
with the data partitioned into sections of `n` items per partition.
@@ -19,10 +19,24 @@ def partition_all(num: int, items: Sequence[Any]) -> Tuple[Tuple[Any, ...], ...]
return acc + ((start, start + num),)
iterations = range(ceil(len(items) / num))
- return tuple([# type: ignore[misc]
- tuple(items[start:stop]) for start, stop # type: ignore[has-type]
- in reduce(
- __compute_start_stop__, iterations, tuple())])
+ for start, stop in reduce(# type: ignore[misc]
+ __compute_start_stop__, iterations, tuple()):
+ yield tuple(items[start:stop]) # type: ignore[has-type]
+
+def partition_by(partition_fn, items):
+ """
+ Given a sequence `items`, return a tuple of tuples, each of which contain
+ the values in `items` partitioned such that the first item in each internal
+ tuple, when passed to `partition_function` returns True.
+
+ This is an approximation of Clojure's `partition-by` function.
+ """
+ def __partitioner__(accumulator, item):
+ if partition_fn(item):
+ return accumulator + ((item,),)
+ return accumulator[:-1] + (accumulator[-1] + (item,),)
+
+ return reduce(__partitioner__, items, tuple())
def parse_csv_line(
line: str, delimiter: str = ",",
@@ -34,4 +48,4 @@ def parse_csv_line(
function in GeneNetwork1.
"""
return tuple(
- col.strip("{} \t\n".format(quoting)) for col in line.split(delimiter))
+ col.strip(f"{quoting} \t\n") for col in line.split(delimiter))
diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py
index 06b3310..3ae66ca 100644
--- a/gn3/db/correlations.py
+++ b/gn3/db/correlations.py
@@ -2,17 +2,16 @@
This module will hold functions that are used in the (partial) correlations
feature to access the database to retrieve data needed for computations.
"""
-
+import os
from functools import reduce
-from typing import Any, Dict, Tuple
+from typing import Any, Dict, Tuple, Union
from gn3.random import random_string
from gn3.data_helpers import partition_all
from gn3.db.species import translate_to_mouse_gene_id
-from gn3.computations.partial_correlations import correlations_of_all_tissue_traits
-
-def get_filename(target_db_name: str, conn: Any) -> str:
+def get_filename(conn: Any, target_db_name: str, text_files_dir: str) -> Union[
+ str, bool]:
"""
Retrieve the name of the reference database file with which correlations are
computed.
@@ -23,18 +22,23 @@ def get_filename(target_db_name: str, conn: Any) -> str:
"""
with conn.cursor() as cursor:
cursor.execute(
- "SELECT Id, FullName from ProbeSetFreeze WHERE Name-%s",
- target_db_name)
+ "SELECT Id, FullName from ProbeSetFreeze WHERE Name=%s",
+ (target_db_name,))
result = cursor.fetchone()
if result:
- return "ProbeSetFreezeId_{tid}_FullName_{fname}.txt".format(
- tid=result[0],
- fname=result[1].replace(' ', '_').replace('/', '_'))
+ filename = (
+ f"ProbeSetFreezeId_{result[0]}_FullName_"
+ f"{result[1].replace(' ', '_').replace('/', '_')}.txt")
+ full_filename = f"{text_files_dir}/{filename}"
+ return (
+ os.path.exists(full_filename) and
+ (filename in os.listdir(text_files_dir)) and
+ full_filename)
- return ""
+ return False
def build_temporary_literature_table(
- species: str, gene_id: int, return_number: int, conn: Any) -> str:
+ conn: Any, species: str, gene_id: int, return_number: int) -> str:
"""
Build and populate a temporary table to hold the literature correlation data
to be used in computations.
@@ -49,7 +53,7 @@ def build_temporary_literature_table(
query = {
"rat": "SELECT rat FROM GeneIDXRef WHERE mouse=%s",
"human": "SELECT human FROM GeneIDXRef WHERE mouse=%d"}
- if species in query.keys():
+ if species in query:
cursor.execute(query[species], row[1])
record = cursor.fetchone()
if record:
@@ -128,7 +132,7 @@ def fetch_literature_correlations(
GeneNetwork1.
"""
temp_table = build_temporary_literature_table(
- species, gene_id, return_number, conn)
+ conn, species, gene_id, return_number)
query_fns = {
"Geno": fetch_geno_literature_correlations,
# "Temp": fetch_temp_literature_correlations,
@@ -156,11 +160,14 @@ def fetch_symbol_value_pair_dict(
symbol: data_id_dict.get(symbol) for symbol in symbol_list
if data_id_dict.get(symbol) is not None
}
- query = "SELECT Id, value FROM TissueProbeSetData WHERE Id IN %(data_ids)s"
+ data_ids_fields = (f"%(id{i})s" for i in range(len(data_ids.values())))
+ query = (
+ "SELECT Id, value FROM TissueProbeSetData "
+ f"WHERE Id IN ({','.join(data_ids_fields)})")
with conn.cursor() as cursor:
cursor.execute(
query,
- data_ids=tuple(data_ids.values()))
+ **{f"id{i}": did for i, did in enumerate(data_ids.values())})
value_results = cursor.fetchall()
return {
key: tuple(row[1] for row in value_results if row[0] == key)
@@ -234,8 +241,10 @@ def fetch_tissue_probeset_xref_info(
"INNER JOIN TissueProbeSetXRef AS t ON t.Symbol = x.Symbol "
"AND t.Mean = x.maxmean")
cursor.execute(
- query, probeset_freeze_id=probeset_freeze_id,
- symbols=tuple(gene_name_list))
+ query, {
+ "probeset_freeze_id": probeset_freeze_id,
+ "symbols": tuple(gene_name_list)
+ })
results = cursor.fetchall()
@@ -268,8 +277,8 @@ def fetch_gene_symbol_tissue_value_dict_for_trait(
return {}
def build_temporary_tissue_correlations_table(
- trait_symbol: str, probeset_freeze_id: int, method: str,
- return_number: int, conn: Any) -> str:
+ conn: Any, trait_symbol: str, probeset_freeze_id: int, method: str,
+ return_number: int) -> str:
"""
Build a temporary table to hold the tissue correlations data.
@@ -279,6 +288,16 @@ def build_temporary_tissue_correlations_table(
# We should probably pass the `correlations_of_all_tissue_traits` function
# as an argument to this function and get rid of the one call immediately
# following this comment.
+ from gn3.computations.partial_correlations import (#pylint: disable=[C0415, R0401]
+ correlations_of_all_tissue_traits)
+ # This import above is necessary within the function to avoid
+ # circular-imports.
+ #
+ #
+ # This import above is indicative of convoluted code, with the computation
+ # being interwoven with the data retrieval. This needs to be changed, such
+ # that the function being imported here is no longer necessary, or have the
+ # imported function passed to this function as an argument.
symbol_corr_dict, symbol_p_value_dict = correlations_of_all_tissue_traits(
fetch_gene_symbol_tissue_value_dict_for_trait(
(trait_symbol,), probeset_freeze_id, conn),
@@ -320,7 +339,7 @@ def fetch_tissue_correlations(# pylint: disable=R0913
GeneNetwork1.
"""
temp_table = build_temporary_tissue_correlations_table(
- trait_symbol, probeset_freeze_id, method, return_number, conn)
+ conn, trait_symbol, probeset_freeze_id, method, return_number)
with conn.cursor() as cursor:
cursor.execute(
(
@@ -379,3 +398,176 @@ def check_symbol_for_tissue_correlation(
return True
return False
+
+def fetch_sample_ids(
+ conn: Any, sample_names: Tuple[str, ...], species_name: str) -> Tuple[
+ int, ...]:
+ """
+ Given a sequence of sample names, and a species name, return the sample ids
+ that correspond to both.
+
+ This is a partial migration of the
+ `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in
+ GeneNetwork1.
+ """
+ samples_fields = (f"%(s{i})s" for i in range(len(sample_names)))
+ query = (
+ "SELECT Strain.Id FROM Strain, Species "
+ f"WHERE Strain.Name IN ({','.join(samples_fields)}) "
+ "AND Strain.SpeciesId=Species.Id "
+ "AND Species.name=%(species_name)s")
+ with conn.cursor() as cursor:
+ cursor.execute(
+ query,
+ {
+ **{f"s{i}": sname for i, sname in enumerate(sample_names)},
+ "species_name": species_name
+ })
+ return tuple(row[0] for row in cursor.fetchall())
+
+def build_query_sgo_lit_corr(
+ db_type: str, temp_table: str, sample_id_columns: str,
+ joins: Tuple[str, ...]) -> Tuple[str, int]:
+ """
+ Build query for `SGO Literature Correlation` data, when querying the given
+ `temp_table` temporary table.
+
+ This is a partial migration of the
+ `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in
+ GeneNetwork1.
+ """
+ return (
+ (f"SELECT {db_type}.Name, {temp_table}.value, " +
+ sample_id_columns +
+ f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " +
+ f"LEFT JOIN {temp_table} ON {temp_table}.GeneId2=ProbeSet.GeneId " +
+ " ".join(joins) +
+ " WHERE ProbeSet.GeneId IS NOT NULL " +
+ f"AND {temp_table}.value IS NOT NULL " +
+ f"AND {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " +
+ f"AND {db_type}Freeze.Name = %(db_name)s " +
+ f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " +
+ f"ORDER BY {db_type}.Id"),
+ 2)
+
+def build_query_tissue_corr(db_type, temp_table, sample_id_columns, joins):
+ """
+ Build query for `Tissue Correlation` data, when querying the given
+ `temp_table` temporary table.
+
+ This is a partial migration of the
+ `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in
+ GeneNetwork1.
+ """
+ return (
+ (f"SELECT {db_type}.Name, {temp_table}.Correlation, " +
+ f"{temp_table}.PValue, " +
+ sample_id_columns +
+ f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " +
+ f"LEFT JOIN {temp_table} ON {temp_table}.Symbol=ProbeSet.Symbol " +
+ " ".join(joins) +
+ " WHERE ProbeSet.Symbol IS NOT NULL " +
+ f"AND {temp_table}.Correlation IS NOT NULL " +
+ f"AND {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " +
+ f"AND {db_type}Freeze.Name = %(db_name)s " +
+ f"AND {db_type}.Id = {db_type}XRef.{db_type}Id "
+ f"ORDER BY {db_type}.Id"),
+ 3)
+
+def fetch_all_database_data(# pylint: disable=[R0913, R0914]
+ conn: Any, species: str, gene_id: int, trait_symbol: str,
+ samples: Tuple[str, ...], dataset: dict, method: str,
+ return_number: int, probeset_freeze_id: int) -> Tuple[
+ Tuple[float], int]:
+ """
+ This is a migration of the
+ `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in
+ GeneNetwork1.
+ """
+ db_type = dataset["dataset_type"]
+ db_name = dataset["dataset_name"]
+ def __build_query__(sample_ids, temp_table):
+ sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in sample_ids)
+ if db_type == "Publish":
+ joins = tuple(
+ (f"LEFT JOIN PublishData AS T{item} "
+ f"ON T{item}.Id = PublishXRef.DataId "
+ f"AND T{item}.StrainId = %(T{item}_sample_id)s")
+ for item in sample_ids)
+ return (
+ ("SELECT PublishXRef.Id, " +
+ sample_id_columns +
+ " FROM (PublishXRef, PublishFreeze) " +
+ " ".join(joins) +
+ " WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
+ "AND PublishFreeze.Name = %(db_name)s"),
+ 1)
+ if temp_table is not None:
+ joins = tuple(
+ (f"LEFT JOIN {db_type}Data AS T{item} "
+ f"ON T{item}.Id = {db_type}XRef.DataId "
+ f"AND T{item}.StrainId=%(T{item}_sample_id)s")
+ for item in sample_ids)
+ if method.lower() == "sgo literature correlation":
+ return build_query_sgo_lit_corr(
+ sample_ids, temp_table, sample_id_columns, joins)
+ if method.lower() in (
+ "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho"):
+ return build_query_tissue_corr(
+ sample_ids, temp_table, sample_id_columns, joins)
+ joins = tuple(
+ (f"LEFT JOIN {db_type}Data AS T{item} "
+ f"ON T{item}.Id = {db_type}XRef.DataId "
+ f"AND T{item}.StrainId = %(T{item}_sample_id)s")
+ for item in sample_ids)
+ return (
+ (
+ f"SELECT {db_type}.Name, " +
+ sample_id_columns +
+ f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " +
+ " ".join(joins) +
+ f" WHERE {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " +
+ f"AND {db_type}Freeze.Name = %(db_name)s " +
+ f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " +
+ f"ORDER BY {db_type}.Id"),
+ 1)
+
+ def __fetch_data__(sample_ids, temp_table):
+ query, data_start_pos = __build_query__(sample_ids, temp_table)
+ with conn.cursor() as cursor:
+ cursor.execute(
+ query,
+ {"db_name": db_name,
+ **{f"T{item}_sample_id": item for item in sample_ids}})
+ return (cursor.fetchall(), data_start_pos)
+
+ sample_ids = tuple(
+ # look into graduating this to an argument and removing the `samples`
+ # and `species` argument: function currying and compositions might help
+ # with this
+ f"{sample_id}" for sample_id in
+ fetch_sample_ids(conn, samples, species))
+
+ temp_table = None
+ if gene_id and db_type == "probeset":
+ if method.lower() == "sgo literature correlation":
+ temp_table = build_temporary_literature_table(
+ conn, species, gene_id, return_number)
+ if method.lower() in (
+ "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho"):
+ temp_table = build_temporary_tissue_correlations_table(
+ conn, trait_symbol, probeset_freeze_id, method, return_number)
+
+ trait_database = tuple(
+ item for sublist in
+ (__fetch_data__(ssample_ids, temp_table)
+ for ssample_ids in partition_all(25, sample_ids))
+ for item in sublist)
+
+ if temp_table:
+ with conn.cursor() as cursor:
+ cursor.execute(f"DROP TEMPORARY TABLE {temp_table}")
+
+ return (trait_database[0], trait_database[1])
diff --git a/gn3/db/datasets.py b/gn3/db/datasets.py
index 6c328f5..b19db53 100644
--- a/gn3/db/datasets.py
+++ b/gn3/db/datasets.py
@@ -1,7 +1,11 @@
"""
This module contains functions relating to specific trait dataset manipulation
"""
-from typing import Any
+import re
+from string import Template
+from typing import Any, Dict, List, Optional
+from SPARQLWrapper import JSON, SPARQLWrapper
+from gn3.settings import SPARQL_ENDPOINT
def retrieve_probeset_trait_dataset_name(
threshold: int, name: str, connection: Any):
@@ -22,10 +26,13 @@ def retrieve_probeset_trait_dataset_name(
"threshold": threshold,
"name": name
})
- return dict(zip(
- ["dataset_id", "dataset_name", "dataset_fullname",
- "dataset_shortname", "dataset_datascale"],
- cursor.fetchone()))
+ res = cursor.fetchone()
+ if res:
+ return dict(zip(
+ ["dataset_id", "dataset_name", "dataset_fullname",
+ "dataset_shortname", "dataset_datascale"],
+ res))
+ return {"dataset_id": None, "dataset_name": name, "dataset_fullname": name}
def retrieve_publish_trait_dataset_name(
threshold: int, name: str, connection: Any):
@@ -75,33 +82,8 @@ def retrieve_geno_trait_dataset_name(
"dataset_shortname"],
cursor.fetchone()))
-def retrieve_temp_trait_dataset_name(
- threshold: int, name: str, connection: Any):
- """
- Get the ID, DataScale and various name formats for a `Temp` trait.
- """
- query = (
- "SELECT Id, Name, FullName, ShortName "
- "FROM TempFreeze "
- "WHERE "
- "public > %(threshold)s "
- "AND "
- "(Name = %(name)s OR FullName = %(name)s OR ShortName = %(name)s)")
- with connection.cursor() as cursor:
- cursor.execute(
- query,
- {
- "threshold": threshold,
- "name": name
- })
- return dict(zip(
- ["dataset_id", "dataset_name", "dataset_fullname",
- "dataset_shortname"],
- cursor.fetchone()))
-
def retrieve_dataset_name(
- trait_type: str, threshold: int, trait_name: str, dataset_name: str,
- conn: Any):
+ trait_type: str, threshold: int, dataset_name: str, conn: Any):
"""
Retrieve the name of a trait given the trait's name
@@ -113,9 +95,7 @@ def retrieve_dataset_name(
"ProbeSet": retrieve_probeset_trait_dataset_name,
"Publish": retrieve_publish_trait_dataset_name,
"Geno": retrieve_geno_trait_dataset_name,
- "Temp": retrieve_temp_trait_dataset_name}
- if trait_type == "Temp":
- return retrieve_temp_trait_dataset_name(threshold, trait_name, conn)
+ "Temp": lambda threshold, dataset_name, conn: {}}
return fn_map[trait_type](threshold, dataset_name, conn)
@@ -203,7 +183,6 @@ def retrieve_temp_trait_dataset():
"""
Retrieve the dataset that relates to `Temp` traits
"""
- # pylint: disable=[C0330]
return {
"searchfield": ["name", "description"],
"disfield": ["name", "description"],
@@ -217,7 +196,6 @@ def retrieve_geno_trait_dataset():
"""
Retrieve the dataset that relates to `Geno` traits
"""
- # pylint: disable=[C0330]
return {
"searchfield": ["name", "chr"],
"disfield": ["name", "chr", "mb", "source2", "sequence"],
@@ -228,7 +206,6 @@ def retrieve_publish_trait_dataset():
"""
Retrieve the dataset that relates to `Publish` traits
"""
- # pylint: disable=[C0330]
return {
"searchfield": [
"name", "post_publication_description", "abstract", "title",
@@ -247,7 +224,6 @@ def retrieve_probeset_trait_dataset():
"""
Retrieve the dataset that relates to `ProbeSet` traits
"""
- # pylint: disable=[C0330]
return {
"searchfield": [
"name", "description", "probe_target_description", "symbol",
@@ -278,8 +254,7 @@ def retrieve_trait_dataset(trait_type, trait, threshold, conn):
"dataset_id": None,
"dataset_name": trait["db"]["dataset_name"],
**retrieve_dataset_name(
- trait_type, threshold, trait["trait_name"],
- trait["db"]["dataset_name"], conn)
+ trait_type, threshold, trait["db"]["dataset_name"], conn)
}
group = retrieve_group_fields(
trait_type, trait["trait_name"], dataset_name_info, conn)
@@ -289,3 +264,100 @@ def retrieve_trait_dataset(trait_type, trait, threshold, conn):
**dataset_fns[trait_type](),
**group
}
+
+def sparql_query(query: str) -> List[Dict[str, Any]]:
+ """Run a SPARQL query and return the bound variables."""
+ sparql = SPARQLWrapper(SPARQL_ENDPOINT)
+ sparql.setQuery(query)
+ sparql.setReturnFormat(JSON)
+ return sparql.queryAndConvert()['results']['bindings']
+
+def dataset_metadata(accession_id: str) -> Optional[Dict[str, Any]]:
+ """Return info about dataset with ACCESSION_ID."""
+ # Check accession_id to protect against query injection.
+ # TODO: This function doesn't yet return the names of the actual dataset files.
+ pattern = re.compile(r'GN\d+', re.ASCII)
+ if not pattern.fullmatch(accession_id):
+ return None
+ # KLUDGE: We split the SPARQL query because virtuoso is very slow on a
+ # single large query.
+ queries = ["""
+PREFIX gn: <http://genenetwork.org/>
+SELECT ?name ?dataset_group ?status ?title ?geo_series
+WHERE {
+ ?dataset gn:accessionId "$accession_id" ;
+ rdf:type gn:dataset ;
+ gn:name ?name .
+ OPTIONAL { ?dataset gn:datasetGroup ?dataset_group } .
+ # FIXME: gn:datasetStatus should not be optional. But, some records don't
+ # have it.
+ OPTIONAL { ?dataset gn:datasetStatus ?status } .
+ OPTIONAL { ?dataset gn:title ?title } .
+ OPTIONAL { ?dataset gn:geoSeries ?geo_series } .
+}
+""",
+ """
+PREFIX gn: <http://genenetwork.org/>
+SELECT ?platform_name ?normalization_name ?species_name ?inbred_set_name ?tissue_name
+WHERE {
+ ?dataset gn:accessionId "$accession_id" ;
+ rdf:type gn:dataset ;
+ gn:normalization / gn:name ?normalization_name ;
+ gn:datasetOfSpecies / gn:menuName ?species_name ;
+ gn:datasetOfInbredSet / gn:name ?inbred_set_name .
+ OPTIONAL { ?dataset gn:datasetOfTissue / gn:name ?tissue_name } .
+ OPTIONAL { ?dataset gn:datasetOfPlatform / gn:name ?platform_name } .
+}
+""",
+ """
+PREFIX gn: <http://genenetwork.org/>
+SELECT ?specifics ?summary ?about_cases ?about_tissue ?about_platform
+ ?about_data_processing ?notes ?experiment_design ?contributors
+ ?citation ?acknowledgment
+WHERE {
+ ?dataset gn:accessionId "$accession_id" ;
+ rdf:type gn:dataset .
+ OPTIONAL { ?dataset gn:specifics ?specifics . }
+ OPTIONAL { ?dataset gn:summary ?summary . }
+ OPTIONAL { ?dataset gn:aboutCases ?about_cases . }
+ OPTIONAL { ?dataset gn:aboutTissue ?about_tissue . }
+ OPTIONAL { ?dataset gn:aboutPlatform ?about_platform . }
+ OPTIONAL { ?dataset gn:aboutDataProcessing ?about_data_processing . }
+ OPTIONAL { ?dataset gn:notes ?notes . }
+ OPTIONAL { ?dataset gn:experimentDesign ?experiment_design . }
+ OPTIONAL { ?dataset gn:contributors ?contributors . }
+ OPTIONAL { ?dataset gn:citation ?citation . }
+ OPTIONAL { ?dataset gn:acknowledgment ?acknowledgment . }
+}
+"""]
+ result: Dict[str, Any] = {'accession_id': accession_id,
+ 'investigator': {}}
+ query_result = {}
+ for query in queries:
+ if sparql_result := sparql_query(Template(query).substitute(accession_id=accession_id)):
+ query_result.update(sparql_result[0])
+ else:
+ return None
+ for key, value in query_result.items():
+ result[key] = value['value']
+ investigator_query_result = sparql_query(Template("""
+PREFIX gn: <http://genenetwork.org/>
+SELECT ?name ?address ?city ?state ?zip ?phone ?email ?country ?homepage
+WHERE {
+ ?dataset gn:accessionId "$accession_id" ;
+ rdf:type gn:dataset ;
+ gn:datasetOfInvestigator ?investigator .
+ OPTIONAL { ?investigator foaf:name ?name . }
+ OPTIONAL { ?investigator gn:address ?address . }
+ OPTIONAL { ?investigator gn:city ?city . }
+ OPTIONAL { ?investigator gn:state ?state . }
+ OPTIONAL { ?investigator gn:zipCode ?zip . }
+ OPTIONAL { ?investigator foaf:phone ?phone . }
+ OPTIONAL { ?investigator foaf:mbox ?email . }
+ OPTIONAL { ?investigator gn:country ?country . }
+ OPTIONAL { ?investigator foaf:homepage ?homepage . }
+}
+""").substitute(accession_id=accession_id))[0]
+ for key, value in investigator_query_result.items():
+ result['investigator'][key] = value['value']
+ return result
diff --git a/gn3/db/genotypes.py b/gn3/db/genotypes.py
index 8f18cac..6f867c7 100644
--- a/gn3/db/genotypes.py
+++ b/gn3/db/genotypes.py
@@ -2,7 +2,6 @@
import os
import gzip
-from typing import Union, TextIO
from gn3.settings import GENOTYPE_FILES
@@ -10,7 +9,7 @@ def build_genotype_file(
geno_name: str, base_dir: str = GENOTYPE_FILES,
extension: str = "geno"):
"""Build the absolute path for the genotype file."""
- return "{}/{}.{}".format(os.path.abspath(base_dir), geno_name, extension)
+ return f"{os.path.abspath(base_dir)}/{geno_name}.{extension}"
def load_genotype_samples(genotype_filename: str, file_type: str = "geno"):
"""
@@ -44,22 +43,23 @@ def __load_genotype_samples_from_geno(genotype_filename: str):
Loads samples from '.geno' files.
"""
- gzipped_filename = "{}.gz".format(genotype_filename)
+ def __remove_comments_and_empty_lines__(rows):
+ return(
+ line for line in rows
+ if line and not line.startswith(("#", "@")))
+
+ gzipped_filename = f"{genotype_filename}.gz"
if os.path.isfile(gzipped_filename):
- genofile: Union[TextIO, gzip.GzipFile] = gzip.open(gzipped_filename)
+ with gzip.open(gzipped_filename) as gz_genofile:
+ rows = __remove_comments_and_empty_lines__(gz_genofile.readlines())
else:
- genofile = open(genotype_filename)
-
- for row in genofile:
- line = row.strip()
- if (not line) or (line.startswith(("#", "@"))): # type: ignore[arg-type]
- continue
- break
+ with open(genotype_filename, encoding="utf8") as genofile:
+ rows = __remove_comments_and_empty_lines__(genofile.readlines())
- headers = line.split("\t") # type: ignore[arg-type]
+ headers = next(rows).split() # type: ignore[arg-type]
if headers[3] == "Mb":
- return headers[4:]
- return headers[3:]
+ return tuple(headers[4:])
+ return tuple(headers[3:])
def __load_genotype_samples_from_plink(genotype_filename: str):
"""
@@ -67,8 +67,8 @@ def __load_genotype_samples_from_plink(genotype_filename: str):
Loads samples from '.plink' files.
"""
- genofile = open(genotype_filename)
- return [line.split(" ")[1] for line in genofile]
+ with open(genotype_filename, encoding="utf8") as genofile:
+ return tuple(line.split()[1] for line in genofile)
def parse_genotype_labels(lines: list):
"""
@@ -129,7 +129,7 @@ def parse_genotype_marker(line: str, geno_obj: dict, parlist: tuple):
alleles = marker_row[start_pos:]
genotype = tuple(
- (geno_table[allele] if allele in geno_table.keys() else "U")
+ (geno_table[allele] if allele in geno_table else "U")
for allele in alleles)
if len(parlist) > 0:
genotype = (-1, 1) + genotype
@@ -164,7 +164,7 @@ def parse_genotype_file(filename: str, parlist: tuple = tuple()):
"""
Parse the provided genotype file into a usable pytho3 data structure.
"""
- with open(filename, "r") as infile:
+ with open(filename, "r", encoding="utf8") as infile:
contents = infile.readlines()
lines = tuple(line for line in contents if
@@ -175,10 +175,10 @@ def parse_genotype_file(filename: str, parlist: tuple = tuple()):
data_lines = tuple(line for line in lines if not line.startswith("@"))
header = parse_genotype_header(data_lines[0], parlist)
geno_obj = dict(labels + header)
- markers = tuple(
- [parse_genotype_marker(line, geno_obj, parlist)
- for line in data_lines[1:]])
+ markers = (
+ parse_genotype_marker(line, geno_obj, parlist)
+ for line in data_lines[1:])
chromosomes = tuple(
dict(chromosome) for chromosome in
- build_genotype_chromosomes(geno_obj, markers))
+ build_genotype_chromosomes(geno_obj, tuple(markers)))
return {**geno_obj, "chromosomes": chromosomes}
diff --git a/gn3/db/partial_correlations.py b/gn3/db/partial_correlations.py
new file mode 100644
index 0000000..72dbf1a
--- /dev/null
+++ b/gn3/db/partial_correlations.py
@@ -0,0 +1,791 @@
+"""
+This module contains the code and queries for fetching data from the database,
+that relates to partial correlations.
+
+It is intended to replace the functions in `gn3.db.traits` and `gn3.db.datasets`
+modules with functions that fetch the data enmasse, rather than one at a time.
+
+This module is part of the optimisation effort for the partial correlations.
+"""
+
+from functools import reduce, partial
+from typing import Any, Dict, Tuple, Union, Sequence
+
+from MySQLdb.cursors import DictCursor
+
+from gn3.function_helpers import compose
+from gn3.db.traits import (
+ build_trait_name,
+ with_samplelist_data_setup,
+ without_samplelist_data_setup)
+
+def organise_trait_data_by_trait(
+ traits_data_rows: Tuple[Dict[str, Any], ...]) -> Dict[
+ str, Dict[str, Any]]:
+ """
+ Organise the trait data items by their trait names.
+ """
+ def __organise__(acc, row):
+ trait_name = row["trait_name"]
+ return {
+ **acc,
+ trait_name: acc.get(trait_name, tuple()) + ({
+ key: val for key, val in row.items() if key != "trait_name"},)
+ }
+ if traits_data_rows:
+ return reduce(__organise__, traits_data_rows, {})
+ return {}
+
+def temp_traits_data(conn, traits):
+ """
+ Retrieve trait data for `Temp` traits.
+ """
+ query = (
+ "SELECT "
+ "Temp.Name AS trait_name, Strain.Name AS sample_name, TempData.value, "
+ "TempData.SE AS se_error, TempData.NStrain AS nstrain, "
+ "TempData.Id AS id "
+ "FROM TempData, Temp, Strain "
+ "WHERE TempData.StrainId = Strain.Id "
+ "AND TempData.Id = Temp.DataId "
+ f"AND Temp.name IN ({', '.join(['%s'] * len(traits))}) "
+ "ORDER BY Strain.Name")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ tuple(trait["trait_name"] for trait in traits))
+ return organise_trait_data_by_trait(cursor.fetchall())
+ return {}
+
+def publish_traits_data(conn, traits):
+ """
+ Retrieve trait data for `Publish` traits.
+ """
+ dataset_ids = tuple(set(
+ trait["db"]["dataset_id"] for trait in traits
+ if trait["db"].get("dataset_id") is not None))
+ query = (
+ "SELECT "
+ "PublishXRef.Id AS trait_name, Strain.Name AS sample_name, "
+ "PublishData.value, PublishSE.error AS se_error, "
+ "NStrain.count AS nstrain, PublishData.Id AS id "
+ "FROM (PublishData, Strain, PublishXRef, PublishFreeze) "
+ "LEFT JOIN PublishSE "
+ "ON (PublishSE.DataId = PublishData.Id "
+ "AND PublishSE.StrainId = PublishData.StrainId) "
+ "LEFT JOIN NStrain "
+ "ON (NStrain.DataId = PublishData.Id "
+ "AND NStrain.StrainId = PublishData.StrainId) "
+ "WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
+ "AND PublishData.Id = PublishXRef.DataId "
+ f"AND PublishXRef.Id IN ({', '.join(['%s'] * len(traits))}) "
+ "AND PublishFreeze.Id IN "
+ f"({', '.join(['%s'] * len(dataset_ids))}) "
+ "AND PublishData.StrainId = Strain.Id "
+ "ORDER BY Strain.Name")
+ if len(dataset_ids) > 0:
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ tuple(trait["trait_name"] for trait in traits) +
+ tuple(dataset_ids))
+ return organise_trait_data_by_trait(cursor.fetchall())
+ return {}
+
+def cellid_traits_data(conn, traits):
+ """
+ Retrieve trait data for `Probe Data` types.
+ """
+ cellids = tuple(trait["cellid"] for trait in traits)
+ dataset_names = set(trait["db"]["dataset_name"] for trait in traits)
+ query = (
+ "SELECT "
+ "ProbeSet.Name AS trait_name, Strain.Name AS sample_name, "
+ "ProbeData.value, ProbeSE.error AS se_error, ProbeData.Id AS id "
+ "FROM (ProbeData, ProbeFreeze, ProbeSetFreeze, ProbeXRef, Strain, "
+ "Probe, ProbeSet) "
+ "LEFT JOIN ProbeSE "
+ "ON (ProbeSE.DataId = ProbeData.Id "
+ "AND ProbeSE.StrainId = ProbeData.StrainId) "
+ f"WHERE Probe.Name IN ({', '.join(['%s'] * len(cellids))}) "
+ f"AND ProbeSet.Name IN ({', '.join(['%s'] * len(traits))}) "
+ "AND Probe.ProbeSetId = ProbeSet.Id "
+ "AND ProbeXRef.ProbeId = Probe.Id "
+ "AND ProbeXRef.ProbeFreezeId = ProbeFreeze.Id "
+ "AND ProbeSetFreeze.ProbeFreezeId = ProbeFreeze.Id "
+ f"AND ProbeSetFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))}) "
+ "AND ProbeXRef.DataId = ProbeData.Id "
+ "AND ProbeData.StrainId = Strain.Id "
+ "ORDER BY Strain.Name")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ cellids + tuple(trait["trait_name"] for trait in traits) +
+ tuple(dataset_names))
+ return organise_trait_data_by_trait(cursor.fetchall())
+ return {}
+
+def probeset_traits_data(conn, traits):
+ """
+ Retrieve trait data for `ProbeSet` traits.
+ """
+ dataset_names = set(trait["db"]["dataset_name"] for trait in traits)
+ query = (
+ "SELECT ProbeSet.Name AS trait_name, Strain.Name AS sample_name, "
+ "ProbeSetData.value, ProbeSetSE.error AS se_error, "
+ "ProbeSetData.Id AS id "
+ "FROM (ProbeSetData, ProbeSetFreeze, Strain, ProbeSet, ProbeSetXRef) "
+ "LEFT JOIN ProbeSetSE ON "
+ "(ProbeSetSE.DataId = ProbeSetData.Id "
+ "AND ProbeSetSE.StrainId = ProbeSetData.StrainId) "
+ f"WHERE ProbeSet.Name IN ({', '.join(['%s'] * len(traits))})"
+ "AND ProbeSetXRef.ProbeSetId = ProbeSet.Id "
+ "AND ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id "
+ f"AND ProbeSetFreeze.Name IN ({', '.join(['%s']*len(dataset_names))}) "
+ "AND ProbeSetXRef.DataId = ProbeSetData.Id "
+ "AND ProbeSetData.StrainId = Strain.Id "
+ "ORDER BY Strain.Name")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ tuple(trait["trait_name"] for trait in traits) +
+ tuple(dataset_names))
+ return organise_trait_data_by_trait(cursor.fetchall())
+ return {}
+
+def species_ids(conn, traits):
+ """
+ Retrieve the IDS of the related species from the given list of traits.
+ """
+ groups = tuple(set(
+ trait["db"]["group"] for trait in traits
+ if trait["db"].get("group") is not None))
+ query = (
+ "SELECT Name AS `group`, SpeciesId AS species_id "
+ "FROM InbredSet "
+ f"WHERE Name IN ({', '.join(['%s'] * len(groups))})")
+ if len(groups) > 0:
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(query, groups)
+ return tuple(row for row in cursor.fetchall())
+ return tuple()
+
+def geno_traits_data(conn, traits):
+ """
+ Retrieve trait data for `Geno` traits.
+ """
+ sp_ids = tuple(item["species_id"] for item in species_ids(conn, traits))
+ dataset_names = set(trait["db"]["dataset_name"] for trait in traits)
+ query = (
+ "SELECT Geno.Name AS trait_name, Strain.Name AS sample_name, "
+ "GenoData.value, GenoSE.error AS se_error, GenoData.Id AS id "
+ "FROM (GenoData, GenoFreeze, Strain, Geno, GenoXRef) "
+ "LEFT JOIN GenoSE ON "
+ "(GenoSE.DataId = GenoData.Id AND GenoSE.StrainId = GenoData.StrainId) "
+ f"WHERE Geno.SpeciesId IN ({', '.join(['%s'] * len(sp_ids))}) "
+ f"AND Geno.Name IN ({', '.join(['%s'] * len(traits))}) "
+ "AND GenoXRef.GenoId = Geno.Id "
+ "AND GenoXRef.GenoFreezeId = GenoFreeze.Id "
+ f"AND GenoFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))}) "
+ "AND GenoXRef.DataId = GenoData.Id "
+ "AND GenoData.StrainId = Strain.Id "
+ "ORDER BY Strain.Name")
+ if len(sp_ids) > 0 and len(dataset_names) > 0:
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ sp_ids +
+ tuple(trait["trait_name"] for trait in traits) +
+ tuple(dataset_names))
+ return organise_trait_data_by_trait(cursor.fetchall())
+ return {}
+
+def traits_data(
+ conn: Any, traits: Tuple[Dict[str, Any], ...],
+ samplelist: Tuple[str, ...] = tuple()) -> Dict[str, Dict[str, Any]]:
+ """
+ Retrieve trait data for multiple `traits`
+
+ This is a rework of the `gn3.db.traits.retrieve_trait_data` function.
+ """
+ def __organise__(acc, trait):
+ dataset_type = trait["db"]["dataset_type"]
+ if dataset_type == "Temp":
+ return {**acc, "Temp": acc.get("Temp", tuple()) + (trait,)}
+ if dataset_type == "Publish":
+ return {**acc, "Publish": acc.get("Publish", tuple()) + (trait,)}
+ if trait.get("cellid"):
+ return {**acc, "cellid": acc.get("cellid", tuple()) + (trait,)}
+ if dataset_type == "ProbeSet":
+ return {**acc, "ProbeSet": acc.get("ProbeSet", tuple()) + (trait,)}
+ return {**acc, "Geno": acc.get("Geno", tuple()) + (trait,)}
+
+ def __setup_samplelist__(data):
+ if samplelist:
+ return tuple(
+ item for item in
+ map(with_samplelist_data_setup(samplelist), data)
+ if item is not None)
+ return tuple(
+ item for item in
+ map(without_samplelist_data_setup(), data)
+ if item is not None)
+
+ def __process_results__(results):
+ flattened = reduce(lambda acc, res: {**acc, **res}, results)
+ return {
+ trait_name: {"data": dict(map(
+ lambda item: (
+ item["sample_name"],
+ {
+ key: val for key, val in item.items()
+ if item != "sample_name"
+ }),
+ __setup_samplelist__(data)))}
+ for trait_name, data in flattened.items()}
+
+ traits_data_fns = {
+ "Temp": temp_traits_data,
+ "Publish": publish_traits_data,
+ "cellid": cellid_traits_data,
+ "ProbeSet": probeset_traits_data,
+ "Geno": geno_traits_data
+ }
+ return __process_results__(tuple(# type: ignore[var-annotated]
+ traits_data_fns[key](conn, vals)
+ for key, vals in reduce(__organise__, traits, {}).items()))
+
+def merge_traits_and_info(traits, info_results):
+ """
+ Utility to merge trait info retrieved from the database with the given traits.
+ """
+ if info_results:
+ results = {
+ str(trait["trait_name"]): trait for trait in info_results
+ }
+ return tuple(
+ {
+ **trait,
+ **results.get(trait["trait_name"], {}),
+ "haveinfo": bool(results.get(trait["trait_name"]))
+ } for trait in traits)
+ return tuple({**trait, "haveinfo": False} for trait in traits)
+
+def publish_traits_info(
+ conn: Any, traits: Tuple[Dict[str, Any], ...]) -> Tuple[
+ Dict[str, Any], ...]:
+ """
+ Retrieve trait information for type `Publish` traits.
+
+ This is a rework of `gn3.db.traits.retrieve_publish_trait_info` function:
+ this one fetches multiple items in a single query, unlike the original that
+ fetches one item per query.
+ """
+ trait_dataset_ids = set(
+ trait["db"]["dataset_id"] for trait in traits
+ if trait["db"].get("dataset_id") is not None)
+ columns = (
+ "PublishXRef.Id, Publication.PubMed_ID, "
+ "Phenotype.Pre_publication_description, "
+ "Phenotype.Post_publication_description, "
+ "Phenotype.Original_description, "
+ "Phenotype.Pre_publication_abbreviation, "
+ "Phenotype.Post_publication_abbreviation, "
+ "Phenotype.Lab_code, Phenotype.Submitter, Phenotype.Owner, "
+ "Phenotype.Authorized_Users, "
+ "CAST(Publication.Authors AS BINARY) AS Authors, Publication.Title, "
+ "Publication.Abstract, Publication.Journal, Publication.Volume, "
+ "Publication.Pages, Publication.Month, Publication.Year, "
+ "PublishXRef.Sequence, Phenotype.Units, PublishXRef.comments")
+ query = (
+ "SELECT "
+ f"PublishXRef.Id AS trait_name, {columns} "
+ "FROM "
+ "PublishXRef, Publication, Phenotype, PublishFreeze "
+ "WHERE "
+ f"PublishXRef.Id IN ({', '.join(['%s'] * len(traits))}) "
+ "AND Phenotype.Id = PublishXRef.PhenotypeId "
+ "AND Publication.Id = PublishXRef.PublicationId "
+ "AND PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
+ "AND PublishFreeze.Id IN "
+ f"({', '.join(['%s'] * len(trait_dataset_ids))})")
+ if trait_dataset_ids:
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ (
+ tuple(trait["trait_name"] for trait in traits) +
+ tuple(trait_dataset_ids)))
+ return merge_traits_and_info(traits, cursor.fetchall())
+ return tuple({**trait, "haveinfo": False} for trait in traits)
+
+def probeset_traits_info(
+ conn: Any, traits: Tuple[Dict[str, Any], ...]):
+ """
+ Retrieve information for the probeset traits
+ """
+ dataset_names = set(trait["db"]["dataset_name"] for trait in traits)
+ columns = ", ".join(
+ [f"ProbeSet.{x}" for x in
+ ("name", "symbol", "description", "probe_target_description", "chr",
+ "mb", "alias", "geneid", "genbankid", "unigeneid", "omim",
+ "refseq_transcriptid", "blatseq", "targetseq", "chipid", "comments",
+ "strand_probe", "strand_gene", "probe_set_target_region", "proteinid",
+ "probe_set_specificity", "probe_set_blat_score",
+ "probe_set_blat_mb_start", "probe_set_blat_mb_end",
+ "probe_set_strand", "probe_set_note_by_rw", "flag")])
+ query = (
+ f"SELECT ProbeSet.Name AS trait_name, {columns} "
+ "FROM ProbeSet INNER JOIN ProbeSetXRef "
+ "ON ProbeSetXRef.ProbeSetId = ProbeSet.Id "
+ "INNER JOIN ProbeSetFreeze "
+ "ON ProbeSetFreeze.Id = ProbeSetXRef.ProbeSetFreezeId "
+ "WHERE ProbeSetFreeze.Name IN "
+ f"({', '.join(['%s'] * len(dataset_names))}) "
+ f"AND ProbeSet.Name IN ({', '.join(['%s'] * len(traits))})")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ tuple(dataset_names) + tuple(
+ trait["trait_name"] for trait in traits))
+ return merge_traits_and_info(traits, cursor.fetchall())
+ return tuple({**trait, "haveinfo": False} for trait in traits)
+
+def geno_traits_info(
+ conn: Any, traits: Tuple[Dict[str, Any], ...]):
+ """
+ Retrieve trait information for type `Geno` traits.
+
+ This is a rework of the `gn3.db.traits.retrieve_geno_trait_info` function.
+ """
+ dataset_names = set(trait["db"]["dataset_name"] for trait in traits)
+ columns = ", ".join([
+ f"Geno.{x}" for x in ("name", "chr", "mb", "source2", "sequence")])
+ query = (
+ "SELECT "
+ f"Geno.Name AS trait_name, {columns} "
+ "FROM "
+ "Geno INNER JOIN GenoXRef ON GenoXRef.GenoId = Geno.Id "
+ "INNER JOIN GenoFreeze ON GenoFreeze.Id = GenoXRef.GenoFreezeId "
+ f"WHERE GenoFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))}) "
+ f"AND Geno.Name IN ({', '.join(['%s'] * len(traits))})")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ tuple(dataset_names) + tuple(
+ trait["trait_name"] for trait in traits))
+ return merge_traits_and_info(traits, cursor.fetchall())
+ return tuple({**trait, "haveinfo": False} for trait in traits)
+
+def temp_traits_info(
+ conn: Any, traits: Tuple[Dict[str, Any], ...]):
+ """
+ Retrieve trait information for type `Temp` traits.
+
+ A rework of the `gn3.db.traits.retrieve_temp_trait_info` function.
+ """
+ query = (
+ "SELECT Name as trait_name, name, description FROM Temp "
+ f"WHERE Name IN ({', '.join(['%s'] * len(traits))})")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ tuple(trait["trait_name"] for trait in traits))
+ return merge_traits_and_info(traits, cursor.fetchall())
+ return tuple({**trait, "haveinfo": False} for trait in traits)
+
+def publish_datasets_names(
+ conn: Any, threshold: int, dataset_names: Tuple[str, ...]):
+ """
+ Get the ID, DataScale and various name formats for a `Publish` trait.
+
+ Rework of the `gn3.db.datasets.retrieve_publish_trait_dataset_name`
+ """
+ query = (
+ "SELECT DISTINCT "
+ "Id AS dataset_id, Name AS dataset_name, FullName AS dataset_fullname, "
+ "ShortName AS dataset_shortname "
+ "FROM PublishFreeze "
+ "WHERE "
+ "public > %s "
+ "AND "
+ "(Name IN ({names}) OR FullName IN ({names}) OR ShortName IN ({names}))")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query.format(names=", ".join(["%s"] * len(dataset_names))),
+ (threshold,) +(dataset_names * 3))
+ return {ds["dataset_name"]: ds for ds in cursor.fetchall()}
+ return {}
+
+def set_bxd(group_info):
+ """Set the group value to BXD if it is 'BXD300'."""
+ return {
+ **group_info,
+ "group": (
+ "BXD" if group_info.get("Name") == "BXD300"
+ else group_info.get("Name", "")),
+ "groupid": group_info["Id"]
+ }
+
+def organise_groups_by_dataset(
+ group_rows: Union[Sequence[Dict[str, Any]], None]) -> Dict[str, Any]:
+ """Utility: Organise given groups by their datasets."""
+ if group_rows:
+ return {
+ row["dataset_name"]: set_bxd({
+ key: val for key, val in row.items()
+ if key != "dataset_name"
+ }) for row in group_rows
+ }
+ return {}
+
+def publish_datasets_groups(conn: Any, dataset_names: Tuple[str]):
+ """
+ Retrieve the Group, and GroupID values for various Publish trait types.
+
+ Rework of `gn3.db.datasets.retrieve_publish_group_fields` function.
+ """
+ query = (
+ "SELECT PublishFreeze.Name AS dataset_name, InbredSet.Name, "
+ "InbredSet.Id "
+ "FROM InbredSet, PublishFreeze "
+ "WHERE PublishFreeze.InbredSetId = InbredSet.Id "
+ f"AND PublishFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))})")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(query, tuple(dataset_names))
+ return organise_groups_by_dataset(cursor.fetchall())
+ return {}
+
+def publish_traits_datasets(conn: Any, threshold, traits: Tuple[Dict]):
+ """Retrieve datasets for 'Publish' traits."""
+ dataset_names = tuple(set(trait["db"]["dataset_name"] for trait in traits))
+ dataset_names_info = publish_datasets_names(conn, threshold, dataset_names)
+ dataset_groups = publish_datasets_groups(conn, dataset_names) # type: ignore[arg-type]
+ return tuple({
+ **trait,
+ "db": {
+ **trait["db"],
+ **dataset_names_info.get(trait["db"]["dataset_name"], {}),
+ **dataset_groups.get(trait["db"]["dataset_name"], {})
+ }
+ } for trait in traits)
+
+def probeset_datasets_names(conn: Any, threshold: int, dataset_names: Tuple[str, ...]):
+ """
+ Get the ID, DataScale and various name formats for a `ProbeSet` trait.
+ """
+ query = (
+ "SELECT Id AS dataset_id, Name AS dataset_name, "
+ "FullName AS dataset_fullname, ShortName AS dataset_shortname, "
+ "DataScale AS dataset_datascale "
+ "FROM ProbeSetFreeze "
+ "WHERE "
+ "public > %s "
+ "AND "
+ "(Name IN ({names}) OR FullName IN ({names}) OR ShortName IN ({names}))")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query.format(names=", ".join(["%s"] * len(dataset_names))),
+ (threshold,) +(dataset_names * 3))
+ return {ds["dataset_name"]: ds for ds in cursor.fetchall()}
+ return {}
+
+def probeset_datasets_groups(conn, dataset_names):
+ """
+ Retrieve the Group, and GroupID values for various ProbeSet trait types.
+ """
+ query = (
+ "SELECT ProbeSetFreeze.Name AS dataset_name, InbredSet.Name, "
+ "InbredSet.Id "
+ "FROM InbredSet, ProbeSetFreeze, ProbeFreeze "
+ "WHERE ProbeFreeze.InbredSetId = InbredSet.Id "
+ "AND ProbeFreeze.Id = ProbeSetFreeze.ProbeFreezeId "
+ f"AND ProbeSetFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))})")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(query, tuple(dataset_names))
+ return organise_groups_by_dataset(cursor.fetchall())
+ return {}
+
+def probeset_traits_datasets(conn: Any, threshold, traits: Tuple[Dict]):
+ """Retrive datasets for 'ProbeSet' traits."""
+ dataset_names = tuple(set(trait["db"]["dataset_name"] for trait in traits))
+ dataset_names_info = probeset_datasets_names(conn, threshold, dataset_names)
+ dataset_groups = probeset_datasets_groups(conn, dataset_names)
+ return tuple({
+ **trait,
+ "db": {
+ **trait["db"],
+ **dataset_names_info.get(trait["db"]["dataset_name"], {}),
+ **dataset_groups.get(trait["db"]["dataset_name"], {})
+ }
+ } for trait in traits)
+
+def geno_datasets_names(conn, threshold, dataset_names):
+ """
+ Get the ID, DataScale and various name formats for a `Geno` trait.
+ """
+ query = (
+ "SELECT Id AS dataset_id, Name AS dataset_name, "
+ "FullName AS dataset_fullname, ShortName AS dataset_short_name "
+ "FROM GenoFreeze "
+ "WHERE "
+ "public > %s "
+ "AND "
+ "(Name IN ({names}) OR FullName IN ({names}) OR ShortName IN ({names}))")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query.format(names=", ".join(["%s"] * len(dataset_names))),
+ (threshold,) + (tuple(dataset_names) * 3))
+ return {ds["dataset_name"]: ds for ds in cursor.fetchall()}
+ return {}
+
+def geno_datasets_groups(conn, dataset_names):
+ """
+ Retrieve the Group, and GroupID values for various Geno trait types.
+ """
+ query = (
+ "SELECT GenoFreeze.Name AS dataset_name, InbredSet.Name, InbredSet.Id "
+ "FROM InbredSet, GenoFreeze "
+ "WHERE GenoFreeze.InbredSetId = InbredSet.Id "
+ f"AND GenoFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))})")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(query, tuple(dataset_names))
+ return organise_groups_by_dataset(cursor.fetchall())
+ return {}
+
+def geno_traits_datasets(conn: Any, threshold: int, traits: Tuple[Dict]):
+ """Retrieve datasets for 'Geno' traits."""
+ dataset_names = tuple(set(trait["db"]["dataset_name"] for trait in traits))
+ dataset_names_info = geno_datasets_names(conn, threshold, dataset_names)
+ dataset_groups = geno_datasets_groups(conn, dataset_names)
+ return tuple({
+ **trait,
+ "db": {
+ **trait["db"],
+ **dataset_names_info.get(trait["db"]["dataset_name"], {}),
+ **dataset_groups.get(trait["db"]["dataset_name"], {})
+ }
+ } for trait in traits)
+
+def temp_datasets_groups(conn, dataset_names):
+ """
+ Retrieve the Group, and GroupID values for `Temp` trait types.
+ """
+ query = (
+ "SELECT Temp.Name AS dataset_name, InbredSet.Name, InbredSet.Id "
+ "FROM InbredSet, Temp "
+ "WHERE Temp.InbredSetId = InbredSet.Id "
+ f"AND Temp.Name IN ({', '.join(['%s'] * len(dataset_names))})")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(query, tuple(dataset_names))
+ return organise_groups_by_dataset(cursor.fetchall())
+ return {}
+
+def temp_traits_datasets(conn: Any, threshold: int, traits: Tuple[Dict]): #pylint: disable=[W0613]
+ """
+ Retrieve datasets for 'Temp' traits.
+ """
+ dataset_names = tuple(set(trait["db"]["dataset_name"] for trait in traits))
+ dataset_groups = temp_datasets_groups(conn, dataset_names)
+ return tuple({
+ **trait,
+ "db": {
+ **trait["db"],
+ **dataset_groups.get(trait["db"]["dataset_name"], {})
+ }
+ } for trait in traits)
+
+def set_confidential(traits):
+ """
+ Set the confidential field for traits of type `Publish`.
+ """
+ return tuple({
+ **trait,
+ "confidential": (
+ True if (# pylint: disable=[R1719]
+ trait.get("pre_publication_description")
+ and not trait.get("pubmed_id"))
+ else False)
+ } for trait in traits)
+
+def query_qtl_info(conn, query, traits, dataset_ids):
+ """
+ Utility: Run the `query` to get the QTL information for the given `traits`.
+ """
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ tuple(trait["trait_name"] for trait in traits) + dataset_ids)
+ results = {
+ row["trait_name"]: {
+ key: val for key, val in row if key != "trait_name"
+ } for row in cursor.fetchall()
+ }
+ return tuple(
+ {**trait, **results.get(trait["trait_name"], {})}
+ for trait in traits)
+
+def set_publish_qtl_info(conn, qtl, traits):
+ """
+ Load extra QTL information for `Publish` traits
+ """
+ if qtl:
+ dataset_ids = set(trait["db"]["dataset_id"] for trait in traits)
+ query = (
+ "SELECT PublishXRef.Id AS trait_name, PublishXRef.Locus, "
+ "PublishXRef.LRS, PublishXRef.additive "
+ "FROM PublishXRef, PublishFreeze "
+ f"WHERE PublishXRef.Id IN ({', '.join(['%s'] * len(traits))}) "
+ "AND PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
+ f"AND PublishFreeze.Id IN ({', '.join(['%s'] * len(dataset_ids))})")
+ return query_qtl_info(conn, query, traits, tuple(dataset_ids))
+ return traits
+
+def set_probeset_qtl_info(conn, qtl, traits):
+ """
+ Load extra QTL information for `ProbeSet` traits
+ """
+ if qtl:
+ dataset_ids = tuple(set(trait["db"]["dataset_id"] for trait in traits))
+ query = (
+ "SELECT ProbeSet.Name AS trait_name, ProbeSetXRef.Locus, "
+ "ProbeSetXRef.LRS, ProbeSetXRef.pValue, "
+ "ProbeSetXRef.mean, ProbeSetXRef.additive "
+ "FROM ProbeSetXRef, ProbeSet "
+ "WHERE ProbeSetXRef.ProbeSetId = ProbeSet.Id "
+ f"AND ProbeSet.Name IN ({', '.join(['%s'] * len(traits))}) "
+ "AND ProbeSetXRef.ProbeSetFreezeId IN "
+ f"({', '.join(['%s'] * len(dataset_ids))})")
+ return query_qtl_info(conn, query, traits, tuple(dataset_ids))
+ return traits
+
+def set_sequence(conn, traits):
+ """
+ Retrieve 'ProbeSet' traits sequence information
+ """
+ dataset_names = set(trait["db"]["dataset_name"] for trait in traits)
+ query = (
+ "SELECT ProbeSet.Name as trait_name, ProbeSet.BlatSeq "
+ "FROM ProbeSet, ProbeSetFreeze, ProbeSetXRef "
+ "WHERE ProbeSet.Id=ProbeSetXRef.ProbeSetId "
+ "AND ProbeSetFreeze.Id = ProbeSetXRef.ProbeSetFreezeId "
+ f"AND ProbeSet.Name IN ({', '.join(['%s'] * len(traits))}) "
+ f"AND ProbeSetFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))})")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ (tuple(trait["trait_name"] for trait in traits) +
+ tuple(dataset_names)))
+ results = {
+ row["trait_name"]: {
+ key: val for key, val in row.items() if key != "trait_name"
+ } for row in cursor.fetchall()
+ }
+ return tuple(
+ {
+ **trait,
+ **results.get(trait["trait_name"], {})
+ } for trait in traits)
+ return traits
+
+def set_homologene_id(conn, traits):
+ """
+ Retrieve and set the 'homologene_id' values for ProbeSet traits.
+ """
+ geneids = set(trait.get("geneid") for trait in traits if trait["haveinfo"])
+ groups = set(
+ trait["db"].get("group") for trait in traits if trait["haveinfo"])
+ if len(geneids) > 1 and len(groups) > 1:
+ query = (
+ "SELECT InbredSet.Name AS `group`, Homologene.GeneId AS geneid, "
+ "HomologeneId "
+ "FROM Homologene, Species, InbredSet "
+ f"WHERE Homologene.GeneId IN ({', '.join(['%s'] * len(geneids))}) "
+ f"AND InbredSet.Name IN ({', '.join(['%s'] * len(groups))}) "
+ "AND InbredSet.SpeciesId = Species.Id "
+ "AND Species.TaxonomyId = Homologene.TaxonomyId")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(query, (tuple(geneids) + tuple(groups)))
+ results = {
+ row["group"]: {
+ row["geneid"]: {
+ key: val for key, val in row.items()
+ if key not in ("group", "geneid")
+ }
+ } for row in cursor.fetchall()
+ }
+ return tuple(
+ {
+ **trait, **results.get(
+ trait["db"]["group"], {}).get(trait["geneid"], {})
+ } for trait in traits)
+ return traits
+
+def traits_datasets(conn, threshold, traits):
+ """
+ Retrieve datasets for various `traits`.
+ """
+ dataset_fns = {
+ "Temp": temp_traits_datasets,
+ "Geno": geno_traits_datasets,
+ "Publish": publish_traits_datasets,
+ "ProbeSet": probeset_traits_datasets
+ }
+ def __organise_by_type__(acc, trait):
+ dataset_type = trait["db"]["dataset_type"]
+ return {
+ **acc,
+ dataset_type: acc.get(dataset_type, tuple()) + (trait,)
+ }
+ with_datasets = {
+ trait["trait_fullname"]: trait for trait in (
+ item for sublist in (
+ dataset_fns[dtype](conn, threshold, ttraits)
+ for dtype, ttraits
+ in reduce(__organise_by_type__, traits, {}).items())
+ for item in sublist)}
+ return tuple(
+ {**trait, **with_datasets.get(trait["trait_fullname"], {})}
+ for trait in traits)
+
+def traits_info(
+ conn: Any, threshold: int, traits_fullnames: Tuple[str, ...],
+ qtl=None) -> Tuple[Dict[str, Any], ...]:
+ """
+ Retrieve basic trait information for multiple `traits`.
+
+ This is a rework of the `gn3.db.traits.retrieve_trait_info` function.
+ """
+ def __organise_by_dataset_type__(acc, trait):
+ dataset_type = trait["db"]["dataset_type"]
+ return {
+ **acc,
+ dataset_type: acc.get(dataset_type, tuple()) + (trait,)
+ }
+ traits = traits_datasets(
+ conn, threshold,
+ tuple(build_trait_name(trait) for trait in traits_fullnames))
+ traits_fns = {
+ "Publish": compose(
+ set_confidential, partial(set_publish_qtl_info, conn, qtl),
+ partial(publish_traits_info, conn),
+ partial(publish_traits_datasets, conn, threshold)),
+ "ProbeSet": compose(
+ partial(set_sequence, conn),
+ partial(set_probeset_qtl_info, conn, qtl),
+ partial(set_homologene_id, conn),
+ partial(probeset_traits_info, conn),
+ partial(probeset_traits_datasets, conn, threshold)),
+ "Geno": compose(
+ partial(geno_traits_info, conn),
+ partial(geno_traits_datasets, conn, threshold)),
+ "Temp": compose(
+ partial(temp_traits_info, conn),
+ partial(temp_traits_datasets, conn, threshold))
+ }
+ return tuple(
+ trait for sublist in (# type: ignore[var-annotated]
+ traits_fns[dataset_type](traits)
+ for dataset_type, traits
+ in reduce(__organise_by_dataset_type__, traits, {}).items())
+ for trait in sublist)
diff --git a/gn3/db/sample_data.py b/gn3/db/sample_data.py
new file mode 100644
index 0000000..f73954f
--- /dev/null
+++ b/gn3/db/sample_data.py
@@ -0,0 +1,365 @@
+"""Module containing functions that work with sample data"""
+from typing import Any, Tuple, Dict, Callable
+
+import MySQLdb
+
+from gn3.csvcmp import extract_strain_name
+
+
+_MAP = {
+ "PublishData": ("StrainId", "Id", "value"),
+ "PublishSE": ("StrainId", "DataId", "error"),
+ "NStrain": ("StrainId", "DataId", "count"),
+}
+
+
+def __extract_actions(original_data: str,
+ updated_data: str,
+ csv_header: str) -> Dict:
+ """Return a dictionary containing elements that need to be deleted, inserted,
+or updated.
+
+ """
+ result: Dict[str, Any] = {
+ "delete": {"data": [], "csv_header": []},
+ "insert": {"data": [], "csv_header": []},
+ "update": {"data": [], "csv_header": []},
+ }
+ strain_name = ""
+ for _o, _u, _h in zip(original_data.strip().split(","),
+ updated_data.strip().split(","),
+ csv_header.strip().split(",")):
+ if _h == "Strain Name":
+ strain_name = _o
+ if _o == _u: # No change
+ continue
+ if _o and _u == "x": # Deletion
+ result["delete"]["data"].append(_o)
+ result["delete"]["csv_header"].append(_h)
+ elif _o == "x" and _u: # Insert
+ result["insert"]["data"].append(_u)
+ result["insert"]["csv_header"].append(_h)
+ elif _o and _u: # Update
+ result["update"]["data"].append(_u)
+ result["update"]["csv_header"].append(_h)
+ for key, val in result.items():
+ if not val["data"]:
+ result[key] = None
+ else:
+ result[key]["data"] = (f"{strain_name}," +
+ ",".join(result[key]["data"]))
+ result[key]["csv_header"] = ("Strain Name," +
+ ",".join(result[key]["csv_header"]))
+ return result
+
+
+def get_trait_csv_sample_data(conn: Any,
+ trait_name: int, phenotype_id: int) -> str:
+ """Fetch a trait and return it as a csv string"""
+ __query = ("SELECT concat(st.Name, ',', ifnull(pd.value, 'x'), ',', "
+ "ifnull(ps.error, 'x'), ',', ifnull(ns.count, 'x')) as 'Data' "
+ ",ifnull(ca.Name, 'x') as 'CaseAttr', "
+ "ifnull(cxref.value, 'x') as 'Value' "
+ "FROM PublishFreeze pf "
+ "JOIN PublishXRef px ON px.InbredSetId = pf.InbredSetId "
+ "JOIN PublishData pd ON pd.Id = px.DataId "
+ "JOIN Strain st ON pd.StrainId = st.Id "
+ "LEFT JOIN PublishSE ps ON ps.DataId = pd.Id "
+ "AND ps.StrainId = pd.StrainId "
+ "LEFT JOIN NStrain ns ON ns.DataId = pd.Id "
+ "AND ns.StrainId = pd.StrainId "
+ "LEFT JOIN CaseAttributeXRefNew cxref ON "
+ "(cxref.InbredSetId = px.InbredSetId AND "
+ "cxref.StrainId = st.Id) "
+ "LEFT JOIN CaseAttribute ca ON ca.Id = cxref.CaseAttributeId "
+ "WHERE px.Id = %s AND px.PhenotypeId = %s ORDER BY st.Name")
+ case_attr_columns = set()
+ csv_data: Dict = {}
+ with conn.cursor() as cursor:
+ cursor.execute(__query, (trait_name, phenotype_id))
+ for data in cursor.fetchall():
+ if data[1] == "x":
+ csv_data[data[0]] = None
+ else:
+ sample, case_attr, value = data[0], data[1], data[2]
+ if not csv_data.get(sample):
+ csv_data[sample] = {}
+ csv_data[sample][case_attr] = None if value == "x" else value
+ case_attr_columns.add(case_attr)
+ if not case_attr_columns:
+ return ("Strain Name,Value,SE,Count\n" +
+ "\n".join(csv_data.keys()))
+ columns = sorted(case_attr_columns)
+ csv = ("Strain Name,Value,SE,Count," +
+ ",".join(columns) + "\n")
+ for key, value in csv_data.items():
+ if not value:
+ csv += (key + (len(case_attr_columns) * ",x") + "\n")
+ else:
+ vals = [str(value.get(column, "x")) for column in columns]
+ csv += (key + "," + ",".join(vals) + "\n")
+ return csv
+ return "No Sample Data Found"
+
+
+def get_sample_data_ids(conn: Any, publishxref_id: int,
+ phenotype_id: int,
+ strain_name: str) -> Tuple:
+ """Get the strain_id, publishdata_id and inbredset_id for a given strain"""
+ strain_id, publishdata_id, inbredset_id = None, None, None
+ with conn.cursor() as cursor:
+ cursor.execute("SELECT st.id, pd.Id, pf.InbredSetId "
+ "FROM PublishData pd "
+ "JOIN Strain st ON pd.StrainId = st.Id "
+ "JOIN PublishXRef px ON px.DataId = pd.Id "
+ "JOIN PublishFreeze pf ON pf.InbredSetId "
+ "= px.InbredSetId WHERE px.Id = %s "
+ "AND px.PhenotypeId = %s AND st.Name = %s",
+ (publishxref_id, phenotype_id, strain_name))
+ if _result := cursor.fetchone():
+ strain_id, publishdata_id, inbredset_id = _result
+ if not all([strain_id, publishdata_id, inbredset_id]):
+ # Applies for data to be inserted:
+ cursor.execute("SELECT DataId, InbredSetId FROM PublishXRef "
+ "WHERE Id = %s AND PhenotypeId = %s",
+ (publishxref_id, phenotype_id))
+ publishdata_id, inbredset_id = cursor.fetchone()
+ cursor.execute("SELECT Id FROM Strain WHERE Name = %s",
+ (strain_name,))
+ strain_id = cursor.fetchone()[0]
+ return (strain_id, publishdata_id, inbredset_id)
+
+
+# pylint: disable=[R0913, R0914]
+def update_sample_data(conn: Any,
+ trait_name: str,
+ original_data: str,
+ updated_data: str,
+ csv_header: str,
+ phenotype_id: int) -> int:
+ """Given the right parameters, update sample-data from the relevant
+ table."""
+ def __update_data(conn, table, value):
+ if value and value != "x":
+ with conn.cursor() as cursor:
+ sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) +
+ " = %s")
+ _val = _MAP.get(table)[-1]
+ cursor.execute((f"UPDATE {table} SET {_val} = %s "
+ f"WHERE {sub_query}"),
+ (value, strain_id, data_id))
+ return cursor.rowcount
+ return 0
+
+ def __update_case_attribute(conn, value, strain_id,
+ case_attr, inbredset_id):
+ if value != "x":
+ with conn.cursor() as cursor:
+ cursor.execute(
+ "UPDATE CaseAttributeXRefNew "
+ "SET Value = %s "
+ "WHERE StrainId = %s AND CaseAttributeId = "
+ "(SELECT CaseAttributeId FROM "
+ "CaseAttribute WHERE Name = %s) "
+ "AND InbredSetId = %s",
+ (value, strain_id, case_attr, inbredset_id))
+ return cursor.rowcount
+ return 0
+
+ strain_id, data_id, inbredset_id = get_sample_data_ids(
+ conn=conn, publishxref_id=int(trait_name),
+ phenotype_id=phenotype_id,
+ strain_name=extract_strain_name(csv_header, original_data))
+
+ none_case_attrs: Dict[str, Callable] = {
+ "Strain Name": lambda x: 0,
+ "Value": lambda x: __update_data(conn, "PublishData", x),
+ "SE": lambda x: __update_data(conn, "PublishSE", x),
+ "Count": lambda x: __update_data(conn, "NStrain", x),
+ }
+ count = 0
+ try:
+ __actions = __extract_actions(original_data=original_data,
+ updated_data=updated_data,
+ csv_header=csv_header)
+ if __actions.get("update"):
+ _csv_header = __actions["update"]["csv_header"]
+ _data = __actions["update"]["data"]
+ # pylint: disable=[E1101]
+ for header, value in zip(_csv_header.split(","),
+ _data.split(",")):
+ header = header.strip()
+ value = value.strip()
+ if header in none_case_attrs:
+ count += none_case_attrs[header](value)
+ else:
+ count += __update_case_attribute(
+ conn=conn,
+ value=none_case_attrs[header](value),
+ strain_id=strain_id,
+ case_attr=header,
+ inbredset_id=inbredset_id)
+ if __actions.get("delete"):
+ _rowcount = delete_sample_data(
+ conn=conn,
+ trait_name=trait_name,
+ data=__actions["delete"]["data"],
+ csv_header=__actions["delete"]["csv_header"],
+ phenotype_id=phenotype_id)
+ if _rowcount:
+ count += 1
+ if __actions.get("insert"):
+ _rowcount = insert_sample_data(
+ conn=conn,
+ trait_name=trait_name,
+ data=__actions["insert"]["data"],
+ csv_header=__actions["insert"]["csv_header"],
+ phenotype_id=phenotype_id)
+ if _rowcount:
+ count += 1
+ except Exception as _e:
+ conn.rollback()
+ raise MySQLdb.Error(_e) from _e
+ conn.commit()
+ return count
+
+
+def delete_sample_data(conn: Any,
+ trait_name: str,
+ data: str,
+ csv_header: str,
+ phenotype_id: int) -> int:
+ """Given the right parameters, delete sample-data from the relevant
+ tables."""
+ def __delete_data(conn, table):
+ sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) + " = %s")
+ with conn.cursor() as cursor:
+ cursor.execute((f"DELETE FROM {table} "
+ f"WHERE {sub_query}"),
+ (strain_id, data_id))
+ return cursor.rowcount
+
+ def __delete_case_attribute(conn, strain_id,
+ case_attr, inbredset_id):
+ with conn.cursor() as cursor:
+ cursor.execute(
+ "DELETE FROM CaseAttributeXRefNew "
+ "WHERE StrainId = %s AND CaseAttributeId = "
+ "(SELECT CaseAttributeId FROM "
+ "CaseAttribute WHERE Name = %s) "
+ "AND InbredSetId = %s",
+ (strain_id, case_attr, inbredset_id))
+ return cursor.rowcount
+
+ strain_id, data_id, inbredset_id = get_sample_data_ids(
+ conn=conn, publishxref_id=int(trait_name),
+ phenotype_id=phenotype_id,
+ strain_name=extract_strain_name(csv_header, data))
+
+ none_case_attrs: Dict[str, Any] = {
+ "Strain Name": lambda: 0,
+ "Value": lambda: __delete_data(conn, "PublishData"),
+ "SE": lambda: __delete_data(conn, "PublishSE"),
+ "Count": lambda: __delete_data(conn, "NStrain"),
+ }
+ count = 0
+
+ try:
+ for header in csv_header.split(","):
+ header = header.strip()
+ if header in none_case_attrs:
+ count += none_case_attrs[header]()
+ else:
+ count += __delete_case_attribute(
+ conn=conn,
+ strain_id=strain_id,
+ case_attr=header,
+ inbredset_id=inbredset_id)
+ except Exception as _e:
+ conn.rollback()
+ raise MySQLdb.Error(_e) from _e
+ conn.commit()
+ return count
+
+
+# pylint: disable=[R0913, R0914]
+def insert_sample_data(conn: Any,
+ trait_name: str,
+ data: str,
+ csv_header: str,
+ phenotype_id: int) -> int:
+ """Given the right parameters, insert sample-data to the relevant table.
+
+ """
+ def __insert_data(conn, table, value):
+ if value and value != "x":
+ with conn.cursor() as cursor:
+ columns = ", ".join(_MAP.get(table))
+ cursor.execute((f"INSERT INTO {table} "
+ f"({columns}) "
+ f"VALUES (%s, %s, %s)"),
+ (strain_id, data_id, value))
+ return cursor.rowcount
+ return 0
+
+ def __insert_case_attribute(conn, case_attr, value):
+ if value != "x":
+ with conn.cursor() as cursor:
+ cursor.execute("SELECT Id FROM "
+ "CaseAttribute WHERE Name = %s",
+ (case_attr,))
+ if case_attr_id := cursor.fetchone():
+ case_attr_id = case_attr_id[0]
+ cursor.execute("SELECT StrainId FROM "
+ "CaseAttributeXRefNew WHERE StrainId = %s "
+ "AND CaseAttributeId = %s "
+ "AND InbredSetId = %s",
+ (strain_id, case_attr_id, inbredset_id))
+ if (not cursor.fetchone()) and case_attr_id:
+ cursor.execute(
+ "INSERT INTO CaseAttributeXRefNew "
+ "(StrainId, CaseAttributeId, Value, InbredSetId) "
+ "VALUES (%s, %s, %s, %s)",
+ (strain_id, case_attr_id, value, inbredset_id))
+ row_count = cursor.rowcount
+ return row_count
+ return 0
+
+ strain_id, data_id, inbredset_id = get_sample_data_ids(
+ conn=conn, publishxref_id=int(trait_name),
+ phenotype_id=phenotype_id,
+ strain_name=extract_strain_name(csv_header, data))
+
+ none_case_attrs: Dict[str, Any] = {
+ "Strain Name": lambda _: 0,
+ "Value": lambda x: __insert_data(conn, "PublishData", x),
+ "SE": lambda x: __insert_data(conn, "PublishSE", x),
+ "Count": lambda x: __insert_data(conn, "NStrain", x),
+ }
+
+ try:
+ count = 0
+
+ # Check if the data already exists:
+ with conn.cursor() as cursor:
+ cursor.execute(
+ "SELECT Id FROM PublishData where Id = %s "
+ "AND StrainId = %s",
+ (data_id, strain_id))
+ if cursor.fetchone(): # Data already exists
+ return count
+
+ for header, value in zip(csv_header.split(","), data.split(",")):
+ header = header.strip()
+ value = value.strip()
+ if header in none_case_attrs:
+ count += none_case_attrs[header](value)
+ else:
+ count += __insert_case_attribute(
+ conn=conn,
+ case_attr=header,
+ value=value)
+ return count
+ except Exception as _e:
+ conn.rollback()
+ raise MySQLdb.Error(_e) from _e
diff --git a/gn3/db/species.py b/gn3/db/species.py
index 702a9a8..5b8e096 100644
--- a/gn3/db/species.py
+++ b/gn3/db/species.py
@@ -57,3 +57,20 @@ def translate_to_mouse_gene_id(species: str, geneid: int, conn: Any) -> int:
return translated_gene_id[0]
return 0 # default if all else fails
+
+def species_name(conn: Any, group: str) -> str:
+ """
+ Retrieve the name of the species, given the group (RISet).
+
+ This is a migration of the
+ `web.webqtl.dbFunction.webqtlDatabaseFunction.retrieveSpecies` function in
+ GeneNetwork1.
+ """
+ with conn.cursor() as cursor:
+ cursor.execute(
+ ("SELECT Species.Name FROM Species, InbredSet "
+ "WHERE InbredSet.Name = %(group_name)s "
+ "AND InbredSet.SpeciesId = Species.Id"),
+ {"group_name": group})
+ return cursor.fetchone()[0]
+ return None
diff --git a/gn3/db/traits.py b/gn3/db/traits.py
index 1c6aaa7..f722e24 100644
--- a/gn3/db/traits.py
+++ b/gn3/db/traits.py
@@ -1,7 +1,7 @@
"""This class contains functions relating to trait data manipulation"""
import os
from functools import reduce
-from typing import Any, Dict, Union, Sequence
+from typing import Any, Dict, Sequence
from gn3.settings import TMPDIR
from gn3.random import random_string
@@ -67,7 +67,7 @@ def export_trait_data(
return accumulator + (trait_data["data"][sample]["ndata"], )
if dtype == "all":
return accumulator + __export_all_types(trait_data["data"], sample)
- raise KeyError("Type `%s` is incorrect" % dtype)
+ raise KeyError(f"Type `{dtype}` is incorrect")
if var_exists and n_exists:
return accumulator + (None, None, None)
if var_exists or n_exists:
@@ -76,80 +76,6 @@ def export_trait_data(
return reduce(__exporter, samplelist, tuple())
-def get_trait_csv_sample_data(conn: Any,
- trait_name: int, phenotype_id: int):
- """Fetch a trait and return it as a csv string"""
- sql = ("SELECT DISTINCT Strain.Id, PublishData.Id, Strain.Name, "
- "PublishData.value, "
- "PublishSE.error, NStrain.count FROM "
- "(PublishData, Strain, PublishXRef, PublishFreeze) "
- "LEFT JOIN PublishSE ON "
- "(PublishSE.DataId = PublishData.Id AND "
- "PublishSE.StrainId = PublishData.StrainId) "
- "LEFT JOIN NStrain ON (NStrain.DataId = PublishData.Id AND "
- "NStrain.StrainId = PublishData.StrainId) WHERE "
- "PublishXRef.InbredSetId = PublishFreeze.InbredSetId AND "
- "PublishData.Id = PublishXRef.DataId AND "
- "PublishXRef.Id = %s AND PublishXRef.PhenotypeId = %s "
- "AND PublishData.StrainId = Strain.Id Order BY Strain.Name")
- csv_data = ["Strain Id,Strain Name,Value,SE,Count"]
- publishdata_id = ""
- with conn.cursor() as cursor:
- cursor.execute(sql, (trait_name, phenotype_id,))
- for record in cursor.fetchall():
- (strain_id, publishdata_id,
- strain_name, value, error, count) = record
- csv_data.append(
- ",".join([str(val) if val else "x"
- for val in (strain_id, strain_name,
- value, error, count)]))
- return f"# Publish Data Id: {publishdata_id}\n\n" + "\n".join(csv_data)
-
-
-def update_sample_data(conn: Any,
- strain_name: str,
- strain_id: int,
- publish_data_id: int,
- value: Union[int, float, str],
- error: Union[int, float, str],
- count: Union[int, str]):
- """Given the right parameters, update sample-data from the relevant
- table."""
- # pylint: disable=[R0913, R0914, C0103]
- STRAIN_ID_SQL: str = "UPDATE Strain SET Name = %s WHERE Id = %s"
- PUBLISH_DATA_SQL: str = ("UPDATE PublishData SET value = %s "
- "WHERE StrainId = %s AND Id = %s")
- PUBLISH_SE_SQL: str = ("UPDATE PublishSE SET error = %s "
- "WHERE StrainId = %s AND DataId = %s")
- N_STRAIN_SQL: str = ("UPDATE NStrain SET count = %s "
- "WHERE StrainId = %s AND DataId = %s")
-
- updated_strains: int = 0
- updated_published_data: int = 0
- updated_se_data: int = 0
- updated_n_strains: int = 0
-
- with conn.cursor() as cursor:
- # Update the Strains table
- cursor.execute(STRAIN_ID_SQL, (strain_name, strain_id))
- updated_strains = cursor.rowcount
- # Update the PublishData table
- cursor.execute(PUBLISH_DATA_SQL,
- (None if value == "x" else value,
- strain_id, publish_data_id))
- updated_published_data = cursor.rowcount
- # Update the PublishSE table
- cursor.execute(PUBLISH_SE_SQL,
- (None if error == "x" else error,
- strain_id, publish_data_id))
- updated_se_data = cursor.rowcount
- # Update the NStrain table
- cursor.execute(N_STRAIN_SQL,
- (None if count == "x" else count,
- strain_id, publish_data_id))
- updated_n_strains = cursor.rowcount
- return (updated_strains, updated_published_data,
- updated_se_data, updated_n_strains)
def retrieve_publish_trait_info(trait_data_source: Dict[str, Any], conn: Any):
"""Retrieve trait information for type `Publish` traits.
@@ -177,24 +103,24 @@ def retrieve_publish_trait_info(trait_data_source: Dict[str, Any], conn: Any):
"PublishXRef.comments")
query = (
"SELECT "
- "{columns} "
+ f"{columns} "
"FROM "
- "PublishXRef, Publication, Phenotype, PublishFreeze "
+ "PublishXRef, Publication, Phenotype "
"WHERE "
"PublishXRef.Id = %(trait_name)s AND "
"Phenotype.Id = PublishXRef.PhenotypeId AND "
"Publication.Id = PublishXRef.PublicationId AND "
- "PublishXRef.InbredSetId = PublishFreeze.InbredSetId AND "
- "PublishFreeze.Id =%(trait_dataset_id)s").format(columns=columns)
+ "PublishXRef.InbredSetId = %(trait_dataset_id)s")
with conn.cursor() as cursor:
cursor.execute(
query,
{
- k:v for k, v in trait_data_source.items()
+ k: v for k, v in trait_data_source.items()
if k in ["trait_name", "trait_dataset_id"]
})
return dict(zip([k.lower() for k in keys], cursor.fetchone()))
+
def set_confidential_field(trait_type, trait_info):
"""Post processing function for 'Publish' trait types.
@@ -207,6 +133,7 @@ def set_confidential_field(trait_type, trait_info):
and not trait_info.get("pubmed_id", None)) else 0}
return trait_info
+
def retrieve_probeset_trait_info(trait_data_source: Dict[str, Any], conn: Any):
"""Retrieve trait information for type `ProbeSet` traits.
@@ -219,67 +146,68 @@ def retrieve_probeset_trait_info(trait_data_source: Dict[str, Any], conn: Any):
"probe_set_specificity", "probe_set_blat_score",
"probe_set_blat_mb_start", "probe_set_blat_mb_end", "probe_set_strand",
"probe_set_note_by_rw", "flag")
+ columns = (f"ProbeSet.{x}" for x in keys)
query = (
- "SELECT "
- "{columns} "
+ f"SELECT {', '.join(columns)} "
"FROM "
"ProbeSet, ProbeSetFreeze, ProbeSetXRef "
"WHERE "
"ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id AND "
"ProbeSetXRef.ProbeSetId = ProbeSet.Id AND "
"ProbeSetFreeze.Name = %(trait_dataset_name)s AND "
- "ProbeSet.Name = %(trait_name)s").format(
- columns=", ".join(["ProbeSet.{}".format(x) for x in keys]))
+ "ProbeSet.Name = %(trait_name)s")
with conn.cursor() as cursor:
cursor.execute(
query,
{
- k:v for k, v in trait_data_source.items()
+ k: v for k, v in trait_data_source.items()
if k in ["trait_name", "trait_dataset_name"]
})
return dict(zip(keys, cursor.fetchone()))
+
def retrieve_geno_trait_info(trait_data_source: Dict[str, Any], conn: Any):
"""Retrieve trait information for type `Geno` traits.
https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L438-L449"""
keys = ("name", "chr", "mb", "source2", "sequence")
+ columns = ", ".join(f"Geno.{x}" for x in keys)
query = (
- "SELECT "
- "{columns} "
+ f"SELECT {columns} "
"FROM "
- "Geno, GenoFreeze, GenoXRef "
+ "Geno INNER JOIN GenoXRef ON GenoXRef.GenoId = Geno.Id "
+ "INNER JOIN GenoFreeze ON GenoFreeze.Id = GenoXRef.GenoFreezeId "
"WHERE "
- "GenoXRef.GenoFreezeId = GenoFreeze.Id AND GenoXRef.GenoId = Geno.Id AND "
"GenoFreeze.Name = %(trait_dataset_name)s AND "
- "Geno.Name = %(trait_name)s").format(
- columns=", ".join(["Geno.{}".format(x) for x in keys]))
+ "Geno.Name = %(trait_name)s")
with conn.cursor() as cursor:
cursor.execute(
query,
{
- k:v for k, v in trait_data_source.items()
+ k: v for k, v in trait_data_source.items()
if k in ["trait_name", "trait_dataset_name"]
})
return dict(zip(keys, cursor.fetchone()))
+
def retrieve_temp_trait_info(trait_data_source: Dict[str, Any], conn: Any):
"""Retrieve trait information for type `Temp` traits.
https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L450-452"""
keys = ("name", "description")
query = (
- "SELECT {columns} FROM Temp "
- "WHERE Name = %(trait_name)s").format(columns=", ".join(keys))
+ f"SELECT {', '.join(keys)} FROM Temp "
+ "WHERE Name = %(trait_name)s")
with conn.cursor() as cursor:
cursor.execute(
query,
{
- k:v for k, v in trait_data_source.items()
+ k: v for k, v in trait_data_source.items()
if k in ["trait_name"]
})
return dict(zip(keys, cursor.fetchone()))
+
def set_haveinfo_field(trait_info):
"""
Common postprocessing function for all trait types.
@@ -287,6 +215,7 @@ def set_haveinfo_field(trait_info):
Sets the value for the 'haveinfo' field."""
return {**trait_info, "haveinfo": 1 if trait_info else 0}
+
def set_homologene_id_field_probeset(trait_info, conn):
"""
Postprocessing function for 'ProbeSet' traits.
@@ -302,7 +231,7 @@ def set_homologene_id_field_probeset(trait_info, conn):
cursor.execute(
query,
{
- k:v for k, v in trait_info.items()
+ k: v for k, v in trait_info.items()
if k in ["geneid", "group"]
})
res = cursor.fetchone()
@@ -310,12 +239,13 @@ def set_homologene_id_field_probeset(trait_info, conn):
return {**trait_info, "homologeneid": res[0]}
return {**trait_info, "homologeneid": None}
+
def set_homologene_id_field(trait_type, trait_info, conn):
"""
Common postprocessing function for all trait types.
Sets the value for the 'homologene' key."""
- set_to_null = lambda ti: {**ti, "homologeneid": None}
+ def set_to_null(ti): return {**ti, "homologeneid": None} # pylint: disable=[C0103, C0321]
functions_table = {
"Temp": set_to_null,
"Geno": set_to_null,
@@ -324,6 +254,7 @@ def set_homologene_id_field(trait_type, trait_info, conn):
}
return functions_table[trait_type](trait_info)
+
def load_publish_qtl_info(trait_info, conn):
"""
Load extra QTL information for `Publish` traits
@@ -344,6 +275,7 @@ def load_publish_qtl_info(trait_info, conn):
return dict(zip(["locus", "lrs", "additive"], cursor.fetchone()))
return {"locus": "", "lrs": "", "additive": ""}
+
def load_probeset_qtl_info(trait_info, conn):
"""
Load extra QTL information for `ProbeSet` traits
@@ -366,6 +298,7 @@ def load_probeset_qtl_info(trait_info, conn):
["locus", "lrs", "pvalue", "mean", "additive"], cursor.fetchone()))
return {"locus": "", "lrs": "", "pvalue": "", "mean": "", "additive": ""}
+
def load_qtl_info(qtl, trait_type, trait_info, conn):
"""
Load extra QTL information for traits
@@ -389,11 +322,12 @@ def load_qtl_info(qtl, trait_type, trait_info, conn):
"Publish": load_publish_qtl_info,
"ProbeSet": load_probeset_qtl_info
}
- if trait_info["name"] not in qtl_info_functions.keys():
+ if trait_info["name"] not in qtl_info_functions:
return trait_info
return qtl_info_functions[trait_type](trait_info, conn)
+
def build_trait_name(trait_fullname):
"""
Initialises the trait's name, and other values from the search data provided
@@ -408,7 +342,7 @@ def build_trait_name(trait_fullname):
return "ProbeSet"
name_parts = trait_fullname.split("::")
- assert len(name_parts) >= 2, "Name format error"
+ assert len(name_parts) >= 2, f"Name format error: '{trait_fullname}'"
dataset_name = name_parts[0]
dataset_type = dataset_type(dataset_name)
return {
@@ -420,6 +354,7 @@ def build_trait_name(trait_fullname):
"cellid": name_parts[2] if len(name_parts) == 3 else ""
}
+
def retrieve_probeset_sequence(trait, conn):
"""
Retrieve a 'ProbeSet' trait's sequence information
@@ -441,6 +376,7 @@ def retrieve_probeset_sequence(trait, conn):
seq = cursor.fetchone()
return {**trait, "sequence": seq[0] if seq else ""}
+
def retrieve_trait_info(
threshold: int, trait_full_name: str, conn: Any,
qtl=None):
@@ -496,6 +432,7 @@ def retrieve_trait_info(
}
return trait_info
+
def retrieve_temp_trait_data(trait_info: dict, conn: Any):
"""
Retrieve trait data for `Temp` traits.
@@ -514,10 +451,12 @@ def retrieve_temp_trait_data(trait_info: dict, conn: Any):
query,
{"trait_name": trait_info["trait_name"]})
return [dict(zip(
- ["sample_name", "value", "se_error", "nstrain", "id"], row))
+ ["sample_name", "value", "se_error", "nstrain", "id"],
+ row))
for row in cursor.fetchall()]
return []
+
def retrieve_species_id(group, conn: Any):
"""
Retrieve a species id given the Group value
@@ -529,6 +468,7 @@ def retrieve_species_id(group, conn: Any):
return cursor.fetchone()[0]
return None
+
def retrieve_geno_trait_data(trait_info: Dict, conn: Any):
"""
Retrieve trait data for `Geno` traits.
@@ -552,11 +492,14 @@ def retrieve_geno_trait_data(trait_info: Dict, conn: Any):
"dataset_name": trait_info["db"]["dataset_name"],
"species_id": retrieve_species_id(
trait_info["db"]["group"], conn)})
- return [dict(zip(
- ["sample_name", "value", "se_error", "id"], row))
- for row in cursor.fetchall()]
+ return [
+ dict(zip(
+ ["sample_name", "value", "se_error", "id"],
+ row))
+ for row in cursor.fetchall()]
return []
+
def retrieve_publish_trait_data(trait_info: Dict, conn: Any):
"""
Retrieve trait data for `Publish` traits.
@@ -565,17 +508,16 @@ def retrieve_publish_trait_data(trait_info: Dict, conn: Any):
"SELECT "
"Strain.Name, PublishData.value, PublishSE.error, NStrain.count, "
"PublishData.Id "
- "FROM (PublishData, Strain, PublishXRef, PublishFreeze) "
+ "FROM (PublishData, Strain, PublishXRef) "
"LEFT JOIN PublishSE ON "
"(PublishSE.DataId = PublishData.Id "
"AND PublishSE.StrainId = PublishData.StrainId) "
"LEFT JOIN NStrain ON "
"(NStrain.DataId = PublishData.Id "
"AND NStrain.StrainId = PublishData.StrainId) "
- "WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
- "AND PublishData.Id = PublishXRef.DataId "
+ "WHERE PublishData.Id = PublishXRef.DataId "
"AND PublishXRef.Id = %(trait_name)s "
- "AND PublishFreeze.Id = %(dataset_id)s "
+ "AND PublishXRef.InbredSetId = %(dataset_id)s "
"AND PublishData.StrainId = Strain.Id "
"ORDER BY Strain.Name")
with conn.cursor() as cursor:
@@ -583,11 +525,13 @@ def retrieve_publish_trait_data(trait_info: Dict, conn: Any):
query,
{"trait_name": trait_info["trait_name"],
"dataset_id": trait_info["db"]["dataset_id"]})
- return [dict(zip(
- ["sample_name", "value", "se_error", "nstrain", "id"], row))
- for row in cursor.fetchall()]
+ return [
+ dict(zip(
+ ["sample_name", "value", "se_error", "nstrain", "id"], row))
+ for row in cursor.fetchall()]
return []
+
def retrieve_cellid_trait_data(trait_info: Dict, conn: Any):
"""
Retrieve trait data for `Probe Data` types.
@@ -616,11 +560,13 @@ def retrieve_cellid_trait_data(trait_info: Dict, conn: Any):
{"cellid": trait_info["cellid"],
"trait_name": trait_info["trait_name"],
"dataset_id": trait_info["db"]["dataset_id"]})
- return [dict(zip(
- ["sample_name", "value", "se_error", "id"], row))
- for row in cursor.fetchall()]
+ return [
+ dict(zip(
+ ["sample_name", "value", "se_error", "id"], row))
+ for row in cursor.fetchall()]
return []
+
def retrieve_probeset_trait_data(trait_info: Dict, conn: Any):
"""
Retrieve trait data for `ProbeSet` traits.
@@ -645,11 +591,13 @@ def retrieve_probeset_trait_data(trait_info: Dict, conn: Any):
query,
{"trait_name": trait_info["trait_name"],
"dataset_name": trait_info["db"]["dataset_name"]})
- return [dict(zip(
- ["sample_name", "value", "se_error", "id"], row))
- for row in cursor.fetchall()]
+ return [
+ dict(zip(
+ ["sample_name", "value", "se_error", "id"], row))
+ for row in cursor.fetchall()]
return []
+
def with_samplelist_data_setup(samplelist: Sequence[str]):
"""
Build function that computes the trait data from provided list of samples.
@@ -676,6 +624,7 @@ def with_samplelist_data_setup(samplelist: Sequence[str]):
return None
return setup_fn
+
def without_samplelist_data_setup():
"""
Build function that computes the trait data.
@@ -696,6 +645,7 @@ def without_samplelist_data_setup():
return None
return setup_fn
+
def retrieve_trait_data(trait: dict, conn: Any, samplelist: Sequence[str] = tuple()):
"""
Retrieve trait data
@@ -735,14 +685,16 @@ def retrieve_trait_data(trait: dict, conn: Any, samplelist: Sequence[str] = tupl
"data": dict(map(
lambda x: (
x["sample_name"],
- {k:v for k, v in x.items() if x != "sample_name"}),
+ {k: v for k, v in x.items() if x != "sample_name"}),
data))}
return {}
+
def generate_traits_filename(base_path: str = TMPDIR):
"""Generate a unique filename for use with generated traits files."""
- return "{}/traits_test_file_{}.txt".format(
- os.path.abspath(base_path), random_string(10))
+ return (
+ f"{os.path.abspath(base_path)}/traits_test_file_{random_string(10)}.txt")
+
def export_informative(trait_data: dict, inc_var: bool = False) -> tuple:
"""
@@ -765,5 +717,6 @@ def export_informative(trait_data: dict, inc_var: bool = False) -> tuple:
return acc
return reduce(
__exporter__,
- filter(lambda td: td["value"] is not None, trait_data["data"].values()),
+ filter(lambda td: td["value"] is not None,
+ trait_data["data"].values()),
(tuple(), tuple(), tuple()))
diff --git a/gn3/db_utils.py b/gn3/db_utils.py
index 7263705..3b72d28 100644
--- a/gn3/db_utils.py
+++ b/gn3/db_utils.py
@@ -14,10 +14,7 @@ def parse_db_url() -> Tuple:
parsed_db.password, parsed_db.path[1:])
-def database_connector() -> Tuple:
+def database_connector() -> mdb.Connection:
"""function to create db connector"""
host, user, passwd, db_name = parse_db_url()
- conn = mdb.connect(host, user, passwd, db_name)
- cursor = conn.cursor()
-
- return (conn, cursor)
+ return mdb.connect(host, user, passwd, db_name)
diff --git a/gn3/fs_helpers.py b/gn3/fs_helpers.py
index 73f6567..f313086 100644
--- a/gn3/fs_helpers.py
+++ b/gn3/fs_helpers.py
@@ -41,7 +41,7 @@ def get_dir_hash(directory: str) -> str:
def jsonfile_to_dict(json_file: str) -> Dict:
"""Give a JSON_FILE, return a python dict"""
- with open(json_file) as _file:
+ with open(json_file, encoding="utf-8") as _file:
data = json.load(_file)
return data
raise FileNotFoundError
@@ -71,9 +71,8 @@ contents to TARGET_DIR/<dir-hash>.
os.mkdir(os.path.join(target_dir, token))
gzipped_file.save(tar_target_loc)
# Extract to "tar_target_loc/token"
- tar = tarfile.open(tar_target_loc)
- tar.extractall(path=os.path.join(target_dir, token))
- tar.close()
+ with tarfile.open(tar_target_loc) as tar:
+ tar.extractall(path=os.path.join(target_dir, token))
# pylint: disable=W0703
except Exception:
return {"status": 128, "error": "gzip failed to unpack file"}
diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py
index bf9dfd1..91437bb 100644
--- a/gn3/heatmaps.py
+++ b/gn3/heatmaps.py
@@ -40,16 +40,15 @@ def trait_display_name(trait: Dict):
if trait["db"]["dataset_type"] == "Temp":
desc = trait["description"]
if desc.find("PCA") >= 0:
- return "%s::%s" % (
- trait["db"]["displayname"],
- desc[desc.rindex(':')+1:].strip())
- return "%s::%s" % (
- trait["db"]["displayname"],
- desc[:desc.index('entered')].strip())
- prefix = "%s::%s" % (
- trait["db"]["dataset_name"], trait["trait_name"])
+ return (
+ f'{trait["db"]["displayname"]}::'
+ f'{desc[desc.rindex(":")+1:].strip()}')
+ return (
+ f'{trait["db"]["displayname"]}::'
+ f'{desc[:desc.index("entered")].strip()}')
+ prefix = f'{trait["db"]["dataset_name"]}::{trait["trait_name"]}'
if trait["cellid"]:
- return "%s::%s" % (prefix, trait["cellid"])
+ return '{prefix}::{trait["cellid"]}'
return prefix
return trait["description"]
@@ -64,11 +63,7 @@ def cluster_traits(traits_data_list: Sequence[Dict]):
def __compute_corr(tdata_i, tdata_j):
if tdata_i[0] == tdata_j[0]:
return 0.0
- corr_vals = compute_correlation(tdata_i[1], tdata_j[1])
- corr = corr_vals[0]
- if (1 - corr) < 0:
- return 0.0
- return 1 - corr
+ return 1 - compute_correlation(tdata_i[1], tdata_j[1])[0]
def __cluster(tdata_i):
return tuple(
@@ -136,8 +131,7 @@ def build_heatmap(
traits_order = compute_traits_order(slinked)
samples_and_values = retrieve_samples_and_values(
traits_order, samples, exported_traits_data_list)
- traits_filename = "{}/traits_test_file_{}.txt".format(
- TMPDIR, random_string(10))
+ traits_filename = f"{TMPDIR}/traits_test_file_{random_string(10)}.txt"
generate_traits_file(
samples_and_values[0][1],
[t[2] for t in samples_and_values],
@@ -314,7 +308,7 @@ def clustered_heatmap(
vertical_spacing=0.010,
horizontal_spacing=0.001,
subplot_titles=["" if vertical else x_axis["label"]] + [
- "Chromosome: {}".format(chromo) if vertical else chromo
+ f"Chromosome: {chromo}" if vertical else chromo
for chromo in x_axis_data],#+ x_axis_data,
figure=ff.create_dendrogram(
np.array(clustering_data),
@@ -336,7 +330,7 @@ def clustered_heatmap(
col=(1 if vertical else (i + 2)))
axes_layouts = {
- "{axis}axis{count}".format(
+ "{axis}axis{count}".format( # pylint: disable=[C0209]
axis=("y" if vertical else "x"),
count=(i+1 if i > 0 else "")): {
"mirror": False,
@@ -345,12 +339,10 @@ def clustered_heatmap(
}
for i in range(num_plots)}
- print("vertical?: {} ==> {}".format("T" if vertical else "F", axes_layouts))
-
fig.update_layout({
"width": 800 if vertical else 4000,
"height": 4000 if vertical else 800,
- "{}axis".format("x" if vertical else "y"): {
+ "{}axis".format("x" if vertical else "y"): { # pylint: disable=[C0209]
"mirror": False,
"ticks": "",
"side": "top" if vertical else "left",
@@ -358,7 +350,7 @@ def clustered_heatmap(
"tickangle": 90 if vertical else 0,
"ticklabelposition": "outside top" if vertical else "outside left"
},
- "{}axis".format("y" if vertical else "x"): {
+ "{}axis".format("y" if vertical else "x"): { # pylint: disable=[C0209]
"mirror": False,
"showgrid": True,
"title": "Distance",
diff --git a/gn3/responses/__init__.py b/gn3/responses/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/gn3/responses/__init__.py
diff --git a/gn3/responses/pcorrs_responses.py b/gn3/responses/pcorrs_responses.py
new file mode 100644
index 0000000..d6fd9d7
--- /dev/null
+++ b/gn3/responses/pcorrs_responses.py
@@ -0,0 +1,24 @@
+"""Functions and classes that deal with responses and conversion to JSON."""
+import json
+
+from flask import make_response
+
+class OutputEncoder(json.JSONEncoder):
+ """
+ Class to encode output into JSON, for objects which the default
+ json.JSONEncoder class does not have default encoding for.
+ """
+ def default(self, o):
+ if isinstance(o, bytes):
+ return str(o, encoding="utf-8")
+ return json.JSONEncoder.default(self, o)
+
+def build_response(data):
+ """Build the responses for the API"""
+ status_codes = {
+ "error": 400, "not-found": 404, "success": 200, "exception": 500}
+ response = make_response(
+ json.dumps(data, cls=OutputEncoder),
+ status_codes[data["status"]])
+ response.headers["Content-Type"] = "application/json"
+ return response
diff --git a/gn3/settings.py b/gn3/settings.py
index 57c63df..6eec2a1 100644
--- a/gn3/settings.py
+++ b/gn3/settings.py
@@ -13,11 +13,13 @@ REDIS_JOB_QUEUE = "GN3::job-queue"
TMPDIR = os.environ.get("TMPDIR", tempfile.gettempdir())
RQTL_WRAPPER = "rqtl_wrapper.R"
+# SPARQL endpoint
+SPARQL_ENDPOINT = "http://localhost:8891/sparql"
+
# SQL confs
SQL_URI = os.environ.get(
"SQL_URI", "mysql://webqtlout:webqtlout@localhost/db_webqtl")
SECRET_KEY = "password"
-SQLALCHEMY_TRACK_MODIFICATIONS = False
# gn2 results only used in fetching dataset info
GN2_BASE_URL = "http://www.genenetwork.org/"
@@ -25,11 +27,11 @@ GN2_BASE_URL = "http://www.genenetwork.org/"
# wgcna script
WGCNA_RSCRIPT = "wgcna_analysis.R"
# qtlreaper command
-REAPER_COMMAND = "{}/bin/qtlreaper".format(os.environ.get("GUIX_ENVIRONMENT"))
+REAPER_COMMAND = f"{os.environ.get('GUIX_ENVIRONMENT')}/bin/qtlreaper"
# genotype files
GENOTYPE_FILES = os.environ.get(
- "GENOTYPE_FILES", "{}/genotype_files/genotype".format(os.environ.get("HOME")))
+ "GENOTYPE_FILES", f"{os.environ.get('HOME')}/genotype_files/genotype")
# CROSS-ORIGIN SETUP
def parse_env_cors(default):
@@ -53,3 +55,7 @@ CORS_HEADERS = [
GNSHARE = os.environ.get("GNSHARE", "/gnshare/gn/")
TEXTDIR = f"{GNSHARE}/web/ProbeSetFreeze_DataMatrix"
+
+ROUND_TO = 10
+
+MULTIPROCESSOR_PROCS = 6 # Number of processes to spawn