diff options
Diffstat (limited to 'gn3')
36 files changed, 3067 insertions, 432 deletions
diff --git a/gn3/api/async_commands.py b/gn3/api/async_commands.py new file mode 100644 index 0000000..c0cf4bb --- /dev/null +++ b/gn3/api/async_commands.py @@ -0,0 +1,16 @@ +"""Endpoints and functions concerning commands run in external processes.""" +import redis +from flask import jsonify, Blueprint + +async_commands = Blueprint("async_commands", __name__) + +@async_commands.route("/state/<command_id>") +def command_state(command_id): + """Respond with the current state of command identified by `command_id`.""" + with redis.Redis(decode_responses=True) as rconn: + state = rconn.hgetall(name=command_id) + if not state: + return jsonify( + status=404, + error="The command id provided does not exist.") + return jsonify(dict(state.items())) diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py index 46121f8..7eb7cd6 100644 --- a/gn3/api/correlation.py +++ b/gn3/api/correlation.py @@ -1,13 +1,21 @@ """Endpoints for running correlations""" +import sys +from functools import reduce + +import redis from flask import jsonify from flask import Blueprint from flask import request +from flask import current_app -from gn3.computations.correlations import compute_all_sample_correlation -from gn3.computations.correlations import compute_all_lit_correlation -from gn3.computations.correlations import compute_tissue_correlation -from gn3.computations.correlations import map_shared_keys_to_values +from gn3.settings import SQL_URI +from gn3.commands import queue_cmd, compose_pcorrs_command from gn3.db_utils import database_connector +from gn3.responses.pcorrs_responses import build_response +from gn3.computations.correlations import map_shared_keys_to_values +from gn3.computations.correlations import compute_tissue_correlation +from gn3.computations.correlations import compute_all_lit_correlation +from gn3.computations.correlations import compute_all_sample_correlation correlation = Blueprint("correlation", __name__) @@ -58,17 +66,15 @@ def compute_lit_corr(species=None, gene_id=None): might be needed for actual computing of the correlation results """ - conn, _cursor_object = database_connector() - target_traits_gene_ids = request.get_json() - target_trait_gene_list = list(target_traits_gene_ids.items()) + with database_connector() as conn: + target_traits_gene_ids = request.get_json() + target_trait_gene_list = list(target_traits_gene_ids.items()) - lit_corr_results = compute_all_lit_correlation( - conn=conn, trait_lists=target_trait_gene_list, - species=species, gene_id=gene_id) + lit_corr_results = compute_all_lit_correlation( + conn=conn, trait_lists=target_trait_gene_list, + species=species, gene_id=gene_id) - conn.close() - - return jsonify(lit_corr_results) + return jsonify(lit_corr_results) @correlation.route("/tissue_corr/<string:corr_method>", methods=["POST"]) @@ -83,3 +89,44 @@ def compute_tissue_corr(corr_method="pearson"): corr_method=corr_method) return jsonify(results) + +@correlation.route("/partial", methods=["POST"]) +def partial_correlation(): + """API endpoint for partial correlations.""" + def trait_fullname(trait): + return f"{trait['dataset']}::{trait['trait_name']}" + + def __field_errors__(args): + def __check__(acc, field): + if args.get(field) is None: + return acc + (f"Field '{field}' missing",) + return acc + return __check__ + + def __errors__(request_data, fields): + errors = tuple() + if request_data is None: + return ("No request data",) + + return reduce(__field_errors__(request_data), fields, errors) + + args = request.get_json() + request_errors = __errors__( + args, ("primary_trait", "control_traits", "target_db", "method")) + if request_errors: + return build_response({ + "status": "error", + "messages": request_errors, + "error_type": "Client Error"}) + return build_response({ + "status": "success", + "results": queue_cmd( + conn=redis.Redis(), + cmd=compose_pcorrs_command( + trait_fullname(args["primary_trait"]), + tuple( + trait_fullname(trait) for trait in args["control_traits"]), + args["method"], args["target_db"], + int(args.get("criteria", 500))), + job_queue=current_app.config.get("REDIS_JOB_QUEUE"), + env = {"PYTHONPATH": ":".join(sys.path), "SQL_URI": SQL_URI})}) diff --git a/gn3/api/ctl.py b/gn3/api/ctl.py new file mode 100644 index 0000000..ac33d63 --- /dev/null +++ b/gn3/api/ctl.py @@ -0,0 +1,24 @@ +"""module contains endpoints for ctl""" + +from flask import Blueprint +from flask import request +from flask import jsonify + +from gn3.computations.ctl import call_ctl_script + +ctl = Blueprint("ctl", __name__) + + +@ctl.route("/run_ctl", methods=["POST"]) +def run_ctl(): + """endpoint to run ctl + input: request form object + output:json object enum::(response,error) + + """ + ctl_data = request.json + + (cmd_results, response) = call_ctl_script(ctl_data) + return (jsonify({ + "results": response + }), 200) if response is not None else (jsonify({"error": str(cmd_results)}), 401) diff --git a/gn3/api/general.py b/gn3/api/general.py index 69ec343..e0bfc81 100644 --- a/gn3/api/general.py +++ b/gn3/api/general.py @@ -7,7 +7,7 @@ from flask import request from gn3.fs_helpers import extract_uploaded_file from gn3.commands import run_cmd - +from gn3.db import datasets general = Blueprint("general", __name__) @@ -68,3 +68,8 @@ def run_r_qtl(geno_filestr, pheno_filestr): cmd = (f"Rscript {rqtl_wrapper} " f"{geno_filestr} {pheno_filestr}") return jsonify(run_cmd(cmd)), 201 + +@general.route("/dataset/<accession_id>") +def dataset_metadata(accession_id): + """Return info as JSON for dataset with ACCESSION_ID.""" + return jsonify(datasets.dataset_metadata(accession_id)) diff --git a/gn3/api/heatmaps.py b/gn3/api/heatmaps.py index 633a061..80c8ca8 100644 --- a/gn3/api/heatmaps.py +++ b/gn3/api/heatmaps.py @@ -24,15 +24,14 @@ def clustered_heatmaps(): return jsonify({ "message": "You need to provide at least two trait names." }), 400 - conn, _cursor = database_connector() - def parse_trait_fullname(trait): - name_parts = trait.split(":") - return "{dataset_name}::{trait_name}".format( - dataset_name=name_parts[1], trait_name=name_parts[0]) - traits_fullnames = [parse_trait_fullname(trait) for trait in traits_names] + with database_connector() as conn: + def parse_trait_fullname(trait): + name_parts = trait.split(":") + return f"{name_parts[1]}::{name_parts[0]}" + traits_fullnames = [parse_trait_fullname(trait) for trait in traits_names] - with io.StringIO() as io_str: - figure = build_heatmap(traits_fullnames, conn, vertical=vertical) - figure.write_json(io_str) - fig_json = io_str.getvalue() - return fig_json, 200 + with io.StringIO() as io_str: + figure = build_heatmap(traits_fullnames, conn, vertical=vertical) + figure.write_json(io_str) + fig_json = io_str.getvalue() + return fig_json, 200 diff --git a/gn3/api/rqtl.py b/gn3/api/rqtl.py index 85b2460..70ebe12 100644 --- a/gn3/api/rqtl.py +++ b/gn3/api/rqtl.py @@ -25,7 +25,7 @@ run the rqtl_wrapper script and return the results as JSON raise FileNotFoundError # Split kwargs by those with values and boolean ones that just convert to True/False - kwargs = ["model", "method", "nperm", "scale", "control_marker"] + kwargs = ["covarstruct", "model", "method", "nperm", "scale", "control_marker"] boolean_kwargs = ["addcovar", "interval", "pstrata", "pairscan"] all_kwargs = kwargs + boolean_kwargs @@ -14,6 +14,8 @@ from gn3.api.heatmaps import heatmaps from gn3.api.correlation import correlation from gn3.api.data_entry import data_entry from gn3.api.wgcna import wgcna +from gn3.api.ctl import ctl +from gn3.api.async_commands import async_commands def create_app(config: Union[Dict, str, None] = None) -> Flask: """Create a new flask object""" @@ -45,4 +47,6 @@ def create_app(config: Union[Dict, str, None] = None) -> Flask: app.register_blueprint(correlation, url_prefix="/api/correlation") app.register_blueprint(data_entry, url_prefix="/api/dataentry") app.register_blueprint(wgcna, url_prefix="/api/wgcna") + app.register_blueprint(ctl, url_prefix="/api/ctl") + app.register_blueprint(async_commands, url_prefix="/api/async_commands") return app diff --git a/gn3/authentication.py b/gn3/authentication.py index 6719631..d0b35bc 100644 --- a/gn3/authentication.py +++ b/gn3/authentication.py @@ -113,9 +113,9 @@ def get_groups_by_user_uid(user_uid: str, conn: Redis) -> Dict: """ admin = [] member = [] - for uuid, group_info in conn.hgetall("groups").items(): + for group_uuid, group_info in conn.hgetall("groups").items(): group_info = json.loads(group_info) - group_info["uuid"] = uuid + group_info["uuid"] = group_uuid if user_uid in group_info.get('admins'): admin.append(group_info) if user_uid in group_info.get('members'): @@ -130,11 +130,10 @@ def get_user_info_by_key(key: str, value: str, conn: Redis) -> Optional[Dict]: """Given a key, get a user's information if value is matched""" if key != "user_id": - for uuid, user_info in conn.hgetall("users").items(): + for user_uuid, user_info in conn.hgetall("users").items(): user_info = json.loads(user_info) - if (key in user_info and - user_info.get(key) == value): - user_info["user_id"] = uuid + if (key in user_info and user_info.get(key) == value): + user_info["user_id"] = user_uuid return user_info elif key == "user_id": if user_info := conn.hget("users", value): @@ -145,9 +144,13 @@ def get_user_info_by_key(key: str, value: str, def create_group(conn: Redis, group_name: Optional[str], - admin_user_uids: List = [], - member_user_uids: List = []) -> Optional[Dict]: + admin_user_uids: List = None, + member_user_uids: List = None) -> Optional[Dict]: """Create a group given the group name, members and admins of that group.""" + if admin_user_uids is None: + admin_user_uids = [] + if member_user_uids is None: + member_user_uids = [] if group_name and bool(admin_user_uids + member_user_uids): timestamp = datetime.datetime.utcnow().strftime('%b %d %Y %I:%M%p') group = { @@ -160,3 +163,4 @@ def create_group(conn: Redis, group_name: Optional[str], } conn.hset("groups", group_id, json.dumps(group)) return group + return None diff --git a/gn3/commands.py b/gn3/commands.py index 7d42ced..e622068 100644 --- a/gn3/commands.py +++ b/gn3/commands.py @@ -1,5 +1,8 @@ """Procedures used to work with the various bio-informatics cli commands""" +import os +import sys +import json import subprocess from datetime import datetime @@ -7,6 +10,8 @@ from typing import Dict from typing import List from typing import Optional from typing import Tuple +from typing import Union +from typing import Sequence from uuid import uuid4 from redis.client import Redis # Used only in type hinting @@ -46,10 +51,21 @@ def compose_rqtl_cmd(rqtl_wrapper_cmd: str, return cmd +def compose_pcorrs_command( + primary_trait: str, control_traits: Tuple[str, ...], method: str, + target_database: str, criteria: int = 500): + """Compose the command to run partias correlations""" + rundir = os.path.abspath(".") + return ( + f"{sys.executable}", f"{rundir}/scripts/partial_correlations.py", + primary_trait, ",".join(control_traits), f'"{method}"', + f"{target_database}", f"--criteria={criteria}") + def queue_cmd(conn: Redis, job_queue: str, - cmd: str, - email: Optional[str] = None) -> str: + cmd: Union[str, Sequence[str]], + email: Optional[str] = None, + env: Optional[dict] = None) -> str: """Given a command CMD; (optional) EMAIL; and a redis connection CONN, queue it in Redis with an initial status of 'queued'. The following status codes are supported: @@ -68,17 +84,23 @@ Returns the name of the specific redis hash for the specific task. f"{datetime.now().strftime('%Y-%m-%d%H-%M%S-%M%S-')}" f"{str(uuid4())}") conn.rpush(job_queue, unique_id) - for key, value in {"cmd": cmd, "result": "", "status": "queued"}.items(): + for key, value in { + "cmd": json.dumps(cmd), "result": "", "status": "queued"}.items(): conn.hset(name=unique_id, key=key, value=value) if email: conn.hset(name=unique_id, key="email", value=email) + if env: + conn.hset(name=unique_id, key="env", value=json.dumps(env)) return unique_id -def run_cmd(cmd: str, success_codes: Tuple = (0,)) -> Dict: +def run_cmd(cmd: str, success_codes: Tuple = (0,), env: str = None) -> Dict: """Run CMD and return the CMD's status code and output as a dict""" - results = subprocess.run(cmd, capture_output=True, shell=True, - check=False) + parsed_cmd = json.loads(cmd) + parsed_env = (json.loads(env) if env is not None else None) + results = subprocess.run( + parsed_cmd, capture_output=True, shell=isinstance(parsed_cmd, str), + check=False, env=parsed_env) out = str(results.stdout, 'utf-8') if results.returncode not in success_codes: # Error! out = str(results.stderr, 'utf-8') diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py index c5c56db..a0da2c4 100644 --- a/gn3/computations/correlations.py +++ b/gn3/computations/correlations.py @@ -7,6 +7,7 @@ from typing import List from typing import Tuple from typing import Optional from typing import Callable +from typing import Generator import scipy.stats import pingouin as pg @@ -38,20 +39,15 @@ def map_shared_keys_to_values(target_sample_keys: List, return target_dataset_data -def normalize_values(a_values: List, - b_values: List) -> Tuple[List[float], List[float], int]: - """Trim two lists of values to contain only the values they both share Given - two lists of sample values, trim each list so that it contains only the - samples that contain a value in both lists. Also returns the number of - such samples. - - >>> normalize_values([2.3, None, None, 3.2, 4.1, 5], - [3.4, 7.2, 1.3, None, 6.2, 4.1]) - ([2.3, 4.1, 5], [3.4, 6.2, 4.1], 3) - +def normalize_values(a_values: List, b_values: List) -> Generator: + """ + :param a_values: list of primary strain values + :param b_values: a list of target strain values + :return: yield 2 values if none of them is none """ + for a_val, b_val in zip(a_values, b_values): - if (a_val and b_val is not None): + if (a_val is not None) and (b_val is not None): yield a_val, b_val @@ -79,15 +75,18 @@ def compute_sample_r_correlation(trait_name, corr_method, trait_vals, """ - sanitized_traits_vals, sanitized_target_vals = list( - zip(*list(normalize_values(trait_vals, target_samples_vals)))) - num_overlap = len(sanitized_traits_vals) + try: + normalized_traits_vals, normalized_target_vals = list( + zip(*list(normalize_values(trait_vals, target_samples_vals)))) + num_overlap = len(normalized_traits_vals) + except ValueError: + return None if num_overlap > 5: (corr_coefficient, p_value) =\ - compute_corr_coeff_p_value(primary_values=sanitized_traits_vals, - target_values=sanitized_target_vals, + compute_corr_coeff_p_value(primary_values=normalized_traits_vals, + target_values=normalized_target_vals, corr_method=corr_method) if corr_coefficient is not None and not math.isnan(corr_coefficient): @@ -108,7 +107,7 @@ package :not packaged in guix def filter_shared_sample_keys(this_samplelist, - target_samplelist) -> Tuple[List, List]: + target_samplelist) -> Generator: """Given primary and target sample-list for two base and target trait select filter the values using the shared keys @@ -134,9 +133,16 @@ def fast_compute_all_sample_correlation(this_trait, for target_trait in target_dataset: trait_name = target_trait.get("trait_id") target_trait_data = target_trait["trait_sample_data"] - processed_values.append((trait_name, corr_method, - list(zip(*list(filter_shared_sample_keys( - this_trait_samples, target_trait_data)))))) + + try: + this_vals, target_vals = list(zip(*list(filter_shared_sample_keys( + this_trait_samples, target_trait_data)))) + + processed_values.append( + (trait_name, corr_method, this_vals, target_vals)) + except ValueError: + continue + with closing(multiprocessing.Pool()) as pool: results = pool.starmap(compute_sample_r_correlation, processed_values) @@ -168,8 +174,14 @@ def compute_all_sample_correlation(this_trait, for target_trait in target_dataset: trait_name = target_trait.get("trait_id") target_trait_data = target_trait["trait_sample_data"] - this_vals, target_vals = list(zip(*list(filter_shared_sample_keys( - this_trait_samples, target_trait_data)))) + + try: + this_vals, target_vals = list(zip(*list(filter_shared_sample_keys( + this_trait_samples, target_trait_data)))) + + except ValueError: + # case where no matching strain names + continue sample_correlation = compute_sample_r_correlation( trait_name=trait_name, diff --git a/gn3/computations/correlations2.py b/gn3/computations/correlations2.py index 93db3fa..d0222ae 100644 --- a/gn3/computations/correlations2.py +++ b/gn3/computations/correlations2.py @@ -6,45 +6,21 @@ FUNCTIONS: compute_correlation: TODO: Describe what the function does...""" -from math import sqrt -from functools import reduce +from scipy import stats ## From GN1: mostly for clustering and heatmap generation def __items_with_values(dbdata, userdata): """Retains only corresponding items in the data items that are not `None` values. This should probably be renamed to something sensible""" - def both_not_none(item1, item2): - """Check that both items are not the value `None`.""" - if (item1 is not None) and (item2 is not None): - return (item1, item2) - return None - def split_lists(accumulator, item): - """Separate the 'x' and 'y' items.""" - return [accumulator[0] + [item[0]], accumulator[1] + [item[1]]] - return reduce( - split_lists, - filter(lambda x: x is not None, map(both_not_none, dbdata, userdata)), - [[], []]) + filtered = [x for x in zip(dbdata, userdata) if x[0] is not None and x[1] is not None] + return tuple(zip(*filtered)) if filtered else ([], []) def compute_correlation(dbdata, userdata): - """Compute some form of correlation. + """Compute the Pearson correlation coefficient. This is extracted from https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/utility/webqtlUtil.py#L622-L647 """ x_items, y_items = __items_with_values(dbdata, userdata) - if len(x_items) < 6: - return (0.0, len(x_items)) - meanx = sum(x_items)/len(x_items) - meany = sum(y_items)/len(y_items) - def cal_corr_vals(acc, item): - xitem, yitem = item - return [ - acc[0] + ((xitem - meanx) * (yitem - meany)), - acc[1] + ((xitem - meanx) * (xitem - meanx)), - acc[2] + ((yitem - meany) * (yitem - meany))] - xyd, sxd, syd = reduce(cal_corr_vals, zip(x_items, y_items), [0.0, 0.0, 0.0]) - try: - return ((xyd/(sqrt(sxd)*sqrt(syd))), len(x_items)) - except ZeroDivisionError: - return(0, len(x_items)) + correlation = stats.pearsonr(x_items, y_items)[0] if len(x_items) >= 6 else 0 + return (correlation, len(x_items)) diff --git a/gn3/computations/ctl.py b/gn3/computations/ctl.py new file mode 100644 index 0000000..f881410 --- /dev/null +++ b/gn3/computations/ctl.py @@ -0,0 +1,30 @@ +"""module contains code to process ctl analysis data""" +import json +from gn3.commands import run_cmd + +from gn3.computations.wgcna import dump_wgcna_data +from gn3.computations.wgcna import compose_wgcna_cmd +from gn3.computations.wgcna import process_image + +from gn3.settings import TMPDIR + + +def call_ctl_script(data): + """function to call ctl script""" + data["imgDir"] = TMPDIR + temp_file_name = dump_wgcna_data(data) + cmd = compose_wgcna_cmd("ctl_analysis.R", temp_file_name) + + cmd_results = run_cmd(cmd) + with open(temp_file_name, "r", encoding="utf-8") as outputfile: + if cmd_results["code"] != 0: + return (cmd_results, None) + output_file_data = json.load(outputfile) + + output_file_data["image_data"] = process_image( + output_file_data["image_loc"]).decode("ascii") + + output_file_data["ctl_plots"] = [process_image(ctl_plot).decode("ascii") for + ctl_plot in output_file_data["ctl_plots"]] + + return (cmd_results, output_file_data) diff --git a/gn3/computations/diff.py b/gn3/computations/diff.py index af02f7f..0b6edd6 100644 --- a/gn3/computations/diff.py +++ b/gn3/computations/diff.py @@ -6,7 +6,7 @@ from gn3.commands import run_cmd def generate_diff(data: str, edited_data: str) -> Optional[str]: """Generate the diff between 2 files""" - results = run_cmd(f"diff {data} {edited_data}", success_codes=(1, 2)) + results = run_cmd(f'"diff {data} {edited_data}"', success_codes=(1, 2)) if results.get("code", -1) > 0: return results.get("output") return None diff --git a/gn3/computations/gemma.py b/gn3/computations/gemma.py index 0b22d3c..8036a7b 100644 --- a/gn3/computations/gemma.py +++ b/gn3/computations/gemma.py @@ -31,7 +31,7 @@ def generate_pheno_txt_file(trait_filename: str, # Early return if this already exists! if os.path.isfile(f"{tmpdir}/gn2/{trait_filename}"): return f"{tmpdir}/gn2/{trait_filename}" - with open(f"{tmpdir}/gn2/{trait_filename}", "w") as _file: + with open(f"{tmpdir}/gn2/{trait_filename}", "w", encoding="utf-8") as _file: for value in values: if value == "x": _file.write("NA\n") diff --git a/gn3/computations/parsers.py b/gn3/computations/parsers.py index 1af35d6..79e3955 100644 --- a/gn3/computations/parsers.py +++ b/gn3/computations/parsers.py @@ -15,7 +15,7 @@ def parse_genofile(file_path: str) -> Tuple[List[str], 'u': None, } genotypes, samples = [], [] - with open(file_path, "r") as _genofile: + with open(file_path, "r", encoding="utf-8") as _genofile: for line in _genofile: line = line.strip() if line.startswith(("#", "@")): diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py index 07dc16d..5017796 100644 --- a/gn3/computations/partial_correlations.py +++ b/gn3/computations/partial_correlations.py @@ -5,12 +5,32 @@ It is an attempt to migrate over the partial correlations feature from GeneNetwork1. """ -from functools import reduce -from typing import Any, Tuple, Sequence +import math +import warnings +from functools import reduce, partial +from typing import Any, Tuple, Union, Sequence + +import numpy +import pandas +import pingouin from scipy.stats import pearsonr, spearmanr from gn3.settings import TEXTDIR +from gn3.random import random_string +from gn3.function_helpers import compose from gn3.data_helpers import parse_csv_line +from gn3.db.traits import export_informative +from gn3.db.datasets import retrieve_trait_dataset +from gn3.db.partial_correlations import traits_info, traits_data +from gn3.db.species import species_name, translate_to_mouse_gene_id +from gn3.db.correlations import ( + get_filename, + fetch_all_database_data, + check_for_literature_info, + fetch_tissue_correlations, + fetch_literature_correlations, + check_symbol_for_tissue_correlation, + fetch_gene_symbol_tissue_value_dict_for_trait) def control_samples(controls: Sequence[dict], sampleslist: Sequence[str]): """ @@ -40,7 +60,7 @@ def control_samples(controls: Sequence[dict], sampleslist: Sequence[str]): __process_sample__, sampleslist, (tuple(), tuple(), tuple())) return reduce( - lambda acc, item: ( + lambda acc, item: (# type: ignore[arg-type, return-value] acc[0] + (item[0],), acc[1] + (item[1],), acc[2] + (item[2],), @@ -49,22 +69,6 @@ def control_samples(controls: Sequence[dict], sampleslist: Sequence[str]): [__process_control__(trait_data) for trait_data in controls], (tuple(), tuple(), tuple(), tuple())) -def dictify_by_samples(samples_vals_vars: Sequence[Sequence]) -> Sequence[dict]: - """ - Build a sequence of dictionaries from a sequence of separate sequences of - samples, values and variances. - - This is a partial migration of - `web.webqtl.correlation.correlationFunction.fixStrains` function in GN1. - This implementation extracts code that will find common use, and that will - find use in more than one place. - """ - return tuple( - { - sample: {"sample_name": sample, "value": val, "variance": var} - for sample, val, var in zip(*trait_line) - } for trait_line in zip(*(samples_vals_vars[0:3]))) - def fix_samples(primary_trait: dict, control_traits: Sequence[dict]) -> Sequence[Sequence[Any]]: """ Corrects sample_names, values and variance such that they all contain only @@ -108,7 +112,7 @@ def find_identical_traits( return acc + ident[1] def __dictify_controls__(acc, control_item): - ckey = "{:.3f}".format(control_item[0]) + ckey = tuple(f"{item:.3f}" for item in control_item[0]) return {**acc, ckey: acc.get(ckey, tuple()) + (control_item[1],)} return (reduce(## for identical control traits @@ -148,11 +152,11 @@ def tissue_correlation( assert len(primary_trait_values) == len(target_trait_values), ( "The lengths of the `primary_trait_values` and `target_trait_values` " "must be equal") - assert method in method_fns.keys(), ( - "Method must be one of: {}".format(",".join(method_fns.keys()))) + assert method in method_fns, ( + "Method must be one of: {','.join(method_fns.keys())}") corr, pvalue = method_fns[method](primary_trait_values, target_trait_values) - return (round(corr, 10), round(pvalue, 10)) + return (corr, pvalue) def batch_computed_tissue_correlation( primary_trait_values: Tuple[float, ...], target_traits_dict: dict, @@ -196,33 +200,19 @@ def good_dataset_samples_indexes( samples_from_file.index(good) for good in set(samples).intersection(set(samples_from_file)))) -def determine_partials( - primary_vals, control_vals, all_target_trait_names, - all_target_trait_values, method): - """ - This **WILL** be a migration of - `web.webqtl.correlation.correlationFunction.determinePartialsByR` function - in GeneNetwork1. - - The function in GeneNetwork1 contains code written in R that is then used to - compute the partial correlations. - """ - ## This function is not implemented at this stage - return tuple( - primary_vals, control_vals, all_target_trait_names, - all_target_trait_values, method) - -def compute_partial_correlations_fast(# pylint: disable=[R0913, R0914] +def partial_correlations_fast(# pylint: disable=[R0913, R0914] samples, primary_vals, control_vals, database_filename, fetched_correlations, method: str, correlation_type: str) -> Tuple[ - float, Tuple[float, ...]]: + int, Tuple[float, ...]]: """ + Computes partial correlation coefficients using data from a CSV file. + This is a partial migration of the `web.webqtl.correlation.PartialCorrDBPage.getPartialCorrelationsFast` function in GeneNetwork1. """ assert method in ("spearman", "pearson") - with open(f"{TEXTDIR}/{database_filename}", "r") as dataset_file: + with open(database_filename, "r", encoding="utf-8") as dataset_file: # pytest: disable=[W1514] dataset = tuple(dataset_file.readlines()) good_dataset_samples = good_dataset_samples_indexes( @@ -245,7 +235,7 @@ def compute_partial_correlations_fast(# pylint: disable=[R0913, R0914] all_target_trait_names: Tuple[str, ...] = processed_trait_names_values[0] all_target_trait_values: Tuple[float, ...] = processed_trait_names_values[1] - all_correlations = determine_partials( + all_correlations = compute_partial( primary_vals, control_vals, all_target_trait_names, all_target_trait_values, method) ## Line 772 to 779 in GN1 are the cause of the weird complexity in the @@ -254,36 +244,544 @@ def compute_partial_correlations_fast(# pylint: disable=[R0913, R0914] ## `correlation_type` parameter return len(all_correlations), tuple( corr + ( - (fetched_correlations[corr[0]],) if correlation_type == "literature" - else fetched_correlations[corr[0]][0:2]) + (fetched_correlations[corr[0]],) # type: ignore[index] + if correlation_type == "literature" + else fetched_correlations[corr[0]][0:2]) # type: ignore[index] for idx, corr in enumerate(all_correlations)) -def partial_correlation_matrix( +def build_data_frame( xdata: Tuple[float, ...], ydata: Tuple[float, ...], - zdata: Tuple[float, ...], method: str = "pearsons", - omit_nones: bool = True) -> float: + zdata: Union[ + Tuple[float, ...], + Tuple[Tuple[float, ...], ...]]) -> pandas.DataFrame: + """ + Build a pandas DataFrame object from xdata, ydata and zdata + """ + x_y_df = pandas.DataFrame({"x": xdata, "y": ydata}) + if isinstance(zdata[0], float): + return x_y_df.join(pandas.DataFrame({"z": zdata})) + interm_df = x_y_df.join(pandas.DataFrame( + {f"z{i}": val for i, val in enumerate(zdata)})) + if interm_df.shape[1] == 3: + return interm_df.rename(columns={"z0": "z"}) + return interm_df + +def compute_trait_info(primary_vals, control_vals, target, method): """ - Computes the partial correlation coefficient using the - 'variance-covariance matrix' method + Compute the correlation values for the given arguments. + """ + targ_vals = target[0] + targ_name = target[1] + primary = [ + prim for targ, prim in zip(targ_vals, primary_vals) + if targ is not None] + + if len(primary) < 3: + return None + + def __remove_controls_for_target_nones(cont_targ): + return tuple(cont for cont, targ in cont_targ if targ is not None) + + datafrm = build_data_frame( + primary, + [targ for targ in targ_vals if targ is not None], + [__remove_controls_for_target_nones(tuple(zip(control, targ_vals))) + for control in control_vals]) + covariates = "z" if datafrm.shape[1] == 3 else [ + col for col in datafrm.columns if col not in ("x", "y")] + ppc = pingouin.partial_corr( + data=datafrm, x="x", y="y", covar=covariates, method=( + "pearson" if "pearson" in method.lower() else "spearman")) + pc_coeff = ppc["r"][0] + + zero_order_corr = pingouin.corr( + datafrm["x"], datafrm["y"], method=( + "pearson" if "pearson" in method.lower() else "spearman")) + + if math.isnan(pc_coeff): + return ( + targ_name, len(primary), pc_coeff, 1, zero_order_corr["r"][0], + zero_order_corr["p-val"][0]) + return ( + targ_name, len(primary), pc_coeff, + (ppc["p-val"][0] if not math.isnan(ppc["p-val"][0]) else ( + 0 if (abs(pc_coeff - 1) < 0.0000001) else 1)), + zero_order_corr["r"][0], zero_order_corr["p-val"][0]) + +def compute_partial( + primary_vals, control_vals, target_vals, target_names, + method: str) -> Tuple[ + Union[ + Tuple[str, int, float, float, float, float], None], + ...]: + """ + Compute the partial correlations. - This is a partial migration of the - `web.webqtl.correlation.correlationFunction.determinPartialsByR` function in - GeneNetwork1, specifically the `pcor.mat` function written in the R - programming language. + This is a re-implementation of the + `web.webqtl.correlation.correlationFunction.determinePartialsByR` function + in GeneNetwork1. + + This implementation reworks the child function `compute_partial` which will + then be used in the place of `determinPartialsByR`. + """ + return tuple( + result for result in ( + compute_trait_info( + primary_vals, control_vals, (tvals, tname), method) + for tvals, tname in zip(target_vals, target_names)) + if result is not None) + +def partial_correlations_normal(# pylint: disable=R0913 + primary_vals, control_vals, input_trait_gene_id, trait_database, + data_start_pos: int, db_type: str, method: str) -> Tuple[ + int, Tuple[Union[ + Tuple[str, int, float, float, float, float], None], + ...]]:#Tuple[float, ...] """ - return 0 + Computes the correlation coefficients. -def partial_correlation_recursive( - xdata: Tuple[float, ...], ydata: Tuple[float, ...], - zdata: Tuple[float, ...], method: str = "pearsons", - omit_nones: bool = True) -> float: + This is a migration of the + `web.webqtl.correlation.PartialCorrDBPage.getPartialCorrelationsNormal` + function in GeneNetwork1. """ - Computes the partial correlation coefficient using the 'recursive formula' - method + def __add_lit_and_tiss_corr__(item): + if method.lower() == "sgo literature correlation": + # if method is 'SGO Literature Correlation', `compute_partial` + # would give us LitCorr in the [1] position + return tuple(item) + trait_database[1] + if method.lower() in ( + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho"): + # if method is 'Tissue Correlation, *', `compute_partial` would give + # us Tissue Corr in the [1] position and Tissue Corr P Value in the + # [2] position + return tuple(item) + (trait_database[1], trait_database[2]) + return item + + target_trait_names, target_trait_vals = reduce(# type: ignore[var-annotated] + lambda acc, item: (acc[0]+(item[0],), acc[1]+(item[data_start_pos:],)), + trait_database, (tuple(), tuple())) + + all_correlations = compute_partial( + primary_vals, control_vals, target_trait_vals, target_trait_names, + method) + + if (input_trait_gene_id and db_type == "ProbeSet" and method.lower() in ( + "sgo literature correlation", "tissue correlation, pearson's r", + "tissue correlation, spearman's rho")): + return ( + len(trait_database), + tuple( + __add_lit_and_tiss_corr__(item) + for idx, item in enumerate(all_correlations))) + + return len(trait_database), all_correlations + +def partial_corrs(# pylint: disable=[R0913] + conn, samples, primary_vals, control_vals, return_number, species, + input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id, + method, dataset, database_filename): + """ + Compute the partial correlations, selecting the fast or normal method + depending on the existence of the database text file. This is a partial migration of the - `web.webqtl.correlation.correlationFunction.determinPartialsByR` function in - GeneNetwork1, specifically the `pcor.rec` function written in the R - programming language. + `web.webqtl.correlation.PartialCorrDBPage.__init__` function in + GeneNetwork1. + """ + if database_filename: + return partial_correlations_fast( + samples, primary_vals, control_vals, database_filename, + ( + fetch_literature_correlations( + species, input_trait_geneid, dataset, return_number, conn) + if "literature" in method.lower() else + fetch_tissue_correlations( + dataset, input_trait_symbol, tissue_probeset_freeze_id, + method, return_number, conn)), + method, + ("literature" if method.lower() == "sgo literature correlation" + else ("tissue" if "tissue" in method.lower() else "genetic"))) + + trait_database, data_start_pos = fetch_all_database_data( + conn, species, input_trait_geneid, input_trait_symbol, samples, dataset, + method, return_number, tissue_probeset_freeze_id) + return partial_correlations_normal( + primary_vals, control_vals, input_trait_geneid, trait_database, + data_start_pos, dataset, method) + +def literature_correlation_by_list( + conn: Any, species: str, trait_list: Tuple[dict]) -> Tuple[dict, ...]: + """ + This is a migration of the + `web.webqtl.correlation.CorrelationPage.getLiteratureCorrelationByList` + function in GeneNetwork1. + """ + if any((lambda t: ( + bool(t.get("tissue_corr")) and + bool(t.get("tissue_p_value"))))(trait) + for trait in trait_list): + temporary_table_name = f"LITERATURE{random_string(8)}" + query1 = ( + f"CREATE TEMPORARY TABLE {temporary_table_name} " + "(GeneId1 INT(12) UNSIGNED, GeneId2 INT(12) UNSIGNED PRIMARY KEY, " + "value DOUBLE)") + query2 = ( + f"INSERT INTO {temporary_table_name}(GeneId1, GeneId2, value) " + "SELECT GeneId1, GeneId2, value FROM LCorrRamin3 " + "WHERE GeneId1=%(geneid)s") + query3 = ( + "INSERT INTO {temporary_table_name}(GeneId1, GeneId2, value) " + "SELECT GeneId2, GeneId1, value FROM LCorrRamin3 " + "WHERE GeneId2=%s AND GeneId1 != %(geneid)s") + + def __set_mouse_geneid__(trait): + if trait.get("geneid"): + return { + **trait, + "mouse_geneid": translate_to_mouse_gene_id( + species, trait.get("geneid"), conn) + } + return {**trait, "mouse_geneid": 0} + + def __retrieve_lcorr__(cursor, geneids): + cursor.execute( + f"SELECT GeneId2, value FROM {temporary_table_name} " + "WHERE GeneId2 IN %(geneids)s", + geneids=geneids) + return dict(cursor.fetchall()) + + with conn.cursor() as cursor: + cursor.execute(query1) + cursor.execute(query2) + cursor.execute(query3) + + traits = tuple(__set_mouse_geneid__(trait) for trait in trait_list) + lcorrs = __retrieve_lcorr__( + cursor, ( + trait["mouse_geneid"] for trait in traits + if (trait["mouse_geneid"] != 0 and + trait["mouse_geneid"].find(";") < 0))) + return tuple( + {**trait, "l_corr": lcorrs.get(trait["mouse_geneid"], None)} + for trait in traits) + + return trait_list + return trait_list + +def tissue_correlation_by_list( + conn: Any, primary_trait_symbol: str, tissue_probeset_freeze_id: int, + method: str, trait_list: Tuple[dict]) -> Tuple[dict, ...]: + """ + This is a migration of the + `web.webqtl.correlation.CorrelationPage.getTissueCorrelationByList` + function in GeneNetwork1. + """ + def __add_tissue_corr__(trait, primary_trait_values, trait_values): + result = pingouin.corr( + primary_trait_values, trait_values, + method=("spearman" if "spearman" in method.lower() else "pearson")) + return { + **trait, + "tissue_corr": result["r"], + "tissue_p_value": result["p-val"] + } + + if any((lambda t: bool(t.get("l_corr")))(trait) for trait in trait_list): + prim_trait_symbol_value_dict = fetch_gene_symbol_tissue_value_dict_for_trait( + (primary_trait_symbol,), tissue_probeset_freeze_id, conn) + if primary_trait_symbol.lower() in prim_trait_symbol_value_dict: + primary_trait_value = prim_trait_symbol_value_dict[ + primary_trait_symbol.lower()] + gene_symbol_list = tuple( + trait["symbol"] for trait in trait_list if "symbol" in trait.keys()) + symbol_value_dict = fetch_gene_symbol_tissue_value_dict_for_trait( + gene_symbol_list, tissue_probeset_freeze_id, conn) + return tuple( + __add_tissue_corr__( + trait, primary_trait_value, + symbol_value_dict[trait["symbol"].lower()]) + for trait in trait_list + if ("symbol" in trait and + bool(trait["symbol"]) and + trait["symbol"].lower() in symbol_value_dict)) + return tuple({ + **trait, + "tissue_corr": None, + "tissue_p_value": None + } for trait in trait_list) + return trait_list + +def trait_for_output(trait): + """ + Process a trait for output. + + Removes a lot of extraneous data from the trait, that is not needed for + the display of partial correlation results. + This function also removes all key-value pairs, for which the value is + `None`, because it is a waste of network resources to transmit the key-value + pair just to indicate it does not exist. + """ + def __nan_to_none__(val): + if val is None: + return None + if math.isnan(val) or numpy.isnan(val): + return None + return val + + trait = { + "trait_type": trait["db"]["dataset_type"], + "dataset_name": trait["db"]["dataset_name"], + "dataset_type": trait["db"]["dataset_type"], + "group": trait["db"]["group"], + "trait_fullname": trait["trait_fullname"], + "trait_name": trait["trait_name"], + "symbol": trait.get("symbol"), + "description": trait.get("description"), + "pre_publication_description": trait.get("Pre_publication_description"), + "post_publication_description": trait.get( + "Post_publication_description"), + "original_description": trait.get("Original_description"), + "authors": trait.get("Authors"), + "year": trait.get("Year"), + "probe_target_description": trait.get("Probe_target_description"), + "chr": trait.get("chr"), + "mb": trait.get("mb"), + "geneid": trait.get("geneid"), + "homologeneid": trait.get("homologeneid"), + "noverlap": trait.get("noverlap"), + "partial_corr": __nan_to_none__(trait.get("partial_corr")), + "partial_corr_p_value": __nan_to_none__( + trait.get("partial_corr_p_value")), + "corr": __nan_to_none__(trait.get("corr")), + "corr_p_value": __nan_to_none__(trait.get("corr_p_value")), + "rank_order": __nan_to_none__(trait.get("rank_order")), + "delta": ( + None if trait.get("partial_corr") is None + else (trait.get("partial_corr") - trait.get("corr"))), + "l_corr": __nan_to_none__(trait.get("l_corr")), + "tissue_corr": __nan_to_none__(trait.get("tissue_corr")), + "tissue_p_value": __nan_to_none__(trait.get("tissue_p_value")) + } + return {key: val for key, val in trait.items() if val is not None} + +def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911] + conn: Any, primary_trait_name: str, + control_trait_names: Tuple[str, ...], method: str, + criteria: int, target_db_name: str) -> dict: + """ + This is the 'ochestration' function for the partial-correlation feature. + + This function will dispatch the functions doing data fetches from the + database (and various other places) and feed that data to the functions + doing the conversions and computations. It will then return the results of + all of that work. + + This function is doing way too much. Look into splitting out the + functionality into smaller functions that do fewer things. """ - return 0 + threshold = 0 + corr_min_informative = 4 + + all_traits = traits_info( + conn, threshold, (primary_trait_name,) + control_trait_names) + all_traits_data = traits_data(conn, all_traits) + + primary_trait = tuple( + trait for trait in all_traits + if trait["trait_fullname"] == primary_trait_name)[0] + if not primary_trait["haveinfo"]: + return { + "status": "not-found", + "message": f"Could not find primary trait {primary_trait['trait_fullname']}" + } + cntrl_traits = tuple( + trait for trait in all_traits + if trait["trait_fullname"] != primary_trait_name) + if not any(trait["haveinfo"] for trait in cntrl_traits): + return { + "status": "not-found", + "message": "None of the requested control traits were found."} + for trait in cntrl_traits: + if trait["haveinfo"] is False: + warnings.warn( + (f"Control traits {trait['trait_fullname']} was not found " + "- continuing without it."), + category=UserWarning) + + group = primary_trait["db"]["group"] + primary_trait_data = all_traits_data[primary_trait["trait_name"]] + primary_samples, primary_values, _primary_variances = export_informative( + primary_trait_data) + + cntrl_traits_data = tuple( + data for trait_name, data in all_traits_data.items() + if trait_name != primary_trait["trait_name"]) + species = species_name(conn, group) + + (cntrl_samples, + cntrl_values, + _cntrl_variances, + _cntrl_ns) = control_samples(cntrl_traits_data, primary_samples) + + common_primary_control_samples = primary_samples + fixed_primary_vals = primary_values + fixed_control_vals = cntrl_values + if not all(cnt_smp == primary_samples for cnt_smp in cntrl_samples): + (common_primary_control_samples, + fixed_primary_vals, + fixed_control_vals, + _primary_variances, + _cntrl_variances) = fix_samples(primary_trait, cntrl_traits) + + if len(common_primary_control_samples) < corr_min_informative: + return { + "status": "error", + "message": ( + f"Fewer than {corr_min_informative} samples data entered for " + f"{group} dataset. No calculation of correlation has been " + "attempted."), + "error_type": "Inadequate Samples"} + + identical_traits_names = find_identical_traits( + primary_trait_name, primary_values, control_trait_names, cntrl_values) + if len(identical_traits_names) > 0: + return { + "status": "error", + "message": ( + f"{identical_traits_names[0]} and {identical_traits_names[1]} " + "have the same values for the {len(fixed_primary_vals)} " + "samples that will be used to compute the partial correlation " + "(common for all primary and control traits). In such cases, " + "partial correlation cannot be computed. Please re-select your " + "traits."), + "error_type": "Identical Traits"} + + input_trait_geneid = primary_trait.get("geneid", 0) + input_trait_symbol = primary_trait.get("symbol", "") + input_trait_mouse_geneid = translate_to_mouse_gene_id( + species, input_trait_geneid, conn) + + tissue_probeset_freeze_id = 1 + db_type = primary_trait["db"]["dataset_type"] + + if db_type == "ProbeSet" and method.lower() in ( + "sgo literature correlation", + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho"): + return { + "status": "error", + "message": ( + "Wrong correlation type: It is not possible to compute the " + f"{method} between your trait and data in the {target_db_name} " + "database. Please try again after selecting another type of " + "correlation."), + "error_type": "Correlation Type"} + + if (method.lower() == "sgo literature correlation" and ( + bool(input_trait_geneid) is False or + check_for_literature_info(conn, input_trait_mouse_geneid))): + return { + "status": "error", + "message": ( + "No Literature Information: This gene does not have any " + "associated Literature Information."), + "error_type": "Literature Correlation"} + + if ( + method.lower() in ( + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho") + and bool(input_trait_symbol) is False): + return { + "status": "error", + "message": ( + "No Tissue Correlation Information: This gene does not have " + "any associated Tissue Correlation Information."), + "error_type": "Tissue Correlation"} + + if ( + method.lower() in ( + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho") + and check_symbol_for_tissue_correlation( + conn, tissue_probeset_freeze_id, input_trait_symbol)): + return { + "status": "error", + "message": ( + "No Tissue Correlation Information: This gene does not have " + "any associated Tissue Correlation Information."), + "error_type": "Tissue Correlation"} + + target_dataset = retrieve_trait_dataset( + ("Temp" if "Temp" in target_db_name else + ("Publish" if "Publish" in target_db_name else + "Geno" if "Geno" in target_db_name else "ProbeSet")), + {"db": {"dataset_name": target_db_name}, "trait_name": "_"}, + threshold, + conn) + + database_filename = get_filename(conn, target_db_name, TEXTDIR) + _total_traits, all_correlations = partial_corrs( + conn, common_primary_control_samples, fixed_primary_vals, + fixed_control_vals, len(fixed_primary_vals), species, + input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id, + method, {**target_dataset, "dataset_type": target_dataset["type"]}, database_filename) + + + def __make_sorter__(method): + def __by_lit_or_tiss_corr_then_p_val__(row): + return (row[6], row[3]) + + def __by_partial_corr_p_value__(row): + return row[3] + + if (("literature" in method.lower()) or ("tissue" in method.lower())): + return __by_lit_or_tiss_corr_then_p_val__ + + return __by_partial_corr_p_value__ + + add_lit_corr_and_tiss_corr = compose( + partial(literature_correlation_by_list, conn, species), + partial( + tissue_correlation_by_list, conn, input_trait_symbol, + tissue_probeset_freeze_id, method)) + + selected_results = sorted( + all_correlations, + key=__make_sorter__(method))[:criteria] + traits_list_corr_info = { + f"{target_dataset['dataset_name']}::{item[0]}": { + "noverlap": item[1], + "partial_corr": item[2], + "partial_corr_p_value": item[3], + "corr": item[4], + "corr_p_value": item[5], + "rank_order": (1 if "spearman" in method.lower() else 0), + **({ + "tissue_corr": item[6], + "tissue_p_value": item[7]} + if len(item) == 8 else {}), + **({"l_corr": item[6]} + if len(item) == 7 else {}) + } for item in selected_results} + + trait_list = add_lit_corr_and_tiss_corr(tuple( + {**trait, **traits_list_corr_info.get(trait["trait_fullname"], {})} + for trait in traits_info( + conn, threshold, + tuple( + f"{target_dataset['dataset_name']}::{item[0]}" + for item in selected_results)))) + + return { + "status": "success", + "results": { + "primary_trait": trait_for_output(primary_trait), + "control_traits": tuple( + trait_for_output(trait) for trait in cntrl_traits), + "correlations": tuple( + trait_for_output(trait) for trait in trait_list), + "dataset_type": target_dataset["type"], + "method": "spearman" if "spearman" in method.lower() else "pearson" + }} diff --git a/gn3/computations/partial_correlations_optimised.py b/gn3/computations/partial_correlations_optimised.py new file mode 100644 index 0000000..601289c --- /dev/null +++ b/gn3/computations/partial_correlations_optimised.py @@ -0,0 +1,244 @@ +""" +This contains an optimised version of the + `gn3.computations.partial_correlations.partial_correlations_entry` +function. +""" +from functools import partial +from typing import Any, Tuple + +from gn3.settings import TEXTDIR +from gn3.function_helpers import compose +from gn3.db.partial_correlations import traits_info, traits_data +from gn3.db.species import species_name, translate_to_mouse_gene_id +from gn3.db.traits import export_informative, retrieve_trait_dataset +from gn3.db.correlations import ( + get_filename, + check_for_literature_info, + check_symbol_for_tissue_correlation) +from gn3.computations.partial_correlations import ( + fix_samples, + partial_corrs, + control_samples, + trait_for_output, + find_identical_traits, + tissue_correlation_by_list, + literature_correlation_by_list) + +def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911] + conn: Any, primary_trait_name: str, + control_trait_names: Tuple[str, ...], method: str, + criteria: int, target_db_name: str) -> dict: + """ + This is the 'ochestration' function for the partial-correlation feature. + + This function will dispatch the functions doing data fetches from the + database (and various other places) and feed that data to the functions + doing the conversions and computations. It will then return the results of + all of that work. + + This function is doing way too much. Look into splitting out the + functionality into smaller functions that do fewer things. + """ + threshold = 0 + corr_min_informative = 4 + + all_traits = traits_info( + conn, threshold, (primary_trait_name,) + control_trait_names) + all_traits_data = traits_data(conn, all_traits) + + # primary_trait = retrieve_trait_info(threshold, primary_trait_name, conn) + primary_trait = tuple( + trait for trait in all_traits + if trait["trait_fullname"] == primary_trait_name)[0] + group = primary_trait["db"]["group"] + # primary_trait_data = retrieve_trait_data(primary_trait, conn) + primary_trait_data = all_traits_data[primary_trait["trait_name"]] + primary_samples, primary_values, _primary_variances = export_informative( + primary_trait_data) + + # cntrl_traits = tuple( + # retrieve_trait_info(threshold, trait_full_name, conn) + # for trait_full_name in control_trait_names) + # cntrl_traits_data = tuple( + # retrieve_trait_data(cntrl_trait, conn) + # for cntrl_trait in cntrl_traits) + cntrl_traits = tuple( + trait for trait in all_traits + if trait["trait_fullname"] != primary_trait_name) + cntrl_traits_data = tuple( + data for trait_name, data in all_traits_data.items() + if trait_name != primary_trait["trait_name"]) + species = species_name(conn, group) + + (cntrl_samples, + cntrl_values, + _cntrl_variances, + _cntrl_ns) = control_samples(cntrl_traits_data, primary_samples) + + common_primary_control_samples = primary_samples + fixed_primary_vals = primary_values + fixed_control_vals = cntrl_values + if not all(cnt_smp == primary_samples for cnt_smp in cntrl_samples): + (common_primary_control_samples, + fixed_primary_vals, + fixed_control_vals, + _primary_variances, + _cntrl_variances) = fix_samples(primary_trait, cntrl_traits) + + if len(common_primary_control_samples) < corr_min_informative: + return { + "status": "error", + "message": ( + f"Fewer than {corr_min_informative} samples data entered for " + f"{group} dataset. No calculation of correlation has been " + "attempted."), + "error_type": "Inadequate Samples"} + + identical_traits_names = find_identical_traits( + primary_trait_name, primary_values, control_trait_names, cntrl_values) + if len(identical_traits_names) > 0: + return { + "status": "error", + "message": ( + f"{identical_traits_names[0]} and {identical_traits_names[1]} " + "have the same values for the {len(fixed_primary_vals)} " + "samples that will be used to compute the partial correlation " + "(common for all primary and control traits). In such cases, " + "partial correlation cannot be computed. Please re-select your " + "traits."), + "error_type": "Identical Traits"} + + input_trait_geneid = primary_trait.get("geneid", 0) + input_trait_symbol = primary_trait.get("symbol", "") + input_trait_mouse_geneid = translate_to_mouse_gene_id( + species, input_trait_geneid, conn) + + tissue_probeset_freeze_id = 1 + db_type = primary_trait["db"]["dataset_type"] + + if db_type == "ProbeSet" and method.lower() in ( + "sgo literature correlation", + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho"): + return { + "status": "error", + "message": ( + "Wrong correlation type: It is not possible to compute the " + f"{method} between your trait and data in the {target_db_name} " + "database. Please try again after selecting another type of " + "correlation."), + "error_type": "Correlation Type"} + + if (method.lower() == "sgo literature correlation" and ( + bool(input_trait_geneid) is False or + check_for_literature_info(conn, input_trait_mouse_geneid))): + return { + "status": "error", + "message": ( + "No Literature Information: This gene does not have any " + "associated Literature Information."), + "error_type": "Literature Correlation"} + + if ( + method.lower() in ( + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho") + and bool(input_trait_symbol) is False): + return { + "status": "error", + "message": ( + "No Tissue Correlation Information: This gene does not have " + "any associated Tissue Correlation Information."), + "error_type": "Tissue Correlation"} + + if ( + method.lower() in ( + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho") + and check_symbol_for_tissue_correlation( + conn, tissue_probeset_freeze_id, input_trait_symbol)): + return { + "status": "error", + "message": ( + "No Tissue Correlation Information: This gene does not have " + "any associated Tissue Correlation Information."), + "error_type": "Tissue Correlation"} + + target_dataset = retrieve_trait_dataset( + ("Temp" if "Temp" in target_db_name else + ("Publish" if "Publish" in target_db_name else + "Geno" if "Geno" in target_db_name else "ProbeSet")), + {"db": {"dataset_name": target_db_name}, "trait_name": "_"}, + threshold, + conn) + + database_filename = get_filename(conn, target_db_name, TEXTDIR) + _total_traits, all_correlations = partial_corrs( + conn, common_primary_control_samples, fixed_primary_vals, + fixed_control_vals, len(fixed_primary_vals), species, + input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id, + method, {**target_dataset, "dataset_type": target_dataset["type"]}, database_filename) + + + def __make_sorter__(method): + def __sort_6__(row): + return row[6] + + def __sort_3__(row): + return row[3] + + if "literature" in method.lower(): + return __sort_6__ + + if "tissue" in method.lower(): + return __sort_6__ + + return __sort_3__ + + # sorted_correlations = sorted( + # all_correlations, key=__make_sorter__(method)) + + add_lit_corr_and_tiss_corr = compose( + partial(literature_correlation_by_list, conn, species), + partial( + tissue_correlation_by_list, conn, input_trait_symbol, + tissue_probeset_freeze_id, method)) + + selected_results = sorted( + all_correlations, + key=__make_sorter__(method))[:min(criteria, len(all_correlations))] + traits_list_corr_info = { + "{target_dataset['dataset_name']}::{item[0]}": { + "noverlap": item[1], + "partial_corr": item[2], + "partial_corr_p_value": item[3], + "corr": item[4], + "corr_p_value": item[5], + "rank_order": (1 if "spearman" in method.lower() else 0), + **({ + "tissue_corr": item[6], + "tissue_p_value": item[7]} + if len(item) == 8 else {}), + **({"l_corr": item[6]} + if len(item) == 7 else {}) + } for item in selected_results} + + trait_list = add_lit_corr_and_tiss_corr(tuple( + {**trait, **traits_list_corr_info.get(trait["trait_fullname"], {})} + for trait in traits_info( + conn, threshold, + tuple( + f"{target_dataset['dataset_name']}::{item[0]}" + for item in selected_results)))) + + return { + "status": "success", + "results": { + "primary_trait": trait_for_output(primary_trait), + "control_traits": tuple( + trait_for_output(trait) for trait in cntrl_traits), + "correlations": tuple( + trait_for_output(trait) for trait in trait_list), + "dataset_type": target_dataset["type"], + "method": "spearman" if "spearman" in method.lower() else "pearson" + }} diff --git a/gn3/computations/pca.py b/gn3/computations/pca.py new file mode 100644 index 0000000..35c9f03 --- /dev/null +++ b/gn3/computations/pca.py @@ -0,0 +1,189 @@ +"""module contains pca implementation using python""" + + +from typing import Any +from scipy import stats + +from sklearn.decomposition import PCA +from sklearn import preprocessing + +import numpy as np +import redis + + +from typing_extensions import TypeAlias + +fArray: TypeAlias = list[float] + + +def compute_pca(array: list[fArray]) -> dict[str, Any]: + """ + computes the principal component analysis + + Parameters: + + array(list[list]):a list of lists contains data to perform pca + + + Returns: + pca_dict(dict):dict contains the pca_object,pca components,pca scores + + + """ + + corr_matrix = np.array(array) + + pca_obj = PCA() + scaled_data = preprocessing.scale(corr_matrix) + + pca_obj.fit(scaled_data) + + return { + "pca": pca_obj, + "components": pca_obj.components_, + "scores": pca_obj.transform(scaled_data) + } + + +def generate_scree_plot_data(variance_ratio: fArray) -> tuple[list, fArray]: + """ + generates the scree data for plotting + + Parameters: + + variance_ratio(list[floats]):ratios for contribution of each pca + + Returns: + + coordinates(list[(x_coor,y_coord)]) + + + """ + + perc_var = [round(ratio*100, 1) for ratio in variance_ratio] + + x_coordinates = [f"PC{val}" for val in range(1, len(perc_var)+1)] + + return (x_coordinates, perc_var) + + +def generate_pca_traits_vals(trait_data_array: list[fArray], + corr_array: list[fArray]) -> list[list[Any]]: + """ + generates datasets from zscores of the traits and eigen_vectors\ + of correlation matrix + + Parameters: + + trait_data_array(list[floats]):an list of the traits + corr_array(list[list]): list of arrays for computing eigen_vectors + + Returns: + + pca_vals[list[list]]: + + + """ + + trait_zscores = stats.zscore(trait_data_array) + + if len(trait_data_array[0]) < 10: + trait_zscores = trait_data_array + + (eigen_values, corr_eigen_vectors) = np.linalg.eig(np.array(corr_array)) + idx = eigen_values.argsort()[::-1] + + return np.dot(corr_eigen_vectors[:, idx], trait_zscores) + + +def process_factor_loadings_tdata(factor_loadings, traits_num: int): + """ + + transform loadings for tables visualization + + Parameters: + factor_loading(numpy.ndarray) + traits_num(int):number of traits + + Returns: + tabular_loadings(list[list[float]]) + """ + + target_columns = 3 if traits_num > 2 else 2 + + trait_loadings = list(factor_loadings.T) + + return [list(trait_loading[:target_columns]) + for trait_loading in trait_loadings] + + +def generate_pca_temp_traits( + species: str, + group: str, + traits_data: list[fArray], + corr_array: list[fArray], + dataset_samples: list[str], + shared_samples: list[str], + create_time: str +) -> dict[str, list[Any]]: + """ + + + generate pca temp datasets + + """ + + # pylint: disable=too-many-arguments + + pca_trait_dict = {} + + pca_vals = generate_pca_traits_vals(traits_data, corr_array) + + for (idx, pca_trait) in enumerate(list(pca_vals)): + + trait_id = f"PCA{str(idx+1)}_{species}_{group}_{create_time}" + sample_vals = [] + + pointer = 0 + + for sample in dataset_samples: + if sample in shared_samples: + + sample_vals.append(str(pca_trait[pointer])) + pointer += 1 + + else: + sample_vals.append("x") + + pca_trait_dict[trait_id] = sample_vals + + return pca_trait_dict + + +def cache_pca_dataset(redis_conn: Any, exp_days: int, + pca_trait_dict: dict[str, list[Any]]): + """ + + caches pca dataset to redis + + Parameters: + + redis_conn(object) + exp_days(int): fo redis cache + pca_trait_dict(Dict): contains traits and traits vals to cache + + Returns: + + boolean(True if correct conn object False incase of exception) + + + """ + + try: + for trait_id, sample_data in pca_trait_dict.items(): + samples_str = " ".join([str(x) for x in sample_data]) + redis_conn.set(trait_id, samples_str, ex=exp_days) + return True + + except (redis.ConnectionError, AttributeError): + return False diff --git a/gn3/computations/qtlreaper.py b/gn3/computations/qtlreaper.py index d1ff4ac..b61bdae 100644 --- a/gn3/computations/qtlreaper.py +++ b/gn3/computations/qtlreaper.py @@ -27,7 +27,7 @@ def generate_traits_file(samples, trait_values, traits_filename): ["{}\t{}".format( len(trait_values), "\t".join([str(i) for i in t])) for t in trait_values[-1:]]) - with open(traits_filename, "w") as outfile: + with open(traits_filename, "w", encoding="utf8") as outfile: outfile.writelines(data) def create_output_directory(path: str): @@ -68,13 +68,13 @@ def run_reaper( The function will raise a `subprocess.CalledProcessError` exception in case of any errors running the `qtlreaper` command. """ - create_output_directory("{}/qtlreaper".format(output_dir)) - output_filename = "{}/qtlreaper/main_output_{}.txt".format( - output_dir, random_string(10)) + create_output_directory(f"{output_dir}/qtlreaper") + output_filename = ( + f"{output_dir}/qtlreaper/main_output_{random_string(10)}.txt") output_list = ["--main_output", output_filename] if separate_nperm_output: - permu_output_filename: Union[None, str] = "{}/qtlreaper/permu_output_{}.txt".format( - output_dir, random_string(10)) + permu_output_filename: Union[None, str] = ( + f"{output_dir}/qtlreaper/permu_output_{random_string(10)}.txt") output_list = output_list + [ "--permu_output", permu_output_filename] # type: ignore[list-item] else: @@ -135,7 +135,7 @@ def parse_reaper_main_results(results_file): """ Parse the results file of running QTLReaper into a list of dicts. """ - with open(results_file, "r") as infile: + with open(results_file, "r", encoding="utf8") as infile: lines = infile.readlines() def __parse_column_float_value(value): @@ -164,7 +164,7 @@ def parse_reaper_permutation_results(results_file): """ Parse the results QTLReaper permutations into a list of values. """ - with open(results_file, "r") as infile: + with open(results_file, "r", encoding="utf8") as infile: lines = infile.readlines() return [float(line.strip()) for line in lines] diff --git a/gn3/computations/rqtl.py b/gn3/computations/rqtl.py index e81aba3..65ee6de 100644 --- a/gn3/computations/rqtl.py +++ b/gn3/computations/rqtl.py @@ -53,7 +53,7 @@ def process_rqtl_mapping(file_name: str) -> List: # Later I should probably redo this using csv.read to avoid the # awkwardness with removing quotes with [1:-1] with open(os.path.join(current_app.config.get("TMPDIR", "/tmp"), - "output", file_name), "r") as the_file: + "output", file_name), "r", encoding="utf-8") as the_file: for line in the_file: line_items = line.split(",") if line_items[1][1:-1] == "chr" or not line_items: @@ -118,7 +118,6 @@ def pairscan_for_figure(file_name: str) -> Dict: return figure_data - def get_marker_list(map_file: str) -> List: """ Open the map file with the list of markers/pseudomarkers and create list of marker obs @@ -255,7 +254,7 @@ def process_perm_output(file_name: str) -> Tuple[List, float, float]: perm_results = [] with open(os.path.join(current_app.config.get("TMPDIR", "/tmp"), - "output", "PERM_" + file_name), "r") as the_file: + "output", "PERM_" + file_name), "r", encoding="utf-8") as the_file: for i, line in enumerate(the_file): if i == 0: # Skip header line diff --git a/gn3/computations/wgcna.py b/gn3/computations/wgcna.py index ab12fe7..c985491 100644 --- a/gn3/computations/wgcna.py +++ b/gn3/computations/wgcna.py @@ -19,7 +19,7 @@ def dump_wgcna_data(request_data: dict): request_data["TMPDIR"] = TMPDIR - with open(temp_file_path, "w") as output_file: + with open(temp_file_path, "w", encoding="utf-8") as output_file: json.dump(request_data, output_file) return temp_file_path @@ -31,20 +31,18 @@ def stream_cmd_output(socketio, request_data, cmd: str): socketio.emit("output", {"data": f"calling you script {cmd}"}, namespace="/", room=request_data["socket_id"]) - results = subprocess.Popen( - cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True) + with subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True) as results: + if results.stdout is not None: + for line in iter(results.stdout.readline, b""): + socketio.emit("output", + {"data": line.decode("utf-8").rstrip()}, + namespace="/", room=request_data["socket_id"]) - if results.stdout is not None: - - for line in iter(results.stdout.readline, b""): - socketio.emit("output", - {"data": line.decode("utf-8").rstrip()}, - namespace="/", room=request_data["socket_id"]) - - socketio.emit( - "output", {"data": - "parsing the output results"}, namespace="/", - room=request_data["socket_id"]) + socketio.emit( + "output", {"data": + "parsing the output results"}, namespace="/", + room=request_data["socket_id"]) def process_image(image_loc: str) -> bytes: @@ -75,7 +73,7 @@ def call_wgcna_script(rscript_path: str, request_data: dict): run_cmd_results = run_cmd(cmd) - with open(generated_file, "r") as outputfile: + with open(generated_file, "r", encoding="utf-8") as outputfile: if run_cmd_results["code"] != 0: return run_cmd_results diff --git a/gn3/csvcmp.py b/gn3/csvcmp.py new file mode 100644 index 0000000..8db89ca --- /dev/null +++ b/gn3/csvcmp.py @@ -0,0 +1,146 @@ +"""This module contains functions for manipulating and working with csv +texts""" +from typing import Any, List + +import json +import os +import uuid +from gn3.commands import run_cmd + + +def extract_strain_name(csv_header, data, seek="Strain Name") -> str: + """Extract a strain's name given a csv header""" + for column, value in zip(csv_header.split(","), data.split(",")): + if seek in column: + return value + return "" + + +def create_dirs_if_not_exists(dirs: list) -> None: + """Create directories from a list""" + for dir_ in dirs: + if not os.path.exists(dir_): + os.makedirs(dir_) + + +def remove_insignificant_edits(diff_data, epsilon=0.001): + """Remove or ignore edits that are not within ε""" + __mod = [] + if diff_data.get("Modifications"): + for mod in diff_data.get("Modifications"): + original = mod.get("Original").split(",") + current = mod.get("Current").split(",") + for i, (_x, _y) in enumerate(zip(original, current)): + if ( + _x.replace(".", "").isdigit() + and _y.replace(".", "").isdigit() + and abs(float(_x) - float(_y)) < epsilon + ): + current[i] = _x + if not (__o := ",".join(original)) == (__c := ",".join(current)): + __mod.append( + { + "Original": __o, + "Current": __c, + } + ) + diff_data["Modifications"] = __mod + return diff_data + + +def clean_csv_text(csv_text: str) -> str: + """Remove extra white space elements in all elements of the CSV file""" + _csv_text = [] + for line in csv_text.strip().split("\n"): + _csv_text.append( + ",".join([el.strip() for el in line.split(",")])) + return "\n".join(_csv_text) + + +def csv_diff(base_csv, delta_csv, tmp_dir="/tmp") -> dict: + """Diff 2 csv strings""" + base_csv = clean_csv_text(base_csv) + delta_csv = clean_csv_text(delta_csv) + base_csv_list = base_csv.split("\n") + delta_csv_list = delta_csv.split("\n") + + base_csv_header, delta_csv_header = "", "" + for i, line in enumerate(base_csv_list): + if line.startswith("Strain Name,Value,SE,Count"): + base_csv_header, delta_csv_header = line, delta_csv_list[i] + break + longest_header = max(base_csv_header, delta_csv_header) + + if base_csv_header != delta_csv_header: + if longest_header != base_csv_header: + base_csv = base_csv.replace("Strain Name,Value,SE,Count", + longest_header, 1) + else: + delta_csv = delta_csv.replace( + "Strain Name,Value,SE,Count", longest_header, 1 + ) + file_name1 = os.path.join(tmp_dir, str(uuid.uuid4())) + file_name2 = os.path.join(tmp_dir, str(uuid.uuid4())) + + with open(file_name1, "w", encoding="utf-8") as _f: + _l = len(longest_header.split(",")) + _f.write(fill_csv(csv_text=base_csv, width=_l)) + with open(file_name2, "w", encoding="utf-8") as _f: + _f.write(fill_csv(delta_csv, width=_l)) + + # Now we can run the diff! + _r = run_cmd(cmd=('"csvdiff ' + f"{file_name1} {file_name2} " + '--format json"')) + if _r.get("code") == 0: + _r = json.loads(_r.get("output", "")) + if any(_r.values()): + _r["Columns"] = max(base_csv_header, delta_csv_header) + else: + _r = {} + + # Clean Up! + if os.path.exists(file_name1): + os.remove(file_name1) + if os.path.exists(file_name2): + os.remove(file_name2) + return _r + + +def fill_csv(csv_text, width, value="x"): + """Fill a csv text with 'value' if it's length is less than width""" + data = [] + for line in csv_text.strip().split("\n"): + if line.startswith("Strain") or line.startswith("#"): + data.append(line) + elif line: + _n = line.split(",") + for i, val in enumerate(_n): + if not val.strip(): + _n[i] = value + data.append(",".join(_n + [value] * (width - len(_n)))) + return "\n".join(data) + + +def get_allowable_sampledata_headers(conn: Any) -> List: + """Get a list of all the case-attributes stored in the database""" + attributes = ["Strain Name", "Value", "SE", "Count"] + with conn.cursor() as cursor: + cursor.execute("SELECT Name from CaseAttribute") + attributes += [attributes[0] for attributes in + cursor.fetchall()] + return attributes + + +def extract_invalid_csv_headers(allowed_headers: List, csv_text: str) -> List: + """Check whether a csv text's columns contains valid headers""" + csv_header = [] + for line in csv_text.split("\n"): + if line.startswith("Strain Name"): + csv_header = [_l.strip() for _l in line.split(",")] + break + invalid_headers = [] + for header in csv_header: + if header not in allowed_headers: + invalid_headers.append(header) + return invalid_headers diff --git a/gn3/data_helpers.py b/gn3/data_helpers.py index d3f942b..268a0bb 100644 --- a/gn3/data_helpers.py +++ b/gn3/data_helpers.py @@ -5,9 +5,9 @@ data structures. from math import ceil from functools import reduce -from typing import Any, Tuple, Sequence, Optional +from typing import Any, Tuple, Sequence, Optional, Generator -def partition_all(num: int, items: Sequence[Any]) -> Tuple[Tuple[Any, ...], ...]: +def partition_all(num: int, items: Sequence[Any]) -> Generator: """ Given a sequence `items`, return a new sequence of the same type as `items` with the data partitioned into sections of `n` items per partition. @@ -19,10 +19,24 @@ def partition_all(num: int, items: Sequence[Any]) -> Tuple[Tuple[Any, ...], ...] return acc + ((start, start + num),) iterations = range(ceil(len(items) / num)) - return tuple([# type: ignore[misc] - tuple(items[start:stop]) for start, stop # type: ignore[has-type] - in reduce( - __compute_start_stop__, iterations, tuple())]) + for start, stop in reduce(# type: ignore[misc] + __compute_start_stop__, iterations, tuple()): + yield tuple(items[start:stop]) # type: ignore[has-type] + +def partition_by(partition_fn, items): + """ + Given a sequence `items`, return a tuple of tuples, each of which contain + the values in `items` partitioned such that the first item in each internal + tuple, when passed to `partition_function` returns True. + + This is an approximation of Clojure's `partition-by` function. + """ + def __partitioner__(accumulator, item): + if partition_fn(item): + return accumulator + ((item,),) + return accumulator[:-1] + (accumulator[-1] + (item,),) + + return reduce(__partitioner__, items, tuple()) def parse_csv_line( line: str, delimiter: str = ",", @@ -34,4 +48,4 @@ def parse_csv_line( function in GeneNetwork1. """ return tuple( - col.strip("{} \t\n".format(quoting)) for col in line.split(delimiter)) + col.strip(f"{quoting} \t\n") for col in line.split(delimiter)) diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py index 06b3310..3ae66ca 100644 --- a/gn3/db/correlations.py +++ b/gn3/db/correlations.py @@ -2,17 +2,16 @@ This module will hold functions that are used in the (partial) correlations feature to access the database to retrieve data needed for computations. """ - +import os from functools import reduce -from typing import Any, Dict, Tuple +from typing import Any, Dict, Tuple, Union from gn3.random import random_string from gn3.data_helpers import partition_all from gn3.db.species import translate_to_mouse_gene_id -from gn3.computations.partial_correlations import correlations_of_all_tissue_traits - -def get_filename(target_db_name: str, conn: Any) -> str: +def get_filename(conn: Any, target_db_name: str, text_files_dir: str) -> Union[ + str, bool]: """ Retrieve the name of the reference database file with which correlations are computed. @@ -23,18 +22,23 @@ def get_filename(target_db_name: str, conn: Any) -> str: """ with conn.cursor() as cursor: cursor.execute( - "SELECT Id, FullName from ProbeSetFreeze WHERE Name-%s", - target_db_name) + "SELECT Id, FullName from ProbeSetFreeze WHERE Name=%s", + (target_db_name,)) result = cursor.fetchone() if result: - return "ProbeSetFreezeId_{tid}_FullName_{fname}.txt".format( - tid=result[0], - fname=result[1].replace(' ', '_').replace('/', '_')) + filename = ( + f"ProbeSetFreezeId_{result[0]}_FullName_" + f"{result[1].replace(' ', '_').replace('/', '_')}.txt") + full_filename = f"{text_files_dir}/{filename}" + return ( + os.path.exists(full_filename) and + (filename in os.listdir(text_files_dir)) and + full_filename) - return "" + return False def build_temporary_literature_table( - species: str, gene_id: int, return_number: int, conn: Any) -> str: + conn: Any, species: str, gene_id: int, return_number: int) -> str: """ Build and populate a temporary table to hold the literature correlation data to be used in computations. @@ -49,7 +53,7 @@ def build_temporary_literature_table( query = { "rat": "SELECT rat FROM GeneIDXRef WHERE mouse=%s", "human": "SELECT human FROM GeneIDXRef WHERE mouse=%d"} - if species in query.keys(): + if species in query: cursor.execute(query[species], row[1]) record = cursor.fetchone() if record: @@ -128,7 +132,7 @@ def fetch_literature_correlations( GeneNetwork1. """ temp_table = build_temporary_literature_table( - species, gene_id, return_number, conn) + conn, species, gene_id, return_number) query_fns = { "Geno": fetch_geno_literature_correlations, # "Temp": fetch_temp_literature_correlations, @@ -156,11 +160,14 @@ def fetch_symbol_value_pair_dict( symbol: data_id_dict.get(symbol) for symbol in symbol_list if data_id_dict.get(symbol) is not None } - query = "SELECT Id, value FROM TissueProbeSetData WHERE Id IN %(data_ids)s" + data_ids_fields = (f"%(id{i})s" for i in range(len(data_ids.values()))) + query = ( + "SELECT Id, value FROM TissueProbeSetData " + f"WHERE Id IN ({','.join(data_ids_fields)})") with conn.cursor() as cursor: cursor.execute( query, - data_ids=tuple(data_ids.values())) + **{f"id{i}": did for i, did in enumerate(data_ids.values())}) value_results = cursor.fetchall() return { key: tuple(row[1] for row in value_results if row[0] == key) @@ -234,8 +241,10 @@ def fetch_tissue_probeset_xref_info( "INNER JOIN TissueProbeSetXRef AS t ON t.Symbol = x.Symbol " "AND t.Mean = x.maxmean") cursor.execute( - query, probeset_freeze_id=probeset_freeze_id, - symbols=tuple(gene_name_list)) + query, { + "probeset_freeze_id": probeset_freeze_id, + "symbols": tuple(gene_name_list) + }) results = cursor.fetchall() @@ -268,8 +277,8 @@ def fetch_gene_symbol_tissue_value_dict_for_trait( return {} def build_temporary_tissue_correlations_table( - trait_symbol: str, probeset_freeze_id: int, method: str, - return_number: int, conn: Any) -> str: + conn: Any, trait_symbol: str, probeset_freeze_id: int, method: str, + return_number: int) -> str: """ Build a temporary table to hold the tissue correlations data. @@ -279,6 +288,16 @@ def build_temporary_tissue_correlations_table( # We should probably pass the `correlations_of_all_tissue_traits` function # as an argument to this function and get rid of the one call immediately # following this comment. + from gn3.computations.partial_correlations import (#pylint: disable=[C0415, R0401] + correlations_of_all_tissue_traits) + # This import above is necessary within the function to avoid + # circular-imports. + # + # + # This import above is indicative of convoluted code, with the computation + # being interwoven with the data retrieval. This needs to be changed, such + # that the function being imported here is no longer necessary, or have the + # imported function passed to this function as an argument. symbol_corr_dict, symbol_p_value_dict = correlations_of_all_tissue_traits( fetch_gene_symbol_tissue_value_dict_for_trait( (trait_symbol,), probeset_freeze_id, conn), @@ -320,7 +339,7 @@ def fetch_tissue_correlations(# pylint: disable=R0913 GeneNetwork1. """ temp_table = build_temporary_tissue_correlations_table( - trait_symbol, probeset_freeze_id, method, return_number, conn) + conn, trait_symbol, probeset_freeze_id, method, return_number) with conn.cursor() as cursor: cursor.execute( ( @@ -379,3 +398,176 @@ def check_symbol_for_tissue_correlation( return True return False + +def fetch_sample_ids( + conn: Any, sample_names: Tuple[str, ...], species_name: str) -> Tuple[ + int, ...]: + """ + Given a sequence of sample names, and a species name, return the sample ids + that correspond to both. + + This is a partial migration of the + `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in + GeneNetwork1. + """ + samples_fields = (f"%(s{i})s" for i in range(len(sample_names))) + query = ( + "SELECT Strain.Id FROM Strain, Species " + f"WHERE Strain.Name IN ({','.join(samples_fields)}) " + "AND Strain.SpeciesId=Species.Id " + "AND Species.name=%(species_name)s") + with conn.cursor() as cursor: + cursor.execute( + query, + { + **{f"s{i}": sname for i, sname in enumerate(sample_names)}, + "species_name": species_name + }) + return tuple(row[0] for row in cursor.fetchall()) + +def build_query_sgo_lit_corr( + db_type: str, temp_table: str, sample_id_columns: str, + joins: Tuple[str, ...]) -> Tuple[str, int]: + """ + Build query for `SGO Literature Correlation` data, when querying the given + `temp_table` temporary table. + + This is a partial migration of the + `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in + GeneNetwork1. + """ + return ( + (f"SELECT {db_type}.Name, {temp_table}.value, " + + sample_id_columns + + f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " + + f"LEFT JOIN {temp_table} ON {temp_table}.GeneId2=ProbeSet.GeneId " + + " ".join(joins) + + " WHERE ProbeSet.GeneId IS NOT NULL " + + f"AND {temp_table}.value IS NOT NULL " + + f"AND {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " + + f"AND {db_type}Freeze.Name = %(db_name)s " + + f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " + + f"ORDER BY {db_type}.Id"), + 2) + +def build_query_tissue_corr(db_type, temp_table, sample_id_columns, joins): + """ + Build query for `Tissue Correlation` data, when querying the given + `temp_table` temporary table. + + This is a partial migration of the + `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in + GeneNetwork1. + """ + return ( + (f"SELECT {db_type}.Name, {temp_table}.Correlation, " + + f"{temp_table}.PValue, " + + sample_id_columns + + f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " + + f"LEFT JOIN {temp_table} ON {temp_table}.Symbol=ProbeSet.Symbol " + + " ".join(joins) + + " WHERE ProbeSet.Symbol IS NOT NULL " + + f"AND {temp_table}.Correlation IS NOT NULL " + + f"AND {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " + + f"AND {db_type}Freeze.Name = %(db_name)s " + + f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " + f"ORDER BY {db_type}.Id"), + 3) + +def fetch_all_database_data(# pylint: disable=[R0913, R0914] + conn: Any, species: str, gene_id: int, trait_symbol: str, + samples: Tuple[str, ...], dataset: dict, method: str, + return_number: int, probeset_freeze_id: int) -> Tuple[ + Tuple[float], int]: + """ + This is a migration of the + `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in + GeneNetwork1. + """ + db_type = dataset["dataset_type"] + db_name = dataset["dataset_name"] + def __build_query__(sample_ids, temp_table): + sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in sample_ids) + if db_type == "Publish": + joins = tuple( + (f"LEFT JOIN PublishData AS T{item} " + f"ON T{item}.Id = PublishXRef.DataId " + f"AND T{item}.StrainId = %(T{item}_sample_id)s") + for item in sample_ids) + return ( + ("SELECT PublishXRef.Id, " + + sample_id_columns + + " FROM (PublishXRef, PublishFreeze) " + + " ".join(joins) + + " WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId " + "AND PublishFreeze.Name = %(db_name)s"), + 1) + if temp_table is not None: + joins = tuple( + (f"LEFT JOIN {db_type}Data AS T{item} " + f"ON T{item}.Id = {db_type}XRef.DataId " + f"AND T{item}.StrainId=%(T{item}_sample_id)s") + for item in sample_ids) + if method.lower() == "sgo literature correlation": + return build_query_sgo_lit_corr( + sample_ids, temp_table, sample_id_columns, joins) + if method.lower() in ( + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho"): + return build_query_tissue_corr( + sample_ids, temp_table, sample_id_columns, joins) + joins = tuple( + (f"LEFT JOIN {db_type}Data AS T{item} " + f"ON T{item}.Id = {db_type}XRef.DataId " + f"AND T{item}.StrainId = %(T{item}_sample_id)s") + for item in sample_ids) + return ( + ( + f"SELECT {db_type}.Name, " + + sample_id_columns + + f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " + + " ".join(joins) + + f" WHERE {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " + + f"AND {db_type}Freeze.Name = %(db_name)s " + + f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " + + f"ORDER BY {db_type}.Id"), + 1) + + def __fetch_data__(sample_ids, temp_table): + query, data_start_pos = __build_query__(sample_ids, temp_table) + with conn.cursor() as cursor: + cursor.execute( + query, + {"db_name": db_name, + **{f"T{item}_sample_id": item for item in sample_ids}}) + return (cursor.fetchall(), data_start_pos) + + sample_ids = tuple( + # look into graduating this to an argument and removing the `samples` + # and `species` argument: function currying and compositions might help + # with this + f"{sample_id}" for sample_id in + fetch_sample_ids(conn, samples, species)) + + temp_table = None + if gene_id and db_type == "probeset": + if method.lower() == "sgo literature correlation": + temp_table = build_temporary_literature_table( + conn, species, gene_id, return_number) + if method.lower() in ( + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho"): + temp_table = build_temporary_tissue_correlations_table( + conn, trait_symbol, probeset_freeze_id, method, return_number) + + trait_database = tuple( + item for sublist in + (__fetch_data__(ssample_ids, temp_table) + for ssample_ids in partition_all(25, sample_ids)) + for item in sublist) + + if temp_table: + with conn.cursor() as cursor: + cursor.execute(f"DROP TEMPORARY TABLE {temp_table}") + + return (trait_database[0], trait_database[1]) diff --git a/gn3/db/datasets.py b/gn3/db/datasets.py index 6c328f5..b19db53 100644 --- a/gn3/db/datasets.py +++ b/gn3/db/datasets.py @@ -1,7 +1,11 @@ """ This module contains functions relating to specific trait dataset manipulation """ -from typing import Any +import re +from string import Template +from typing import Any, Dict, List, Optional +from SPARQLWrapper import JSON, SPARQLWrapper +from gn3.settings import SPARQL_ENDPOINT def retrieve_probeset_trait_dataset_name( threshold: int, name: str, connection: Any): @@ -22,10 +26,13 @@ def retrieve_probeset_trait_dataset_name( "threshold": threshold, "name": name }) - return dict(zip( - ["dataset_id", "dataset_name", "dataset_fullname", - "dataset_shortname", "dataset_datascale"], - cursor.fetchone())) + res = cursor.fetchone() + if res: + return dict(zip( + ["dataset_id", "dataset_name", "dataset_fullname", + "dataset_shortname", "dataset_datascale"], + res)) + return {"dataset_id": None, "dataset_name": name, "dataset_fullname": name} def retrieve_publish_trait_dataset_name( threshold: int, name: str, connection: Any): @@ -75,33 +82,8 @@ def retrieve_geno_trait_dataset_name( "dataset_shortname"], cursor.fetchone())) -def retrieve_temp_trait_dataset_name( - threshold: int, name: str, connection: Any): - """ - Get the ID, DataScale and various name formats for a `Temp` trait. - """ - query = ( - "SELECT Id, Name, FullName, ShortName " - "FROM TempFreeze " - "WHERE " - "public > %(threshold)s " - "AND " - "(Name = %(name)s OR FullName = %(name)s OR ShortName = %(name)s)") - with connection.cursor() as cursor: - cursor.execute( - query, - { - "threshold": threshold, - "name": name - }) - return dict(zip( - ["dataset_id", "dataset_name", "dataset_fullname", - "dataset_shortname"], - cursor.fetchone())) - def retrieve_dataset_name( - trait_type: str, threshold: int, trait_name: str, dataset_name: str, - conn: Any): + trait_type: str, threshold: int, dataset_name: str, conn: Any): """ Retrieve the name of a trait given the trait's name @@ -113,9 +95,7 @@ def retrieve_dataset_name( "ProbeSet": retrieve_probeset_trait_dataset_name, "Publish": retrieve_publish_trait_dataset_name, "Geno": retrieve_geno_trait_dataset_name, - "Temp": retrieve_temp_trait_dataset_name} - if trait_type == "Temp": - return retrieve_temp_trait_dataset_name(threshold, trait_name, conn) + "Temp": lambda threshold, dataset_name, conn: {}} return fn_map[trait_type](threshold, dataset_name, conn) @@ -203,7 +183,6 @@ def retrieve_temp_trait_dataset(): """ Retrieve the dataset that relates to `Temp` traits """ - # pylint: disable=[C0330] return { "searchfield": ["name", "description"], "disfield": ["name", "description"], @@ -217,7 +196,6 @@ def retrieve_geno_trait_dataset(): """ Retrieve the dataset that relates to `Geno` traits """ - # pylint: disable=[C0330] return { "searchfield": ["name", "chr"], "disfield": ["name", "chr", "mb", "source2", "sequence"], @@ -228,7 +206,6 @@ def retrieve_publish_trait_dataset(): """ Retrieve the dataset that relates to `Publish` traits """ - # pylint: disable=[C0330] return { "searchfield": [ "name", "post_publication_description", "abstract", "title", @@ -247,7 +224,6 @@ def retrieve_probeset_trait_dataset(): """ Retrieve the dataset that relates to `ProbeSet` traits """ - # pylint: disable=[C0330] return { "searchfield": [ "name", "description", "probe_target_description", "symbol", @@ -278,8 +254,7 @@ def retrieve_trait_dataset(trait_type, trait, threshold, conn): "dataset_id": None, "dataset_name": trait["db"]["dataset_name"], **retrieve_dataset_name( - trait_type, threshold, trait["trait_name"], - trait["db"]["dataset_name"], conn) + trait_type, threshold, trait["db"]["dataset_name"], conn) } group = retrieve_group_fields( trait_type, trait["trait_name"], dataset_name_info, conn) @@ -289,3 +264,100 @@ def retrieve_trait_dataset(trait_type, trait, threshold, conn): **dataset_fns[trait_type](), **group } + +def sparql_query(query: str) -> List[Dict[str, Any]]: + """Run a SPARQL query and return the bound variables.""" + sparql = SPARQLWrapper(SPARQL_ENDPOINT) + sparql.setQuery(query) + sparql.setReturnFormat(JSON) + return sparql.queryAndConvert()['results']['bindings'] + +def dataset_metadata(accession_id: str) -> Optional[Dict[str, Any]]: + """Return info about dataset with ACCESSION_ID.""" + # Check accession_id to protect against query injection. + # TODO: This function doesn't yet return the names of the actual dataset files. + pattern = re.compile(r'GN\d+', re.ASCII) + if not pattern.fullmatch(accession_id): + return None + # KLUDGE: We split the SPARQL query because virtuoso is very slow on a + # single large query. + queries = [""" +PREFIX gn: <http://genenetwork.org/> +SELECT ?name ?dataset_group ?status ?title ?geo_series +WHERE { + ?dataset gn:accessionId "$accession_id" ; + rdf:type gn:dataset ; + gn:name ?name . + OPTIONAL { ?dataset gn:datasetGroup ?dataset_group } . + # FIXME: gn:datasetStatus should not be optional. But, some records don't + # have it. + OPTIONAL { ?dataset gn:datasetStatus ?status } . + OPTIONAL { ?dataset gn:title ?title } . + OPTIONAL { ?dataset gn:geoSeries ?geo_series } . +} +""", + """ +PREFIX gn: <http://genenetwork.org/> +SELECT ?platform_name ?normalization_name ?species_name ?inbred_set_name ?tissue_name +WHERE { + ?dataset gn:accessionId "$accession_id" ; + rdf:type gn:dataset ; + gn:normalization / gn:name ?normalization_name ; + gn:datasetOfSpecies / gn:menuName ?species_name ; + gn:datasetOfInbredSet / gn:name ?inbred_set_name . + OPTIONAL { ?dataset gn:datasetOfTissue / gn:name ?tissue_name } . + OPTIONAL { ?dataset gn:datasetOfPlatform / gn:name ?platform_name } . +} +""", + """ +PREFIX gn: <http://genenetwork.org/> +SELECT ?specifics ?summary ?about_cases ?about_tissue ?about_platform + ?about_data_processing ?notes ?experiment_design ?contributors + ?citation ?acknowledgment +WHERE { + ?dataset gn:accessionId "$accession_id" ; + rdf:type gn:dataset . + OPTIONAL { ?dataset gn:specifics ?specifics . } + OPTIONAL { ?dataset gn:summary ?summary . } + OPTIONAL { ?dataset gn:aboutCases ?about_cases . } + OPTIONAL { ?dataset gn:aboutTissue ?about_tissue . } + OPTIONAL { ?dataset gn:aboutPlatform ?about_platform . } + OPTIONAL { ?dataset gn:aboutDataProcessing ?about_data_processing . } + OPTIONAL { ?dataset gn:notes ?notes . } + OPTIONAL { ?dataset gn:experimentDesign ?experiment_design . } + OPTIONAL { ?dataset gn:contributors ?contributors . } + OPTIONAL { ?dataset gn:citation ?citation . } + OPTIONAL { ?dataset gn:acknowledgment ?acknowledgment . } +} +"""] + result: Dict[str, Any] = {'accession_id': accession_id, + 'investigator': {}} + query_result = {} + for query in queries: + if sparql_result := sparql_query(Template(query).substitute(accession_id=accession_id)): + query_result.update(sparql_result[0]) + else: + return None + for key, value in query_result.items(): + result[key] = value['value'] + investigator_query_result = sparql_query(Template(""" +PREFIX gn: <http://genenetwork.org/> +SELECT ?name ?address ?city ?state ?zip ?phone ?email ?country ?homepage +WHERE { + ?dataset gn:accessionId "$accession_id" ; + rdf:type gn:dataset ; + gn:datasetOfInvestigator ?investigator . + OPTIONAL { ?investigator foaf:name ?name . } + OPTIONAL { ?investigator gn:address ?address . } + OPTIONAL { ?investigator gn:city ?city . } + OPTIONAL { ?investigator gn:state ?state . } + OPTIONAL { ?investigator gn:zipCode ?zip . } + OPTIONAL { ?investigator foaf:phone ?phone . } + OPTIONAL { ?investigator foaf:mbox ?email . } + OPTIONAL { ?investigator gn:country ?country . } + OPTIONAL { ?investigator foaf:homepage ?homepage . } +} +""").substitute(accession_id=accession_id))[0] + for key, value in investigator_query_result.items(): + result['investigator'][key] = value['value'] + return result diff --git a/gn3/db/genotypes.py b/gn3/db/genotypes.py index 8f18cac..6f867c7 100644 --- a/gn3/db/genotypes.py +++ b/gn3/db/genotypes.py @@ -2,7 +2,6 @@ import os import gzip -from typing import Union, TextIO from gn3.settings import GENOTYPE_FILES @@ -10,7 +9,7 @@ def build_genotype_file( geno_name: str, base_dir: str = GENOTYPE_FILES, extension: str = "geno"): """Build the absolute path for the genotype file.""" - return "{}/{}.{}".format(os.path.abspath(base_dir), geno_name, extension) + return f"{os.path.abspath(base_dir)}/{geno_name}.{extension}" def load_genotype_samples(genotype_filename: str, file_type: str = "geno"): """ @@ -44,22 +43,23 @@ def __load_genotype_samples_from_geno(genotype_filename: str): Loads samples from '.geno' files. """ - gzipped_filename = "{}.gz".format(genotype_filename) + def __remove_comments_and_empty_lines__(rows): + return( + line for line in rows + if line and not line.startswith(("#", "@"))) + + gzipped_filename = f"{genotype_filename}.gz" if os.path.isfile(gzipped_filename): - genofile: Union[TextIO, gzip.GzipFile] = gzip.open(gzipped_filename) + with gzip.open(gzipped_filename) as gz_genofile: + rows = __remove_comments_and_empty_lines__(gz_genofile.readlines()) else: - genofile = open(genotype_filename) - - for row in genofile: - line = row.strip() - if (not line) or (line.startswith(("#", "@"))): # type: ignore[arg-type] - continue - break + with open(genotype_filename, encoding="utf8") as genofile: + rows = __remove_comments_and_empty_lines__(genofile.readlines()) - headers = line.split("\t") # type: ignore[arg-type] + headers = next(rows).split() # type: ignore[arg-type] if headers[3] == "Mb": - return headers[4:] - return headers[3:] + return tuple(headers[4:]) + return tuple(headers[3:]) def __load_genotype_samples_from_plink(genotype_filename: str): """ @@ -67,8 +67,8 @@ def __load_genotype_samples_from_plink(genotype_filename: str): Loads samples from '.plink' files. """ - genofile = open(genotype_filename) - return [line.split(" ")[1] for line in genofile] + with open(genotype_filename, encoding="utf8") as genofile: + return tuple(line.split()[1] for line in genofile) def parse_genotype_labels(lines: list): """ @@ -129,7 +129,7 @@ def parse_genotype_marker(line: str, geno_obj: dict, parlist: tuple): alleles = marker_row[start_pos:] genotype = tuple( - (geno_table[allele] if allele in geno_table.keys() else "U") + (geno_table[allele] if allele in geno_table else "U") for allele in alleles) if len(parlist) > 0: genotype = (-1, 1) + genotype @@ -164,7 +164,7 @@ def parse_genotype_file(filename: str, parlist: tuple = tuple()): """ Parse the provided genotype file into a usable pytho3 data structure. """ - with open(filename, "r") as infile: + with open(filename, "r", encoding="utf8") as infile: contents = infile.readlines() lines = tuple(line for line in contents if @@ -175,10 +175,10 @@ def parse_genotype_file(filename: str, parlist: tuple = tuple()): data_lines = tuple(line for line in lines if not line.startswith("@")) header = parse_genotype_header(data_lines[0], parlist) geno_obj = dict(labels + header) - markers = tuple( - [parse_genotype_marker(line, geno_obj, parlist) - for line in data_lines[1:]]) + markers = ( + parse_genotype_marker(line, geno_obj, parlist) + for line in data_lines[1:]) chromosomes = tuple( dict(chromosome) for chromosome in - build_genotype_chromosomes(geno_obj, markers)) + build_genotype_chromosomes(geno_obj, tuple(markers))) return {**geno_obj, "chromosomes": chromosomes} diff --git a/gn3/db/partial_correlations.py b/gn3/db/partial_correlations.py new file mode 100644 index 0000000..72dbf1a --- /dev/null +++ b/gn3/db/partial_correlations.py @@ -0,0 +1,791 @@ +""" +This module contains the code and queries for fetching data from the database, +that relates to partial correlations. + +It is intended to replace the functions in `gn3.db.traits` and `gn3.db.datasets` +modules with functions that fetch the data enmasse, rather than one at a time. + +This module is part of the optimisation effort for the partial correlations. +""" + +from functools import reduce, partial +from typing import Any, Dict, Tuple, Union, Sequence + +from MySQLdb.cursors import DictCursor + +from gn3.function_helpers import compose +from gn3.db.traits import ( + build_trait_name, + with_samplelist_data_setup, + without_samplelist_data_setup) + +def organise_trait_data_by_trait( + traits_data_rows: Tuple[Dict[str, Any], ...]) -> Dict[ + str, Dict[str, Any]]: + """ + Organise the trait data items by their trait names. + """ + def __organise__(acc, row): + trait_name = row["trait_name"] + return { + **acc, + trait_name: acc.get(trait_name, tuple()) + ({ + key: val for key, val in row.items() if key != "trait_name"},) + } + if traits_data_rows: + return reduce(__organise__, traits_data_rows, {}) + return {} + +def temp_traits_data(conn, traits): + """ + Retrieve trait data for `Temp` traits. + """ + query = ( + "SELECT " + "Temp.Name AS trait_name, Strain.Name AS sample_name, TempData.value, " + "TempData.SE AS se_error, TempData.NStrain AS nstrain, " + "TempData.Id AS id " + "FROM TempData, Temp, Strain " + "WHERE TempData.StrainId = Strain.Id " + "AND TempData.Id = Temp.DataId " + f"AND Temp.name IN ({', '.join(['%s'] * len(traits))}) " + "ORDER BY Strain.Name") + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + query, + tuple(trait["trait_name"] for trait in traits)) + return organise_trait_data_by_trait(cursor.fetchall()) + return {} + +def publish_traits_data(conn, traits): + """ + Retrieve trait data for `Publish` traits. + """ + dataset_ids = tuple(set( + trait["db"]["dataset_id"] for trait in traits + if trait["db"].get("dataset_id") is not None)) + query = ( + "SELECT " + "PublishXRef.Id AS trait_name, Strain.Name AS sample_name, " + "PublishData.value, PublishSE.error AS se_error, " + "NStrain.count AS nstrain, PublishData.Id AS id " + "FROM (PublishData, Strain, PublishXRef, PublishFreeze) " + "LEFT JOIN PublishSE " + "ON (PublishSE.DataId = PublishData.Id " + "AND PublishSE.StrainId = PublishData.StrainId) " + "LEFT JOIN NStrain " + "ON (NStrain.DataId = PublishData.Id " + "AND NStrain.StrainId = PublishData.StrainId) " + "WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId " + "AND PublishData.Id = PublishXRef.DataId " + f"AND PublishXRef.Id IN ({', '.join(['%s'] * len(traits))}) " + "AND PublishFreeze.Id IN " + f"({', '.join(['%s'] * len(dataset_ids))}) " + "AND PublishData.StrainId = Strain.Id " + "ORDER BY Strain.Name") + if len(dataset_ids) > 0: + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + query, + tuple(trait["trait_name"] for trait in traits) + + tuple(dataset_ids)) + return organise_trait_data_by_trait(cursor.fetchall()) + return {} + +def cellid_traits_data(conn, traits): + """ + Retrieve trait data for `Probe Data` types. + """ + cellids = tuple(trait["cellid"] for trait in traits) + dataset_names = set(trait["db"]["dataset_name"] for trait in traits) + query = ( + "SELECT " + "ProbeSet.Name AS trait_name, Strain.Name AS sample_name, " + "ProbeData.value, ProbeSE.error AS se_error, ProbeData.Id AS id " + "FROM (ProbeData, ProbeFreeze, ProbeSetFreeze, ProbeXRef, Strain, " + "Probe, ProbeSet) " + "LEFT JOIN ProbeSE " + "ON (ProbeSE.DataId = ProbeData.Id " + "AND ProbeSE.StrainId = ProbeData.StrainId) " + f"WHERE Probe.Name IN ({', '.join(['%s'] * len(cellids))}) " + f"AND ProbeSet.Name IN ({', '.join(['%s'] * len(traits))}) " + "AND Probe.ProbeSetId = ProbeSet.Id " + "AND ProbeXRef.ProbeId = Probe.Id " + "AND ProbeXRef.ProbeFreezeId = ProbeFreeze.Id " + "AND ProbeSetFreeze.ProbeFreezeId = ProbeFreeze.Id " + f"AND ProbeSetFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))}) " + "AND ProbeXRef.DataId = ProbeData.Id " + "AND ProbeData.StrainId = Strain.Id " + "ORDER BY Strain.Name") + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + query, + cellids + tuple(trait["trait_name"] for trait in traits) + + tuple(dataset_names)) + return organise_trait_data_by_trait(cursor.fetchall()) + return {} + +def probeset_traits_data(conn, traits): + """ + Retrieve trait data for `ProbeSet` traits. + """ + dataset_names = set(trait["db"]["dataset_name"] for trait in traits) + query = ( + "SELECT ProbeSet.Name AS trait_name, Strain.Name AS sample_name, " + "ProbeSetData.value, ProbeSetSE.error AS se_error, " + "ProbeSetData.Id AS id " + "FROM (ProbeSetData, ProbeSetFreeze, Strain, ProbeSet, ProbeSetXRef) " + "LEFT JOIN ProbeSetSE ON " + "(ProbeSetSE.DataId = ProbeSetData.Id " + "AND ProbeSetSE.StrainId = ProbeSetData.StrainId) " + f"WHERE ProbeSet.Name IN ({', '.join(['%s'] * len(traits))})" + "AND ProbeSetXRef.ProbeSetId = ProbeSet.Id " + "AND ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id " + f"AND ProbeSetFreeze.Name IN ({', '.join(['%s']*len(dataset_names))}) " + "AND ProbeSetXRef.DataId = ProbeSetData.Id " + "AND ProbeSetData.StrainId = Strain.Id " + "ORDER BY Strain.Name") + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + query, + tuple(trait["trait_name"] for trait in traits) + + tuple(dataset_names)) + return organise_trait_data_by_trait(cursor.fetchall()) + return {} + +def species_ids(conn, traits): + """ + Retrieve the IDS of the related species from the given list of traits. + """ + groups = tuple(set( + trait["db"]["group"] for trait in traits + if trait["db"].get("group") is not None)) + query = ( + "SELECT Name AS `group`, SpeciesId AS species_id " + "FROM InbredSet " + f"WHERE Name IN ({', '.join(['%s'] * len(groups))})") + if len(groups) > 0: + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute(query, groups) + return tuple(row for row in cursor.fetchall()) + return tuple() + +def geno_traits_data(conn, traits): + """ + Retrieve trait data for `Geno` traits. + """ + sp_ids = tuple(item["species_id"] for item in species_ids(conn, traits)) + dataset_names = set(trait["db"]["dataset_name"] for trait in traits) + query = ( + "SELECT Geno.Name AS trait_name, Strain.Name AS sample_name, " + "GenoData.value, GenoSE.error AS se_error, GenoData.Id AS id " + "FROM (GenoData, GenoFreeze, Strain, Geno, GenoXRef) " + "LEFT JOIN GenoSE ON " + "(GenoSE.DataId = GenoData.Id AND GenoSE.StrainId = GenoData.StrainId) " + f"WHERE Geno.SpeciesId IN ({', '.join(['%s'] * len(sp_ids))}) " + f"AND Geno.Name IN ({', '.join(['%s'] * len(traits))}) " + "AND GenoXRef.GenoId = Geno.Id " + "AND GenoXRef.GenoFreezeId = GenoFreeze.Id " + f"AND GenoFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))}) " + "AND GenoXRef.DataId = GenoData.Id " + "AND GenoData.StrainId = Strain.Id " + "ORDER BY Strain.Name") + if len(sp_ids) > 0 and len(dataset_names) > 0: + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + query, + sp_ids + + tuple(trait["trait_name"] for trait in traits) + + tuple(dataset_names)) + return organise_trait_data_by_trait(cursor.fetchall()) + return {} + +def traits_data( + conn: Any, traits: Tuple[Dict[str, Any], ...], + samplelist: Tuple[str, ...] = tuple()) -> Dict[str, Dict[str, Any]]: + """ + Retrieve trait data for multiple `traits` + + This is a rework of the `gn3.db.traits.retrieve_trait_data` function. + """ + def __organise__(acc, trait): + dataset_type = trait["db"]["dataset_type"] + if dataset_type == "Temp": + return {**acc, "Temp": acc.get("Temp", tuple()) + (trait,)} + if dataset_type == "Publish": + return {**acc, "Publish": acc.get("Publish", tuple()) + (trait,)} + if trait.get("cellid"): + return {**acc, "cellid": acc.get("cellid", tuple()) + (trait,)} + if dataset_type == "ProbeSet": + return {**acc, "ProbeSet": acc.get("ProbeSet", tuple()) + (trait,)} + return {**acc, "Geno": acc.get("Geno", tuple()) + (trait,)} + + def __setup_samplelist__(data): + if samplelist: + return tuple( + item for item in + map(with_samplelist_data_setup(samplelist), data) + if item is not None) + return tuple( + item for item in + map(without_samplelist_data_setup(), data) + if item is not None) + + def __process_results__(results): + flattened = reduce(lambda acc, res: {**acc, **res}, results) + return { + trait_name: {"data": dict(map( + lambda item: ( + item["sample_name"], + { + key: val for key, val in item.items() + if item != "sample_name" + }), + __setup_samplelist__(data)))} + for trait_name, data in flattened.items()} + + traits_data_fns = { + "Temp": temp_traits_data, + "Publish": publish_traits_data, + "cellid": cellid_traits_data, + "ProbeSet": probeset_traits_data, + "Geno": geno_traits_data + } + return __process_results__(tuple(# type: ignore[var-annotated] + traits_data_fns[key](conn, vals) + for key, vals in reduce(__organise__, traits, {}).items())) + +def merge_traits_and_info(traits, info_results): + """ + Utility to merge trait info retrieved from the database with the given traits. + """ + if info_results: + results = { + str(trait["trait_name"]): trait for trait in info_results + } + return tuple( + { + **trait, + **results.get(trait["trait_name"], {}), + "haveinfo": bool(results.get(trait["trait_name"])) + } for trait in traits) + return tuple({**trait, "haveinfo": False} for trait in traits) + +def publish_traits_info( + conn: Any, traits: Tuple[Dict[str, Any], ...]) -> Tuple[ + Dict[str, Any], ...]: + """ + Retrieve trait information for type `Publish` traits. + + This is a rework of `gn3.db.traits.retrieve_publish_trait_info` function: + this one fetches multiple items in a single query, unlike the original that + fetches one item per query. + """ + trait_dataset_ids = set( + trait["db"]["dataset_id"] for trait in traits + if trait["db"].get("dataset_id") is not None) + columns = ( + "PublishXRef.Id, Publication.PubMed_ID, " + "Phenotype.Pre_publication_description, " + "Phenotype.Post_publication_description, " + "Phenotype.Original_description, " + "Phenotype.Pre_publication_abbreviation, " + "Phenotype.Post_publication_abbreviation, " + "Phenotype.Lab_code, Phenotype.Submitter, Phenotype.Owner, " + "Phenotype.Authorized_Users, " + "CAST(Publication.Authors AS BINARY) AS Authors, Publication.Title, " + "Publication.Abstract, Publication.Journal, Publication.Volume, " + "Publication.Pages, Publication.Month, Publication.Year, " + "PublishXRef.Sequence, Phenotype.Units, PublishXRef.comments") + query = ( + "SELECT " + f"PublishXRef.Id AS trait_name, {columns} " + "FROM " + "PublishXRef, Publication, Phenotype, PublishFreeze " + "WHERE " + f"PublishXRef.Id IN ({', '.join(['%s'] * len(traits))}) " + "AND Phenotype.Id = PublishXRef.PhenotypeId " + "AND Publication.Id = PublishXRef.PublicationId " + "AND PublishXRef.InbredSetId = PublishFreeze.InbredSetId " + "AND PublishFreeze.Id IN " + f"({', '.join(['%s'] * len(trait_dataset_ids))})") + if trait_dataset_ids: + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + query, + ( + tuple(trait["trait_name"] for trait in traits) + + tuple(trait_dataset_ids))) + return merge_traits_and_info(traits, cursor.fetchall()) + return tuple({**trait, "haveinfo": False} for trait in traits) + +def probeset_traits_info( + conn: Any, traits: Tuple[Dict[str, Any], ...]): + """ + Retrieve information for the probeset traits + """ + dataset_names = set(trait["db"]["dataset_name"] for trait in traits) + columns = ", ".join( + [f"ProbeSet.{x}" for x in + ("name", "symbol", "description", "probe_target_description", "chr", + "mb", "alias", "geneid", "genbankid", "unigeneid", "omim", + "refseq_transcriptid", "blatseq", "targetseq", "chipid", "comments", + "strand_probe", "strand_gene", "probe_set_target_region", "proteinid", + "probe_set_specificity", "probe_set_blat_score", + "probe_set_blat_mb_start", "probe_set_blat_mb_end", + "probe_set_strand", "probe_set_note_by_rw", "flag")]) + query = ( + f"SELECT ProbeSet.Name AS trait_name, {columns} " + "FROM ProbeSet INNER JOIN ProbeSetXRef " + "ON ProbeSetXRef.ProbeSetId = ProbeSet.Id " + "INNER JOIN ProbeSetFreeze " + "ON ProbeSetFreeze.Id = ProbeSetXRef.ProbeSetFreezeId " + "WHERE ProbeSetFreeze.Name IN " + f"({', '.join(['%s'] * len(dataset_names))}) " + f"AND ProbeSet.Name IN ({', '.join(['%s'] * len(traits))})") + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + query, + tuple(dataset_names) + tuple( + trait["trait_name"] for trait in traits)) + return merge_traits_and_info(traits, cursor.fetchall()) + return tuple({**trait, "haveinfo": False} for trait in traits) + +def geno_traits_info( + conn: Any, traits: Tuple[Dict[str, Any], ...]): + """ + Retrieve trait information for type `Geno` traits. + + This is a rework of the `gn3.db.traits.retrieve_geno_trait_info` function. + """ + dataset_names = set(trait["db"]["dataset_name"] for trait in traits) + columns = ", ".join([ + f"Geno.{x}" for x in ("name", "chr", "mb", "source2", "sequence")]) + query = ( + "SELECT " + f"Geno.Name AS trait_name, {columns} " + "FROM " + "Geno INNER JOIN GenoXRef ON GenoXRef.GenoId = Geno.Id " + "INNER JOIN GenoFreeze ON GenoFreeze.Id = GenoXRef.GenoFreezeId " + f"WHERE GenoFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))}) " + f"AND Geno.Name IN ({', '.join(['%s'] * len(traits))})") + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + query, + tuple(dataset_names) + tuple( + trait["trait_name"] for trait in traits)) + return merge_traits_and_info(traits, cursor.fetchall()) + return tuple({**trait, "haveinfo": False} for trait in traits) + +def temp_traits_info( + conn: Any, traits: Tuple[Dict[str, Any], ...]): + """ + Retrieve trait information for type `Temp` traits. + + A rework of the `gn3.db.traits.retrieve_temp_trait_info` function. + """ + query = ( + "SELECT Name as trait_name, name, description FROM Temp " + f"WHERE Name IN ({', '.join(['%s'] * len(traits))})") + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + query, + tuple(trait["trait_name"] for trait in traits)) + return merge_traits_and_info(traits, cursor.fetchall()) + return tuple({**trait, "haveinfo": False} for trait in traits) + +def publish_datasets_names( + conn: Any, threshold: int, dataset_names: Tuple[str, ...]): + """ + Get the ID, DataScale and various name formats for a `Publish` trait. + + Rework of the `gn3.db.datasets.retrieve_publish_trait_dataset_name` + """ + query = ( + "SELECT DISTINCT " + "Id AS dataset_id, Name AS dataset_name, FullName AS dataset_fullname, " + "ShortName AS dataset_shortname " + "FROM PublishFreeze " + "WHERE " + "public > %s " + "AND " + "(Name IN ({names}) OR FullName IN ({names}) OR ShortName IN ({names}))") + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + query.format(names=", ".join(["%s"] * len(dataset_names))), + (threshold,) +(dataset_names * 3)) + return {ds["dataset_name"]: ds for ds in cursor.fetchall()} + return {} + +def set_bxd(group_info): + """Set the group value to BXD if it is 'BXD300'.""" + return { + **group_info, + "group": ( + "BXD" if group_info.get("Name") == "BXD300" + else group_info.get("Name", "")), + "groupid": group_info["Id"] + } + +def organise_groups_by_dataset( + group_rows: Union[Sequence[Dict[str, Any]], None]) -> Dict[str, Any]: + """Utility: Organise given groups by their datasets.""" + if group_rows: + return { + row["dataset_name"]: set_bxd({ + key: val for key, val in row.items() + if key != "dataset_name" + }) for row in group_rows + } + return {} + +def publish_datasets_groups(conn: Any, dataset_names: Tuple[str]): + """ + Retrieve the Group, and GroupID values for various Publish trait types. + + Rework of `gn3.db.datasets.retrieve_publish_group_fields` function. + """ + query = ( + "SELECT PublishFreeze.Name AS dataset_name, InbredSet.Name, " + "InbredSet.Id " + "FROM InbredSet, PublishFreeze " + "WHERE PublishFreeze.InbredSetId = InbredSet.Id " + f"AND PublishFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))})") + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute(query, tuple(dataset_names)) + return organise_groups_by_dataset(cursor.fetchall()) + return {} + +def publish_traits_datasets(conn: Any, threshold, traits: Tuple[Dict]): + """Retrieve datasets for 'Publish' traits.""" + dataset_names = tuple(set(trait["db"]["dataset_name"] for trait in traits)) + dataset_names_info = publish_datasets_names(conn, threshold, dataset_names) + dataset_groups = publish_datasets_groups(conn, dataset_names) # type: ignore[arg-type] + return tuple({ + **trait, + "db": { + **trait["db"], + **dataset_names_info.get(trait["db"]["dataset_name"], {}), + **dataset_groups.get(trait["db"]["dataset_name"], {}) + } + } for trait in traits) + +def probeset_datasets_names(conn: Any, threshold: int, dataset_names: Tuple[str, ...]): + """ + Get the ID, DataScale and various name formats for a `ProbeSet` trait. + """ + query = ( + "SELECT Id AS dataset_id, Name AS dataset_name, " + "FullName AS dataset_fullname, ShortName AS dataset_shortname, " + "DataScale AS dataset_datascale " + "FROM ProbeSetFreeze " + "WHERE " + "public > %s " + "AND " + "(Name IN ({names}) OR FullName IN ({names}) OR ShortName IN ({names}))") + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + query.format(names=", ".join(["%s"] * len(dataset_names))), + (threshold,) +(dataset_names * 3)) + return {ds["dataset_name"]: ds for ds in cursor.fetchall()} + return {} + +def probeset_datasets_groups(conn, dataset_names): + """ + Retrieve the Group, and GroupID values for various ProbeSet trait types. + """ + query = ( + "SELECT ProbeSetFreeze.Name AS dataset_name, InbredSet.Name, " + "InbredSet.Id " + "FROM InbredSet, ProbeSetFreeze, ProbeFreeze " + "WHERE ProbeFreeze.InbredSetId = InbredSet.Id " + "AND ProbeFreeze.Id = ProbeSetFreeze.ProbeFreezeId " + f"AND ProbeSetFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))})") + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute(query, tuple(dataset_names)) + return organise_groups_by_dataset(cursor.fetchall()) + return {} + +def probeset_traits_datasets(conn: Any, threshold, traits: Tuple[Dict]): + """Retrive datasets for 'ProbeSet' traits.""" + dataset_names = tuple(set(trait["db"]["dataset_name"] for trait in traits)) + dataset_names_info = probeset_datasets_names(conn, threshold, dataset_names) + dataset_groups = probeset_datasets_groups(conn, dataset_names) + return tuple({ + **trait, + "db": { + **trait["db"], + **dataset_names_info.get(trait["db"]["dataset_name"], {}), + **dataset_groups.get(trait["db"]["dataset_name"], {}) + } + } for trait in traits) + +def geno_datasets_names(conn, threshold, dataset_names): + """ + Get the ID, DataScale and various name formats for a `Geno` trait. + """ + query = ( + "SELECT Id AS dataset_id, Name AS dataset_name, " + "FullName AS dataset_fullname, ShortName AS dataset_short_name " + "FROM GenoFreeze " + "WHERE " + "public > %s " + "AND " + "(Name IN ({names}) OR FullName IN ({names}) OR ShortName IN ({names}))") + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + query.format(names=", ".join(["%s"] * len(dataset_names))), + (threshold,) + (tuple(dataset_names) * 3)) + return {ds["dataset_name"]: ds for ds in cursor.fetchall()} + return {} + +def geno_datasets_groups(conn, dataset_names): + """ + Retrieve the Group, and GroupID values for various Geno trait types. + """ + query = ( + "SELECT GenoFreeze.Name AS dataset_name, InbredSet.Name, InbredSet.Id " + "FROM InbredSet, GenoFreeze " + "WHERE GenoFreeze.InbredSetId = InbredSet.Id " + f"AND GenoFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))})") + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute(query, tuple(dataset_names)) + return organise_groups_by_dataset(cursor.fetchall()) + return {} + +def geno_traits_datasets(conn: Any, threshold: int, traits: Tuple[Dict]): + """Retrieve datasets for 'Geno' traits.""" + dataset_names = tuple(set(trait["db"]["dataset_name"] for trait in traits)) + dataset_names_info = geno_datasets_names(conn, threshold, dataset_names) + dataset_groups = geno_datasets_groups(conn, dataset_names) + return tuple({ + **trait, + "db": { + **trait["db"], + **dataset_names_info.get(trait["db"]["dataset_name"], {}), + **dataset_groups.get(trait["db"]["dataset_name"], {}) + } + } for trait in traits) + +def temp_datasets_groups(conn, dataset_names): + """ + Retrieve the Group, and GroupID values for `Temp` trait types. + """ + query = ( + "SELECT Temp.Name AS dataset_name, InbredSet.Name, InbredSet.Id " + "FROM InbredSet, Temp " + "WHERE Temp.InbredSetId = InbredSet.Id " + f"AND Temp.Name IN ({', '.join(['%s'] * len(dataset_names))})") + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute(query, tuple(dataset_names)) + return organise_groups_by_dataset(cursor.fetchall()) + return {} + +def temp_traits_datasets(conn: Any, threshold: int, traits: Tuple[Dict]): #pylint: disable=[W0613] + """ + Retrieve datasets for 'Temp' traits. + """ + dataset_names = tuple(set(trait["db"]["dataset_name"] for trait in traits)) + dataset_groups = temp_datasets_groups(conn, dataset_names) + return tuple({ + **trait, + "db": { + **trait["db"], + **dataset_groups.get(trait["db"]["dataset_name"], {}) + } + } for trait in traits) + +def set_confidential(traits): + """ + Set the confidential field for traits of type `Publish`. + """ + return tuple({ + **trait, + "confidential": ( + True if (# pylint: disable=[R1719] + trait.get("pre_publication_description") + and not trait.get("pubmed_id")) + else False) + } for trait in traits) + +def query_qtl_info(conn, query, traits, dataset_ids): + """ + Utility: Run the `query` to get the QTL information for the given `traits`. + """ + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + query, + tuple(trait["trait_name"] for trait in traits) + dataset_ids) + results = { + row["trait_name"]: { + key: val for key, val in row if key != "trait_name" + } for row in cursor.fetchall() + } + return tuple( + {**trait, **results.get(trait["trait_name"], {})} + for trait in traits) + +def set_publish_qtl_info(conn, qtl, traits): + """ + Load extra QTL information for `Publish` traits + """ + if qtl: + dataset_ids = set(trait["db"]["dataset_id"] for trait in traits) + query = ( + "SELECT PublishXRef.Id AS trait_name, PublishXRef.Locus, " + "PublishXRef.LRS, PublishXRef.additive " + "FROM PublishXRef, PublishFreeze " + f"WHERE PublishXRef.Id IN ({', '.join(['%s'] * len(traits))}) " + "AND PublishXRef.InbredSetId = PublishFreeze.InbredSetId " + f"AND PublishFreeze.Id IN ({', '.join(['%s'] * len(dataset_ids))})") + return query_qtl_info(conn, query, traits, tuple(dataset_ids)) + return traits + +def set_probeset_qtl_info(conn, qtl, traits): + """ + Load extra QTL information for `ProbeSet` traits + """ + if qtl: + dataset_ids = tuple(set(trait["db"]["dataset_id"] for trait in traits)) + query = ( + "SELECT ProbeSet.Name AS trait_name, ProbeSetXRef.Locus, " + "ProbeSetXRef.LRS, ProbeSetXRef.pValue, " + "ProbeSetXRef.mean, ProbeSetXRef.additive " + "FROM ProbeSetXRef, ProbeSet " + "WHERE ProbeSetXRef.ProbeSetId = ProbeSet.Id " + f"AND ProbeSet.Name IN ({', '.join(['%s'] * len(traits))}) " + "AND ProbeSetXRef.ProbeSetFreezeId IN " + f"({', '.join(['%s'] * len(dataset_ids))})") + return query_qtl_info(conn, query, traits, tuple(dataset_ids)) + return traits + +def set_sequence(conn, traits): + """ + Retrieve 'ProbeSet' traits sequence information + """ + dataset_names = set(trait["db"]["dataset_name"] for trait in traits) + query = ( + "SELECT ProbeSet.Name as trait_name, ProbeSet.BlatSeq " + "FROM ProbeSet, ProbeSetFreeze, ProbeSetXRef " + "WHERE ProbeSet.Id=ProbeSetXRef.ProbeSetId " + "AND ProbeSetFreeze.Id = ProbeSetXRef.ProbeSetFreezeId " + f"AND ProbeSet.Name IN ({', '.join(['%s'] * len(traits))}) " + f"AND ProbeSetFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))})") + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute( + query, + (tuple(trait["trait_name"] for trait in traits) + + tuple(dataset_names))) + results = { + row["trait_name"]: { + key: val for key, val in row.items() if key != "trait_name" + } for row in cursor.fetchall() + } + return tuple( + { + **trait, + **results.get(trait["trait_name"], {}) + } for trait in traits) + return traits + +def set_homologene_id(conn, traits): + """ + Retrieve and set the 'homologene_id' values for ProbeSet traits. + """ + geneids = set(trait.get("geneid") for trait in traits if trait["haveinfo"]) + groups = set( + trait["db"].get("group") for trait in traits if trait["haveinfo"]) + if len(geneids) > 1 and len(groups) > 1: + query = ( + "SELECT InbredSet.Name AS `group`, Homologene.GeneId AS geneid, " + "HomologeneId " + "FROM Homologene, Species, InbredSet " + f"WHERE Homologene.GeneId IN ({', '.join(['%s'] * len(geneids))}) " + f"AND InbredSet.Name IN ({', '.join(['%s'] * len(groups))}) " + "AND InbredSet.SpeciesId = Species.Id " + "AND Species.TaxonomyId = Homologene.TaxonomyId") + with conn.cursor(cursorclass=DictCursor) as cursor: + cursor.execute(query, (tuple(geneids) + tuple(groups))) + results = { + row["group"]: { + row["geneid"]: { + key: val for key, val in row.items() + if key not in ("group", "geneid") + } + } for row in cursor.fetchall() + } + return tuple( + { + **trait, **results.get( + trait["db"]["group"], {}).get(trait["geneid"], {}) + } for trait in traits) + return traits + +def traits_datasets(conn, threshold, traits): + """ + Retrieve datasets for various `traits`. + """ + dataset_fns = { + "Temp": temp_traits_datasets, + "Geno": geno_traits_datasets, + "Publish": publish_traits_datasets, + "ProbeSet": probeset_traits_datasets + } + def __organise_by_type__(acc, trait): + dataset_type = trait["db"]["dataset_type"] + return { + **acc, + dataset_type: acc.get(dataset_type, tuple()) + (trait,) + } + with_datasets = { + trait["trait_fullname"]: trait for trait in ( + item for sublist in ( + dataset_fns[dtype](conn, threshold, ttraits) + for dtype, ttraits + in reduce(__organise_by_type__, traits, {}).items()) + for item in sublist)} + return tuple( + {**trait, **with_datasets.get(trait["trait_fullname"], {})} + for trait in traits) + +def traits_info( + conn: Any, threshold: int, traits_fullnames: Tuple[str, ...], + qtl=None) -> Tuple[Dict[str, Any], ...]: + """ + Retrieve basic trait information for multiple `traits`. + + This is a rework of the `gn3.db.traits.retrieve_trait_info` function. + """ + def __organise_by_dataset_type__(acc, trait): + dataset_type = trait["db"]["dataset_type"] + return { + **acc, + dataset_type: acc.get(dataset_type, tuple()) + (trait,) + } + traits = traits_datasets( + conn, threshold, + tuple(build_trait_name(trait) for trait in traits_fullnames)) + traits_fns = { + "Publish": compose( + set_confidential, partial(set_publish_qtl_info, conn, qtl), + partial(publish_traits_info, conn), + partial(publish_traits_datasets, conn, threshold)), + "ProbeSet": compose( + partial(set_sequence, conn), + partial(set_probeset_qtl_info, conn, qtl), + partial(set_homologene_id, conn), + partial(probeset_traits_info, conn), + partial(probeset_traits_datasets, conn, threshold)), + "Geno": compose( + partial(geno_traits_info, conn), + partial(geno_traits_datasets, conn, threshold)), + "Temp": compose( + partial(temp_traits_info, conn), + partial(temp_traits_datasets, conn, threshold)) + } + return tuple( + trait for sublist in (# type: ignore[var-annotated] + traits_fns[dataset_type](traits) + for dataset_type, traits + in reduce(__organise_by_dataset_type__, traits, {}).items()) + for trait in sublist) diff --git a/gn3/db/sample_data.py b/gn3/db/sample_data.py new file mode 100644 index 0000000..f73954f --- /dev/null +++ b/gn3/db/sample_data.py @@ -0,0 +1,365 @@ +"""Module containing functions that work with sample data""" +from typing import Any, Tuple, Dict, Callable + +import MySQLdb + +from gn3.csvcmp import extract_strain_name + + +_MAP = { + "PublishData": ("StrainId", "Id", "value"), + "PublishSE": ("StrainId", "DataId", "error"), + "NStrain": ("StrainId", "DataId", "count"), +} + + +def __extract_actions(original_data: str, + updated_data: str, + csv_header: str) -> Dict: + """Return a dictionary containing elements that need to be deleted, inserted, +or updated. + + """ + result: Dict[str, Any] = { + "delete": {"data": [], "csv_header": []}, + "insert": {"data": [], "csv_header": []}, + "update": {"data": [], "csv_header": []}, + } + strain_name = "" + for _o, _u, _h in zip(original_data.strip().split(","), + updated_data.strip().split(","), + csv_header.strip().split(",")): + if _h == "Strain Name": + strain_name = _o + if _o == _u: # No change + continue + if _o and _u == "x": # Deletion + result["delete"]["data"].append(_o) + result["delete"]["csv_header"].append(_h) + elif _o == "x" and _u: # Insert + result["insert"]["data"].append(_u) + result["insert"]["csv_header"].append(_h) + elif _o and _u: # Update + result["update"]["data"].append(_u) + result["update"]["csv_header"].append(_h) + for key, val in result.items(): + if not val["data"]: + result[key] = None + else: + result[key]["data"] = (f"{strain_name}," + + ",".join(result[key]["data"])) + result[key]["csv_header"] = ("Strain Name," + + ",".join(result[key]["csv_header"])) + return result + + +def get_trait_csv_sample_data(conn: Any, + trait_name: int, phenotype_id: int) -> str: + """Fetch a trait and return it as a csv string""" + __query = ("SELECT concat(st.Name, ',', ifnull(pd.value, 'x'), ',', " + "ifnull(ps.error, 'x'), ',', ifnull(ns.count, 'x')) as 'Data' " + ",ifnull(ca.Name, 'x') as 'CaseAttr', " + "ifnull(cxref.value, 'x') as 'Value' " + "FROM PublishFreeze pf " + "JOIN PublishXRef px ON px.InbredSetId = pf.InbredSetId " + "JOIN PublishData pd ON pd.Id = px.DataId " + "JOIN Strain st ON pd.StrainId = st.Id " + "LEFT JOIN PublishSE ps ON ps.DataId = pd.Id " + "AND ps.StrainId = pd.StrainId " + "LEFT JOIN NStrain ns ON ns.DataId = pd.Id " + "AND ns.StrainId = pd.StrainId " + "LEFT JOIN CaseAttributeXRefNew cxref ON " + "(cxref.InbredSetId = px.InbredSetId AND " + "cxref.StrainId = st.Id) " + "LEFT JOIN CaseAttribute ca ON ca.Id = cxref.CaseAttributeId " + "WHERE px.Id = %s AND px.PhenotypeId = %s ORDER BY st.Name") + case_attr_columns = set() + csv_data: Dict = {} + with conn.cursor() as cursor: + cursor.execute(__query, (trait_name, phenotype_id)) + for data in cursor.fetchall(): + if data[1] == "x": + csv_data[data[0]] = None + else: + sample, case_attr, value = data[0], data[1], data[2] + if not csv_data.get(sample): + csv_data[sample] = {} + csv_data[sample][case_attr] = None if value == "x" else value + case_attr_columns.add(case_attr) + if not case_attr_columns: + return ("Strain Name,Value,SE,Count\n" + + "\n".join(csv_data.keys())) + columns = sorted(case_attr_columns) + csv = ("Strain Name,Value,SE,Count," + + ",".join(columns) + "\n") + for key, value in csv_data.items(): + if not value: + csv += (key + (len(case_attr_columns) * ",x") + "\n") + else: + vals = [str(value.get(column, "x")) for column in columns] + csv += (key + "," + ",".join(vals) + "\n") + return csv + return "No Sample Data Found" + + +def get_sample_data_ids(conn: Any, publishxref_id: int, + phenotype_id: int, + strain_name: str) -> Tuple: + """Get the strain_id, publishdata_id and inbredset_id for a given strain""" + strain_id, publishdata_id, inbredset_id = None, None, None + with conn.cursor() as cursor: + cursor.execute("SELECT st.id, pd.Id, pf.InbredSetId " + "FROM PublishData pd " + "JOIN Strain st ON pd.StrainId = st.Id " + "JOIN PublishXRef px ON px.DataId = pd.Id " + "JOIN PublishFreeze pf ON pf.InbredSetId " + "= px.InbredSetId WHERE px.Id = %s " + "AND px.PhenotypeId = %s AND st.Name = %s", + (publishxref_id, phenotype_id, strain_name)) + if _result := cursor.fetchone(): + strain_id, publishdata_id, inbredset_id = _result + if not all([strain_id, publishdata_id, inbredset_id]): + # Applies for data to be inserted: + cursor.execute("SELECT DataId, InbredSetId FROM PublishXRef " + "WHERE Id = %s AND PhenotypeId = %s", + (publishxref_id, phenotype_id)) + publishdata_id, inbredset_id = cursor.fetchone() + cursor.execute("SELECT Id FROM Strain WHERE Name = %s", + (strain_name,)) + strain_id = cursor.fetchone()[0] + return (strain_id, publishdata_id, inbredset_id) + + +# pylint: disable=[R0913, R0914] +def update_sample_data(conn: Any, + trait_name: str, + original_data: str, + updated_data: str, + csv_header: str, + phenotype_id: int) -> int: + """Given the right parameters, update sample-data from the relevant + table.""" + def __update_data(conn, table, value): + if value and value != "x": + with conn.cursor() as cursor: + sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) + + " = %s") + _val = _MAP.get(table)[-1] + cursor.execute((f"UPDATE {table} SET {_val} = %s " + f"WHERE {sub_query}"), + (value, strain_id, data_id)) + return cursor.rowcount + return 0 + + def __update_case_attribute(conn, value, strain_id, + case_attr, inbredset_id): + if value != "x": + with conn.cursor() as cursor: + cursor.execute( + "UPDATE CaseAttributeXRefNew " + "SET Value = %s " + "WHERE StrainId = %s AND CaseAttributeId = " + "(SELECT CaseAttributeId FROM " + "CaseAttribute WHERE Name = %s) " + "AND InbredSetId = %s", + (value, strain_id, case_attr, inbredset_id)) + return cursor.rowcount + return 0 + + strain_id, data_id, inbredset_id = get_sample_data_ids( + conn=conn, publishxref_id=int(trait_name), + phenotype_id=phenotype_id, + strain_name=extract_strain_name(csv_header, original_data)) + + none_case_attrs: Dict[str, Callable] = { + "Strain Name": lambda x: 0, + "Value": lambda x: __update_data(conn, "PublishData", x), + "SE": lambda x: __update_data(conn, "PublishSE", x), + "Count": lambda x: __update_data(conn, "NStrain", x), + } + count = 0 + try: + __actions = __extract_actions(original_data=original_data, + updated_data=updated_data, + csv_header=csv_header) + if __actions.get("update"): + _csv_header = __actions["update"]["csv_header"] + _data = __actions["update"]["data"] + # pylint: disable=[E1101] + for header, value in zip(_csv_header.split(","), + _data.split(",")): + header = header.strip() + value = value.strip() + if header in none_case_attrs: + count += none_case_attrs[header](value) + else: + count += __update_case_attribute( + conn=conn, + value=none_case_attrs[header](value), + strain_id=strain_id, + case_attr=header, + inbredset_id=inbredset_id) + if __actions.get("delete"): + _rowcount = delete_sample_data( + conn=conn, + trait_name=trait_name, + data=__actions["delete"]["data"], + csv_header=__actions["delete"]["csv_header"], + phenotype_id=phenotype_id) + if _rowcount: + count += 1 + if __actions.get("insert"): + _rowcount = insert_sample_data( + conn=conn, + trait_name=trait_name, + data=__actions["insert"]["data"], + csv_header=__actions["insert"]["csv_header"], + phenotype_id=phenotype_id) + if _rowcount: + count += 1 + except Exception as _e: + conn.rollback() + raise MySQLdb.Error(_e) from _e + conn.commit() + return count + + +def delete_sample_data(conn: Any, + trait_name: str, + data: str, + csv_header: str, + phenotype_id: int) -> int: + """Given the right parameters, delete sample-data from the relevant + tables.""" + def __delete_data(conn, table): + sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) + " = %s") + with conn.cursor() as cursor: + cursor.execute((f"DELETE FROM {table} " + f"WHERE {sub_query}"), + (strain_id, data_id)) + return cursor.rowcount + + def __delete_case_attribute(conn, strain_id, + case_attr, inbredset_id): + with conn.cursor() as cursor: + cursor.execute( + "DELETE FROM CaseAttributeXRefNew " + "WHERE StrainId = %s AND CaseAttributeId = " + "(SELECT CaseAttributeId FROM " + "CaseAttribute WHERE Name = %s) " + "AND InbredSetId = %s", + (strain_id, case_attr, inbredset_id)) + return cursor.rowcount + + strain_id, data_id, inbredset_id = get_sample_data_ids( + conn=conn, publishxref_id=int(trait_name), + phenotype_id=phenotype_id, + strain_name=extract_strain_name(csv_header, data)) + + none_case_attrs: Dict[str, Any] = { + "Strain Name": lambda: 0, + "Value": lambda: __delete_data(conn, "PublishData"), + "SE": lambda: __delete_data(conn, "PublishSE"), + "Count": lambda: __delete_data(conn, "NStrain"), + } + count = 0 + + try: + for header in csv_header.split(","): + header = header.strip() + if header in none_case_attrs: + count += none_case_attrs[header]() + else: + count += __delete_case_attribute( + conn=conn, + strain_id=strain_id, + case_attr=header, + inbredset_id=inbredset_id) + except Exception as _e: + conn.rollback() + raise MySQLdb.Error(_e) from _e + conn.commit() + return count + + +# pylint: disable=[R0913, R0914] +def insert_sample_data(conn: Any, + trait_name: str, + data: str, + csv_header: str, + phenotype_id: int) -> int: + """Given the right parameters, insert sample-data to the relevant table. + + """ + def __insert_data(conn, table, value): + if value and value != "x": + with conn.cursor() as cursor: + columns = ", ".join(_MAP.get(table)) + cursor.execute((f"INSERT INTO {table} " + f"({columns}) " + f"VALUES (%s, %s, %s)"), + (strain_id, data_id, value)) + return cursor.rowcount + return 0 + + def __insert_case_attribute(conn, case_attr, value): + if value != "x": + with conn.cursor() as cursor: + cursor.execute("SELECT Id FROM " + "CaseAttribute WHERE Name = %s", + (case_attr,)) + if case_attr_id := cursor.fetchone(): + case_attr_id = case_attr_id[0] + cursor.execute("SELECT StrainId FROM " + "CaseAttributeXRefNew WHERE StrainId = %s " + "AND CaseAttributeId = %s " + "AND InbredSetId = %s", + (strain_id, case_attr_id, inbredset_id)) + if (not cursor.fetchone()) and case_attr_id: + cursor.execute( + "INSERT INTO CaseAttributeXRefNew " + "(StrainId, CaseAttributeId, Value, InbredSetId) " + "VALUES (%s, %s, %s, %s)", + (strain_id, case_attr_id, value, inbredset_id)) + row_count = cursor.rowcount + return row_count + return 0 + + strain_id, data_id, inbredset_id = get_sample_data_ids( + conn=conn, publishxref_id=int(trait_name), + phenotype_id=phenotype_id, + strain_name=extract_strain_name(csv_header, data)) + + none_case_attrs: Dict[str, Any] = { + "Strain Name": lambda _: 0, + "Value": lambda x: __insert_data(conn, "PublishData", x), + "SE": lambda x: __insert_data(conn, "PublishSE", x), + "Count": lambda x: __insert_data(conn, "NStrain", x), + } + + try: + count = 0 + + # Check if the data already exists: + with conn.cursor() as cursor: + cursor.execute( + "SELECT Id FROM PublishData where Id = %s " + "AND StrainId = %s", + (data_id, strain_id)) + if cursor.fetchone(): # Data already exists + return count + + for header, value in zip(csv_header.split(","), data.split(",")): + header = header.strip() + value = value.strip() + if header in none_case_attrs: + count += none_case_attrs[header](value) + else: + count += __insert_case_attribute( + conn=conn, + case_attr=header, + value=value) + return count + except Exception as _e: + conn.rollback() + raise MySQLdb.Error(_e) from _e diff --git a/gn3/db/species.py b/gn3/db/species.py index 702a9a8..5b8e096 100644 --- a/gn3/db/species.py +++ b/gn3/db/species.py @@ -57,3 +57,20 @@ def translate_to_mouse_gene_id(species: str, geneid: int, conn: Any) -> int: return translated_gene_id[0] return 0 # default if all else fails + +def species_name(conn: Any, group: str) -> str: + """ + Retrieve the name of the species, given the group (RISet). + + This is a migration of the + `web.webqtl.dbFunction.webqtlDatabaseFunction.retrieveSpecies` function in + GeneNetwork1. + """ + with conn.cursor() as cursor: + cursor.execute( + ("SELECT Species.Name FROM Species, InbredSet " + "WHERE InbredSet.Name = %(group_name)s " + "AND InbredSet.SpeciesId = Species.Id"), + {"group_name": group}) + return cursor.fetchone()[0] + return None diff --git a/gn3/db/traits.py b/gn3/db/traits.py index 1c6aaa7..f722e24 100644 --- a/gn3/db/traits.py +++ b/gn3/db/traits.py @@ -1,7 +1,7 @@ """This class contains functions relating to trait data manipulation""" import os from functools import reduce -from typing import Any, Dict, Union, Sequence +from typing import Any, Dict, Sequence from gn3.settings import TMPDIR from gn3.random import random_string @@ -67,7 +67,7 @@ def export_trait_data( return accumulator + (trait_data["data"][sample]["ndata"], ) if dtype == "all": return accumulator + __export_all_types(trait_data["data"], sample) - raise KeyError("Type `%s` is incorrect" % dtype) + raise KeyError(f"Type `{dtype}` is incorrect") if var_exists and n_exists: return accumulator + (None, None, None) if var_exists or n_exists: @@ -76,80 +76,6 @@ def export_trait_data( return reduce(__exporter, samplelist, tuple()) -def get_trait_csv_sample_data(conn: Any, - trait_name: int, phenotype_id: int): - """Fetch a trait and return it as a csv string""" - sql = ("SELECT DISTINCT Strain.Id, PublishData.Id, Strain.Name, " - "PublishData.value, " - "PublishSE.error, NStrain.count FROM " - "(PublishData, Strain, PublishXRef, PublishFreeze) " - "LEFT JOIN PublishSE ON " - "(PublishSE.DataId = PublishData.Id AND " - "PublishSE.StrainId = PublishData.StrainId) " - "LEFT JOIN NStrain ON (NStrain.DataId = PublishData.Id AND " - "NStrain.StrainId = PublishData.StrainId) WHERE " - "PublishXRef.InbredSetId = PublishFreeze.InbredSetId AND " - "PublishData.Id = PublishXRef.DataId AND " - "PublishXRef.Id = %s AND PublishXRef.PhenotypeId = %s " - "AND PublishData.StrainId = Strain.Id Order BY Strain.Name") - csv_data = ["Strain Id,Strain Name,Value,SE,Count"] - publishdata_id = "" - with conn.cursor() as cursor: - cursor.execute(sql, (trait_name, phenotype_id,)) - for record in cursor.fetchall(): - (strain_id, publishdata_id, - strain_name, value, error, count) = record - csv_data.append( - ",".join([str(val) if val else "x" - for val in (strain_id, strain_name, - value, error, count)])) - return f"# Publish Data Id: {publishdata_id}\n\n" + "\n".join(csv_data) - - -def update_sample_data(conn: Any, - strain_name: str, - strain_id: int, - publish_data_id: int, - value: Union[int, float, str], - error: Union[int, float, str], - count: Union[int, str]): - """Given the right parameters, update sample-data from the relevant - table.""" - # pylint: disable=[R0913, R0914, C0103] - STRAIN_ID_SQL: str = "UPDATE Strain SET Name = %s WHERE Id = %s" - PUBLISH_DATA_SQL: str = ("UPDATE PublishData SET value = %s " - "WHERE StrainId = %s AND Id = %s") - PUBLISH_SE_SQL: str = ("UPDATE PublishSE SET error = %s " - "WHERE StrainId = %s AND DataId = %s") - N_STRAIN_SQL: str = ("UPDATE NStrain SET count = %s " - "WHERE StrainId = %s AND DataId = %s") - - updated_strains: int = 0 - updated_published_data: int = 0 - updated_se_data: int = 0 - updated_n_strains: int = 0 - - with conn.cursor() as cursor: - # Update the Strains table - cursor.execute(STRAIN_ID_SQL, (strain_name, strain_id)) - updated_strains = cursor.rowcount - # Update the PublishData table - cursor.execute(PUBLISH_DATA_SQL, - (None if value == "x" else value, - strain_id, publish_data_id)) - updated_published_data = cursor.rowcount - # Update the PublishSE table - cursor.execute(PUBLISH_SE_SQL, - (None if error == "x" else error, - strain_id, publish_data_id)) - updated_se_data = cursor.rowcount - # Update the NStrain table - cursor.execute(N_STRAIN_SQL, - (None if count == "x" else count, - strain_id, publish_data_id)) - updated_n_strains = cursor.rowcount - return (updated_strains, updated_published_data, - updated_se_data, updated_n_strains) def retrieve_publish_trait_info(trait_data_source: Dict[str, Any], conn: Any): """Retrieve trait information for type `Publish` traits. @@ -177,24 +103,24 @@ def retrieve_publish_trait_info(trait_data_source: Dict[str, Any], conn: Any): "PublishXRef.comments") query = ( "SELECT " - "{columns} " + f"{columns} " "FROM " - "PublishXRef, Publication, Phenotype, PublishFreeze " + "PublishXRef, Publication, Phenotype " "WHERE " "PublishXRef.Id = %(trait_name)s AND " "Phenotype.Id = PublishXRef.PhenotypeId AND " "Publication.Id = PublishXRef.PublicationId AND " - "PublishXRef.InbredSetId = PublishFreeze.InbredSetId AND " - "PublishFreeze.Id =%(trait_dataset_id)s").format(columns=columns) + "PublishXRef.InbredSetId = %(trait_dataset_id)s") with conn.cursor() as cursor: cursor.execute( query, { - k:v for k, v in trait_data_source.items() + k: v for k, v in trait_data_source.items() if k in ["trait_name", "trait_dataset_id"] }) return dict(zip([k.lower() for k in keys], cursor.fetchone())) + def set_confidential_field(trait_type, trait_info): """Post processing function for 'Publish' trait types. @@ -207,6 +133,7 @@ def set_confidential_field(trait_type, trait_info): and not trait_info.get("pubmed_id", None)) else 0} return trait_info + def retrieve_probeset_trait_info(trait_data_source: Dict[str, Any], conn: Any): """Retrieve trait information for type `ProbeSet` traits. @@ -219,67 +146,68 @@ def retrieve_probeset_trait_info(trait_data_source: Dict[str, Any], conn: Any): "probe_set_specificity", "probe_set_blat_score", "probe_set_blat_mb_start", "probe_set_blat_mb_end", "probe_set_strand", "probe_set_note_by_rw", "flag") + columns = (f"ProbeSet.{x}" for x in keys) query = ( - "SELECT " - "{columns} " + f"SELECT {', '.join(columns)} " "FROM " "ProbeSet, ProbeSetFreeze, ProbeSetXRef " "WHERE " "ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id AND " "ProbeSetXRef.ProbeSetId = ProbeSet.Id AND " "ProbeSetFreeze.Name = %(trait_dataset_name)s AND " - "ProbeSet.Name = %(trait_name)s").format( - columns=", ".join(["ProbeSet.{}".format(x) for x in keys])) + "ProbeSet.Name = %(trait_name)s") with conn.cursor() as cursor: cursor.execute( query, { - k:v for k, v in trait_data_source.items() + k: v for k, v in trait_data_source.items() if k in ["trait_name", "trait_dataset_name"] }) return dict(zip(keys, cursor.fetchone())) + def retrieve_geno_trait_info(trait_data_source: Dict[str, Any], conn: Any): """Retrieve trait information for type `Geno` traits. https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L438-L449""" keys = ("name", "chr", "mb", "source2", "sequence") + columns = ", ".join(f"Geno.{x}" for x in keys) query = ( - "SELECT " - "{columns} " + f"SELECT {columns} " "FROM " - "Geno, GenoFreeze, GenoXRef " + "Geno INNER JOIN GenoXRef ON GenoXRef.GenoId = Geno.Id " + "INNER JOIN GenoFreeze ON GenoFreeze.Id = GenoXRef.GenoFreezeId " "WHERE " - "GenoXRef.GenoFreezeId = GenoFreeze.Id AND GenoXRef.GenoId = Geno.Id AND " "GenoFreeze.Name = %(trait_dataset_name)s AND " - "Geno.Name = %(trait_name)s").format( - columns=", ".join(["Geno.{}".format(x) for x in keys])) + "Geno.Name = %(trait_name)s") with conn.cursor() as cursor: cursor.execute( query, { - k:v for k, v in trait_data_source.items() + k: v for k, v in trait_data_source.items() if k in ["trait_name", "trait_dataset_name"] }) return dict(zip(keys, cursor.fetchone())) + def retrieve_temp_trait_info(trait_data_source: Dict[str, Any], conn: Any): """Retrieve trait information for type `Temp` traits. https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L450-452""" keys = ("name", "description") query = ( - "SELECT {columns} FROM Temp " - "WHERE Name = %(trait_name)s").format(columns=", ".join(keys)) + f"SELECT {', '.join(keys)} FROM Temp " + "WHERE Name = %(trait_name)s") with conn.cursor() as cursor: cursor.execute( query, { - k:v for k, v in trait_data_source.items() + k: v for k, v in trait_data_source.items() if k in ["trait_name"] }) return dict(zip(keys, cursor.fetchone())) + def set_haveinfo_field(trait_info): """ Common postprocessing function for all trait types. @@ -287,6 +215,7 @@ def set_haveinfo_field(trait_info): Sets the value for the 'haveinfo' field.""" return {**trait_info, "haveinfo": 1 if trait_info else 0} + def set_homologene_id_field_probeset(trait_info, conn): """ Postprocessing function for 'ProbeSet' traits. @@ -302,7 +231,7 @@ def set_homologene_id_field_probeset(trait_info, conn): cursor.execute( query, { - k:v for k, v in trait_info.items() + k: v for k, v in trait_info.items() if k in ["geneid", "group"] }) res = cursor.fetchone() @@ -310,12 +239,13 @@ def set_homologene_id_field_probeset(trait_info, conn): return {**trait_info, "homologeneid": res[0]} return {**trait_info, "homologeneid": None} + def set_homologene_id_field(trait_type, trait_info, conn): """ Common postprocessing function for all trait types. Sets the value for the 'homologene' key.""" - set_to_null = lambda ti: {**ti, "homologeneid": None} + def set_to_null(ti): return {**ti, "homologeneid": None} # pylint: disable=[C0103, C0321] functions_table = { "Temp": set_to_null, "Geno": set_to_null, @@ -324,6 +254,7 @@ def set_homologene_id_field(trait_type, trait_info, conn): } return functions_table[trait_type](trait_info) + def load_publish_qtl_info(trait_info, conn): """ Load extra QTL information for `Publish` traits @@ -344,6 +275,7 @@ def load_publish_qtl_info(trait_info, conn): return dict(zip(["locus", "lrs", "additive"], cursor.fetchone())) return {"locus": "", "lrs": "", "additive": ""} + def load_probeset_qtl_info(trait_info, conn): """ Load extra QTL information for `ProbeSet` traits @@ -366,6 +298,7 @@ def load_probeset_qtl_info(trait_info, conn): ["locus", "lrs", "pvalue", "mean", "additive"], cursor.fetchone())) return {"locus": "", "lrs": "", "pvalue": "", "mean": "", "additive": ""} + def load_qtl_info(qtl, trait_type, trait_info, conn): """ Load extra QTL information for traits @@ -389,11 +322,12 @@ def load_qtl_info(qtl, trait_type, trait_info, conn): "Publish": load_publish_qtl_info, "ProbeSet": load_probeset_qtl_info } - if trait_info["name"] not in qtl_info_functions.keys(): + if trait_info["name"] not in qtl_info_functions: return trait_info return qtl_info_functions[trait_type](trait_info, conn) + def build_trait_name(trait_fullname): """ Initialises the trait's name, and other values from the search data provided @@ -408,7 +342,7 @@ def build_trait_name(trait_fullname): return "ProbeSet" name_parts = trait_fullname.split("::") - assert len(name_parts) >= 2, "Name format error" + assert len(name_parts) >= 2, f"Name format error: '{trait_fullname}'" dataset_name = name_parts[0] dataset_type = dataset_type(dataset_name) return { @@ -420,6 +354,7 @@ def build_trait_name(trait_fullname): "cellid": name_parts[2] if len(name_parts) == 3 else "" } + def retrieve_probeset_sequence(trait, conn): """ Retrieve a 'ProbeSet' trait's sequence information @@ -441,6 +376,7 @@ def retrieve_probeset_sequence(trait, conn): seq = cursor.fetchone() return {**trait, "sequence": seq[0] if seq else ""} + def retrieve_trait_info( threshold: int, trait_full_name: str, conn: Any, qtl=None): @@ -496,6 +432,7 @@ def retrieve_trait_info( } return trait_info + def retrieve_temp_trait_data(trait_info: dict, conn: Any): """ Retrieve trait data for `Temp` traits. @@ -514,10 +451,12 @@ def retrieve_temp_trait_data(trait_info: dict, conn: Any): query, {"trait_name": trait_info["trait_name"]}) return [dict(zip( - ["sample_name", "value", "se_error", "nstrain", "id"], row)) + ["sample_name", "value", "se_error", "nstrain", "id"], + row)) for row in cursor.fetchall()] return [] + def retrieve_species_id(group, conn: Any): """ Retrieve a species id given the Group value @@ -529,6 +468,7 @@ def retrieve_species_id(group, conn: Any): return cursor.fetchone()[0] return None + def retrieve_geno_trait_data(trait_info: Dict, conn: Any): """ Retrieve trait data for `Geno` traits. @@ -552,11 +492,14 @@ def retrieve_geno_trait_data(trait_info: Dict, conn: Any): "dataset_name": trait_info["db"]["dataset_name"], "species_id": retrieve_species_id( trait_info["db"]["group"], conn)}) - return [dict(zip( - ["sample_name", "value", "se_error", "id"], row)) - for row in cursor.fetchall()] + return [ + dict(zip( + ["sample_name", "value", "se_error", "id"], + row)) + for row in cursor.fetchall()] return [] + def retrieve_publish_trait_data(trait_info: Dict, conn: Any): """ Retrieve trait data for `Publish` traits. @@ -565,17 +508,16 @@ def retrieve_publish_trait_data(trait_info: Dict, conn: Any): "SELECT " "Strain.Name, PublishData.value, PublishSE.error, NStrain.count, " "PublishData.Id " - "FROM (PublishData, Strain, PublishXRef, PublishFreeze) " + "FROM (PublishData, Strain, PublishXRef) " "LEFT JOIN PublishSE ON " "(PublishSE.DataId = PublishData.Id " "AND PublishSE.StrainId = PublishData.StrainId) " "LEFT JOIN NStrain ON " "(NStrain.DataId = PublishData.Id " "AND NStrain.StrainId = PublishData.StrainId) " - "WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId " - "AND PublishData.Id = PublishXRef.DataId " + "WHERE PublishData.Id = PublishXRef.DataId " "AND PublishXRef.Id = %(trait_name)s " - "AND PublishFreeze.Id = %(dataset_id)s " + "AND PublishXRef.InbredSetId = %(dataset_id)s " "AND PublishData.StrainId = Strain.Id " "ORDER BY Strain.Name") with conn.cursor() as cursor: @@ -583,11 +525,13 @@ def retrieve_publish_trait_data(trait_info: Dict, conn: Any): query, {"trait_name": trait_info["trait_name"], "dataset_id": trait_info["db"]["dataset_id"]}) - return [dict(zip( - ["sample_name", "value", "se_error", "nstrain", "id"], row)) - for row in cursor.fetchall()] + return [ + dict(zip( + ["sample_name", "value", "se_error", "nstrain", "id"], row)) + for row in cursor.fetchall()] return [] + def retrieve_cellid_trait_data(trait_info: Dict, conn: Any): """ Retrieve trait data for `Probe Data` types. @@ -616,11 +560,13 @@ def retrieve_cellid_trait_data(trait_info: Dict, conn: Any): {"cellid": trait_info["cellid"], "trait_name": trait_info["trait_name"], "dataset_id": trait_info["db"]["dataset_id"]}) - return [dict(zip( - ["sample_name", "value", "se_error", "id"], row)) - for row in cursor.fetchall()] + return [ + dict(zip( + ["sample_name", "value", "se_error", "id"], row)) + for row in cursor.fetchall()] return [] + def retrieve_probeset_trait_data(trait_info: Dict, conn: Any): """ Retrieve trait data for `ProbeSet` traits. @@ -645,11 +591,13 @@ def retrieve_probeset_trait_data(trait_info: Dict, conn: Any): query, {"trait_name": trait_info["trait_name"], "dataset_name": trait_info["db"]["dataset_name"]}) - return [dict(zip( - ["sample_name", "value", "se_error", "id"], row)) - for row in cursor.fetchall()] + return [ + dict(zip( + ["sample_name", "value", "se_error", "id"], row)) + for row in cursor.fetchall()] return [] + def with_samplelist_data_setup(samplelist: Sequence[str]): """ Build function that computes the trait data from provided list of samples. @@ -676,6 +624,7 @@ def with_samplelist_data_setup(samplelist: Sequence[str]): return None return setup_fn + def without_samplelist_data_setup(): """ Build function that computes the trait data. @@ -696,6 +645,7 @@ def without_samplelist_data_setup(): return None return setup_fn + def retrieve_trait_data(trait: dict, conn: Any, samplelist: Sequence[str] = tuple()): """ Retrieve trait data @@ -735,14 +685,16 @@ def retrieve_trait_data(trait: dict, conn: Any, samplelist: Sequence[str] = tupl "data": dict(map( lambda x: ( x["sample_name"], - {k:v for k, v in x.items() if x != "sample_name"}), + {k: v for k, v in x.items() if x != "sample_name"}), data))} return {} + def generate_traits_filename(base_path: str = TMPDIR): """Generate a unique filename for use with generated traits files.""" - return "{}/traits_test_file_{}.txt".format( - os.path.abspath(base_path), random_string(10)) + return ( + f"{os.path.abspath(base_path)}/traits_test_file_{random_string(10)}.txt") + def export_informative(trait_data: dict, inc_var: bool = False) -> tuple: """ @@ -765,5 +717,6 @@ def export_informative(trait_data: dict, inc_var: bool = False) -> tuple: return acc return reduce( __exporter__, - filter(lambda td: td["value"] is not None, trait_data["data"].values()), + filter(lambda td: td["value"] is not None, + trait_data["data"].values()), (tuple(), tuple(), tuple())) diff --git a/gn3/db_utils.py b/gn3/db_utils.py index 7263705..3b72d28 100644 --- a/gn3/db_utils.py +++ b/gn3/db_utils.py @@ -14,10 +14,7 @@ def parse_db_url() -> Tuple: parsed_db.password, parsed_db.path[1:]) -def database_connector() -> Tuple: +def database_connector() -> mdb.Connection: """function to create db connector""" host, user, passwd, db_name = parse_db_url() - conn = mdb.connect(host, user, passwd, db_name) - cursor = conn.cursor() - - return (conn, cursor) + return mdb.connect(host, user, passwd, db_name) diff --git a/gn3/fs_helpers.py b/gn3/fs_helpers.py index 73f6567..f313086 100644 --- a/gn3/fs_helpers.py +++ b/gn3/fs_helpers.py @@ -41,7 +41,7 @@ def get_dir_hash(directory: str) -> str: def jsonfile_to_dict(json_file: str) -> Dict: """Give a JSON_FILE, return a python dict""" - with open(json_file) as _file: + with open(json_file, encoding="utf-8") as _file: data = json.load(_file) return data raise FileNotFoundError @@ -71,9 +71,8 @@ contents to TARGET_DIR/<dir-hash>. os.mkdir(os.path.join(target_dir, token)) gzipped_file.save(tar_target_loc) # Extract to "tar_target_loc/token" - tar = tarfile.open(tar_target_loc) - tar.extractall(path=os.path.join(target_dir, token)) - tar.close() + with tarfile.open(tar_target_loc) as tar: + tar.extractall(path=os.path.join(target_dir, token)) # pylint: disable=W0703 except Exception: return {"status": 128, "error": "gzip failed to unpack file"} diff --git a/gn3/heatmaps.py b/gn3/heatmaps.py index bf9dfd1..91437bb 100644 --- a/gn3/heatmaps.py +++ b/gn3/heatmaps.py @@ -40,16 +40,15 @@ def trait_display_name(trait: Dict): if trait["db"]["dataset_type"] == "Temp": desc = trait["description"] if desc.find("PCA") >= 0: - return "%s::%s" % ( - trait["db"]["displayname"], - desc[desc.rindex(':')+1:].strip()) - return "%s::%s" % ( - trait["db"]["displayname"], - desc[:desc.index('entered')].strip()) - prefix = "%s::%s" % ( - trait["db"]["dataset_name"], trait["trait_name"]) + return ( + f'{trait["db"]["displayname"]}::' + f'{desc[desc.rindex(":")+1:].strip()}') + return ( + f'{trait["db"]["displayname"]}::' + f'{desc[:desc.index("entered")].strip()}') + prefix = f'{trait["db"]["dataset_name"]}::{trait["trait_name"]}' if trait["cellid"]: - return "%s::%s" % (prefix, trait["cellid"]) + return '{prefix}::{trait["cellid"]}' return prefix return trait["description"] @@ -64,11 +63,7 @@ def cluster_traits(traits_data_list: Sequence[Dict]): def __compute_corr(tdata_i, tdata_j): if tdata_i[0] == tdata_j[0]: return 0.0 - corr_vals = compute_correlation(tdata_i[1], tdata_j[1]) - corr = corr_vals[0] - if (1 - corr) < 0: - return 0.0 - return 1 - corr + return 1 - compute_correlation(tdata_i[1], tdata_j[1])[0] def __cluster(tdata_i): return tuple( @@ -136,8 +131,7 @@ def build_heatmap( traits_order = compute_traits_order(slinked) samples_and_values = retrieve_samples_and_values( traits_order, samples, exported_traits_data_list) - traits_filename = "{}/traits_test_file_{}.txt".format( - TMPDIR, random_string(10)) + traits_filename = f"{TMPDIR}/traits_test_file_{random_string(10)}.txt" generate_traits_file( samples_and_values[0][1], [t[2] for t in samples_and_values], @@ -314,7 +308,7 @@ def clustered_heatmap( vertical_spacing=0.010, horizontal_spacing=0.001, subplot_titles=["" if vertical else x_axis["label"]] + [ - "Chromosome: {}".format(chromo) if vertical else chromo + f"Chromosome: {chromo}" if vertical else chromo for chromo in x_axis_data],#+ x_axis_data, figure=ff.create_dendrogram( np.array(clustering_data), @@ -336,7 +330,7 @@ def clustered_heatmap( col=(1 if vertical else (i + 2))) axes_layouts = { - "{axis}axis{count}".format( + "{axis}axis{count}".format( # pylint: disable=[C0209] axis=("y" if vertical else "x"), count=(i+1 if i > 0 else "")): { "mirror": False, @@ -345,12 +339,10 @@ def clustered_heatmap( } for i in range(num_plots)} - print("vertical?: {} ==> {}".format("T" if vertical else "F", axes_layouts)) - fig.update_layout({ "width": 800 if vertical else 4000, "height": 4000 if vertical else 800, - "{}axis".format("x" if vertical else "y"): { + "{}axis".format("x" if vertical else "y"): { # pylint: disable=[C0209] "mirror": False, "ticks": "", "side": "top" if vertical else "left", @@ -358,7 +350,7 @@ def clustered_heatmap( "tickangle": 90 if vertical else 0, "ticklabelposition": "outside top" if vertical else "outside left" }, - "{}axis".format("y" if vertical else "x"): { + "{}axis".format("y" if vertical else "x"): { # pylint: disable=[C0209] "mirror": False, "showgrid": True, "title": "Distance", diff --git a/gn3/responses/__init__.py b/gn3/responses/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/gn3/responses/__init__.py diff --git a/gn3/responses/pcorrs_responses.py b/gn3/responses/pcorrs_responses.py new file mode 100644 index 0000000..d6fd9d7 --- /dev/null +++ b/gn3/responses/pcorrs_responses.py @@ -0,0 +1,24 @@ +"""Functions and classes that deal with responses and conversion to JSON.""" +import json + +from flask import make_response + +class OutputEncoder(json.JSONEncoder): + """ + Class to encode output into JSON, for objects which the default + json.JSONEncoder class does not have default encoding for. + """ + def default(self, o): + if isinstance(o, bytes): + return str(o, encoding="utf-8") + return json.JSONEncoder.default(self, o) + +def build_response(data): + """Build the responses for the API""" + status_codes = { + "error": 400, "not-found": 404, "success": 200, "exception": 500} + response = make_response( + json.dumps(data, cls=OutputEncoder), + status_codes[data["status"]]) + response.headers["Content-Type"] = "application/json" + return response diff --git a/gn3/settings.py b/gn3/settings.py index 57c63df..6eec2a1 100644 --- a/gn3/settings.py +++ b/gn3/settings.py @@ -13,11 +13,13 @@ REDIS_JOB_QUEUE = "GN3::job-queue" TMPDIR = os.environ.get("TMPDIR", tempfile.gettempdir()) RQTL_WRAPPER = "rqtl_wrapper.R" +# SPARQL endpoint +SPARQL_ENDPOINT = "http://localhost:8891/sparql" + # SQL confs SQL_URI = os.environ.get( "SQL_URI", "mysql://webqtlout:webqtlout@localhost/db_webqtl") SECRET_KEY = "password" -SQLALCHEMY_TRACK_MODIFICATIONS = False # gn2 results only used in fetching dataset info GN2_BASE_URL = "http://www.genenetwork.org/" @@ -25,11 +27,11 @@ GN2_BASE_URL = "http://www.genenetwork.org/" # wgcna script WGCNA_RSCRIPT = "wgcna_analysis.R" # qtlreaper command -REAPER_COMMAND = "{}/bin/qtlreaper".format(os.environ.get("GUIX_ENVIRONMENT")) +REAPER_COMMAND = f"{os.environ.get('GUIX_ENVIRONMENT')}/bin/qtlreaper" # genotype files GENOTYPE_FILES = os.environ.get( - "GENOTYPE_FILES", "{}/genotype_files/genotype".format(os.environ.get("HOME"))) + "GENOTYPE_FILES", f"{os.environ.get('HOME')}/genotype_files/genotype") # CROSS-ORIGIN SETUP def parse_env_cors(default): @@ -53,3 +55,7 @@ CORS_HEADERS = [ GNSHARE = os.environ.get("GNSHARE", "/gnshare/gn/") TEXTDIR = f"{GNSHARE}/web/ProbeSetFreeze_DataMatrix" + +ROUND_TO = 10 + +MULTIPROCESSOR_PROCS = 6 # Number of processes to spawn |