import json
import math
import requests
from functools import reduce
from typing import Union, Tuple
from urllib.parse import urljoin
from flask import (
flash,
request,
url_for,
redirect,
current_app,
render_template)
from gn2.wqflask import app
from gn2.utility.tools import get_setting, GN_SERVER_URL
from gn2.wqflask.database import database_connection
from gn3.db.partial_correlations import traits_info
def publish_target_databases(conn, groups, threshold):
query = (
"SELECT PublishFreeze.FullName,PublishFreeze.Name "
"FROM PublishFreeze, InbredSet "
"WHERE PublishFreeze.InbredSetId = InbredSet.Id "
f"AND InbredSet.Name IN ({', '.join(['%s'] * len(groups))}) "
"AND PublishFreeze.public > %s")
with conn.cursor() as cursor:
cursor.execute(query, tuple(groups) + (threshold,))
res = cursor.fetchall()
if res:
return tuple(
dict(zip(("description", "value"), row)) for row in res)
return tuple()
def geno_target_databases(conn, groups, threshold):
query = (
"SELECT GenoFreeze.FullName,GenoFreeze.Name "
"FROM GenoFreeze, InbredSet "
"WHERE GenoFreeze.InbredSetId = InbredSet.Id "
f"AND InbredSet.Name IN ({', '.join(['%s'] * len(groups))}) "
"AND GenoFreeze.public > %s")
with conn.cursor() as cursor:
cursor.execute(query, tuple(groups) + (threshold,))
res = cursor.fetchall()
if res:
return tuple(
dict(zip(("description", "value"), row)) for row in res)
return tuple()
def probeset_target_databases(conn, groups, threshold):
query1 = "SELECT Id, Name FROM Tissue order by Name"
with conn.cursor() as cursor:
cursor.execute(query1)
tissue_res = cursor.fetchall()
if tissue_res:
tissue_ids = tuple(row[0] for row in tissue_res)
groups_clauses = ["InbredSet.Name like %s"] * len(groups)
query2 = (
"SELECT ProbeFreeze.TissueId, ProbeSetFreeze.FullName, "
"ProbeSetFreeze.Name "
"FROM ProbeSetFreeze, ProbeFreeze, InbredSet "
"WHERE ProbeSetFreeze.ProbeFreezeId = ProbeFreeze.Id "
"AND ProbeFreeze.TissueId IN "
f"({', '.join(['%s'] * len(tissue_ids))}) "
"AND ProbeSetFreeze.public > %s "
"AND ProbeFreeze.InbredSetId = InbredSet.Id "
f"AND ({' OR '.join(groups_clauses)}) "
"ORDER BY ProbeSetFreeze.CreateTime desc, ProbeSetFreeze.AvgId")
cursor.execute(query2, tissue_ids + (threshold,) + tuple(groups))
db_res = cursor.fetchall()
if db_res:
databases = tuple(
dict(zip(("tissue_id", "description", "value"), row))
for row in db_res)
return tuple(
{tissue_name: tuple(
{
"value": item["value"],
"description": item["description"]
} for item in databases
if item["tissue_id"] == tissue_id)}
for tissue_id, tissue_name in tissue_res)
return tuple()
def target_databases(conn, traits, threshold):
"""
Retrieves the names of possible target databases from the database.
"""
trait_info = traits_info(
conn, threshold,
tuple(f"{trait['dataset']}::{trait['trait_name']}" for trait in traits))
groups = tuple(set(row["db"]["group"] for row in trait_info))
return (
publish_target_databases(conn, groups, threshold) +
geno_target_databases(conn, groups, threshold) +
probeset_target_databases(conn, groups, threshold))
def primary_error(args):
if len(args["primary_trait"]) == 0 or len(args["primary_trait"]) > 1:
return {
**args,
"errors": (args.get("errors", tuple()) +
("You must provide one, and only one primary trait",))}
return args
def controls_error(args):
if len(args["control_traits"]) == 0 or len(args["control_traits"]) > 3:
return {
**args,
"errors": (
args.get("errors", tuple()) +
(("You must provide at least one control trait, and a maximum "
"of three control traits"),))}
return args
def target_traits_error(args, with_target_traits):
target_traits_present = (
(args.get("target_traits") is not None) and
(len(args["target_traits"]) > 0))
if with_target_traits and not target_traits_present:
return {
**args,
"errors": (
args.get("errors", tuple()) +
(("You must provide at least one target trait"),))}
return args
def target_db_error(args, with_target_db: bool):
if with_target_db and not args["target_db"]:
return {
**args,
"errors": (
args.get("errors", tuple()) +
("The target database must be provided",))}
return args
def method_error(args):
methods = (
"pearson's r", "spearman's rho",
"genetic correlation, pearson's r",
"genetic correlation, spearman's rho",
"sgo literature correlation",
"tissue correlation, pearson's r",
"tissue correlation, spearman's rho")
if not args["method"] or args["method"].lower() not in methods:
return {
**args,
"errors": (
args.get("errors", tuple()) +
("Invalid correlation method provided",))}
return args
def criteria_error(args):
try:
int(args.get("criteria", "invalid"))
return args
except ValueError:
return {
**args,
"errors": (
args.get("errors", tuple()) +
("Invalid return number provided",))}
def errors(args, with_target_db: bool):
return {
**criteria_error(
method_error(
target_traits_error(
target_db_error(
controls_error(primary_error(args)),
with_target_db),
not with_target_db))),
"with_target_db": with_target_db
}
def __classify_args(acc, item):
if item[1].startswith("primary_"):
return {
**acc,
"primary_trait": (acc.get("primary_trait", tuple()) + (item,))}
if item[1].startswith("controls_"):
return {**acc, "control_traits": (acc.get("control_traits", tuple()) + (item,))}
if item[1].startswith("targets_"):
return {**acc, "target_traits": (acc.get("target_traits", tuple()) + (item,))}
if item[0] == "target_db":
return {**acc, "target_db": item[1]}
if item[0] == "method":
return {**acc, "method": item[1]}
if item[0] == "criteria":
return {**acc, "criteria": item[1]}
return acc
def __build_args(raw_form, traits):
args = reduce(__classify_args, raw_form.items(), {})
return {
**args,
"primary_trait": [
item for item in traits if item["trait_name"] in
(name[1][8:] for name in args["primary_trait"])],
"control_traits": [
item for item in traits if item["trait_name"] in
(name[1][9:] for name in args["control_traits"])],
"target_traits": [
item for item in traits if item["trait_name"] in
(name[1][8:] for name in args.get("target_traits", tuple()))]
}
def parse_trait(trait_str):
return dict(zip(
("trait_name", "dataset", "description", "symbol", "location", "mean",
"lrs", "lrs_location"),
trait_str.strip().split("|||")))
def response_error_message(response):
error_messages = {
404: ("We could not connect to the API server at this time. "
"Try again later."),
500: ("The API server experienced a problem. We will be working on a "
"fix. Please try again later.")
}
return error_messages.get(
response.status_code,
"General API server error!!")
def render_error(error_message, command_id = None):
return render_template(
"partial_correlations/pcorrs_error.html",
message = error_message,
command_id = command_id)
def __format_number(num):
if num is None or math.isnan(num):
return ""
if abs(num) <= 1.04E-4:
return f"{num:.2e}"
return f"{num:.5f}"
def handle_200_response(response):
if response.get("queued", False):
return redirect(
url_for(
"poll_partial_correlation_results",
command_id=response["results"]),
code=303)
if response["status"] == "success":
return render_template(
"partial_correlations/pcorrs_results_with_target_traits.html",
primary = response["results"]["results"]["primary_trait"],
controls = response["results"]["results"]["control_traits"],
pcorrs = sorted(
response["results"]["results"]["correlations"],
key = lambda item: item["partial_corr_p_value"]),
method = response["results"]["results"]["method"],
enumerate = enumerate,
format_number = __format_number)
return render_error(response["results"])
def handle_response(response):
if response.status_code != 200:
return render_template(
"partial_correlations/pcorrs_error.html",
message = response_error_message(response))
return handle_200_response(response.json())
@app.route("/partial_correlations", methods=["POST"])
def partial_correlations():
form = request.form
traits = tuple(
parse_trait(trait) for trait in
form.get("trait_list").split(";;;"))
submit = form.get("submit")
if submit in ("with_target_pearsons", "with_target_spearmans"):
method = "pearsons" if "pearsons" in submit else "spearmans"
args = {
**errors(__build_args(form, traits), with_target_db=False),
"method": method
}
if len(args.get("errors", [])) == 0:
post_data = {
**args,
"primary_trait": args["primary_trait"][0],
"with_target_db": args["with_target_db"]
}
return handle_response(requests.post(
url=urljoin(GN_SERVER_URL, "correlation/partial"),
json=post_data))
for error in args["errors"]:
flash(error, "alert-danger")
if submit == "Run Partial Correlations":
args = errors(__build_args(form, traits), with_target_db=True)
if len(args.get("errors", [])) == 0:
post_data = {
**args,
"primary_trait": args["primary_trait"][0],
"with_target_db": args["with_target_db"]
}
return handle_response(requests.post(
url=urljoin(GN_SERVER_URL, "correlation/partial"),
json=post_data))
for error in args["errors"]:
flash(error, "alert-danger")
with database_connection(get_setting("SQL_URI")) as conn:
target_dbs = target_databases(conn, traits, threshold=0)
return render_template(
"partial_correlations/pcorrs_select_operations.html",
trait_list_str=form.get("trait_list"),
traits=traits,
target_dbs=target_dbs)
def process_pcorrs_command_output(result):
if result["status"] == "success":
if result["results"]["dataset_type"] == "NOT SET YET":
return render_template(
"partial_correlations/pcorrs_results_with_target_traits.html",
primary = result["results"]["primary_trait"],
controls = result["results"]["control_traits"],
pcorrs = sorted(
result["results"]["correlations"],
key = lambda item: item["partial_corr_p_value"]),
method = result["results"]["method"],
enumerate = enumerate,
format_number = __format_number)
return render_template(
"partial_correlations/pcorrs_results_presentation.html",
primary=result["results"]["primary_trait"],
controls=result["results"]["control_traits"],
correlations=result["results"]["correlations"],
dataset_type=result["results"]["dataset_type"],
method=result["results"]["method"],
enumerate = enumerate,
format_number=__format_number)
if result["status"] == "error":
return render_error(
f"({result['error_type']}: {result['message']})")
@app.route("/partial_correlations/<command_id>", methods=["GET"])
def poll_partial_correlation_results(command_id):
response = requests.get(
url=urljoin(GN_SERVER_URL, f"async_commands/state/{command_id}"))
if response.status_code == 200:
data = response.json()
raw_result = data["result"]
result = {"status": "computing"}
if raw_result:
result = json.loads(raw_result)
if result["status"].lower() in ("error", "exception"):
return render_error(
"We messed up, and the computation failed due to a system "
"error.",
command_id)
if data["status"] == "success":
return process_pcorrs_command_output(json.loads(data["result"]))
return render_template(
"partial_correlations/pcorrs_poll_results.html",
command_id = command_id)
return render_error(
"We messed up, and the computation failed due to a system "
"error.",
command_id)