import json
import math
import requests
from functools import reduce
from typing import Union, Tuple
from urllib.parse import urljoin

from flask import (
    flash,
    request,
    url_for,
    redirect,
    current_app,
    render_template)

from gn2.wqflask import app
from gn2.utility.tools import get_setting, GN_SERVER_URL
from gn2.wqflask.database import database_connection
from gn3.db.partial_correlations import traits_info

def publish_target_databases(conn, groups, threshold):
    query = (
        "SELECT PublishFreeze.FullName,PublishFreeze.Name "
        "FROM PublishFreeze, InbredSet "
        "WHERE PublishFreeze.InbredSetId = InbredSet.Id "
        f"AND InbredSet.Name IN ({', '.join(['%s'] * len(groups))}) "
        "AND PublishFreeze.public > %s")
    with conn.cursor() as cursor:
        cursor.execute(query, tuple(groups) + (threshold,))
        res = cursor.fetchall()
        if res:
            return tuple(
                dict(zip(("description", "value"), row)) for row in res)

    return tuple()

def geno_target_databases(conn, groups, threshold):
    query = (
        "SELECT GenoFreeze.FullName,GenoFreeze.Name "
        "FROM GenoFreeze, InbredSet "
        "WHERE GenoFreeze.InbredSetId = InbredSet.Id "
        f"AND InbredSet.Name IN ({', '.join(['%s'] * len(groups))}) "
        "AND GenoFreeze.public > %s")
    with conn.cursor() as cursor:
        cursor.execute(query, tuple(groups) + (threshold,))
        res = cursor.fetchall()
        if res:
            return tuple(
                dict(zip(("description", "value"), row)) for row in res)

    return tuple()

def probeset_target_databases(conn, groups, threshold):
    query1 = "SELECT Id, Name FROM Tissue order by Name"
    with conn.cursor() as cursor:
        cursor.execute(query1)
        tissue_res = cursor.fetchall()
        if tissue_res:
            tissue_ids = tuple(row[0] for row in tissue_res)
            groups_clauses = ["InbredSet.Name like %s"] * len(groups)
            query2 = (
                "SELECT ProbeFreeze.TissueId, ProbeSetFreeze.FullName, "
                "ProbeSetFreeze.Name "
                "FROM ProbeSetFreeze, ProbeFreeze, InbredSet "
                "WHERE ProbeSetFreeze.ProbeFreezeId = ProbeFreeze.Id "
                "AND ProbeFreeze.TissueId IN "
                f"({', '.join(['%s'] * len(tissue_ids))}) "
                "AND ProbeSetFreeze.public > %s "
                "AND ProbeFreeze.InbredSetId = InbredSet.Id "
                f"AND ({' OR '.join(groups_clauses)}) "
                "ORDER BY ProbeSetFreeze.CreateTime desc, ProbeSetFreeze.AvgId")
            cursor.execute(query2, tissue_ids + (threshold,) + tuple(groups))
            db_res = cursor.fetchall()
            if db_res:
                databases = tuple(
                    dict(zip(("tissue_id", "description", "value"), row))
                    for row in db_res)
                return tuple(
                    {tissue_name: tuple(
                        {
                            "value": item["value"],
                            "description": item["description"]
                         } for item in databases
                        if item["tissue_id"] == tissue_id)}
                    for tissue_id, tissue_name in tissue_res)

    return tuple()

def target_databases(conn, traits, threshold):
    """
    Retrieves the names of possible target databases from the database.
    """
    trait_info = traits_info(
        conn, threshold,
        tuple(f"{trait['dataset']}::{trait['trait_name']}" for trait in traits))
    groups = tuple(set(row["db"]["group"] for row in trait_info))
    return (
        publish_target_databases(conn, groups, threshold) +
        geno_target_databases(conn, groups, threshold) +
        probeset_target_databases(conn, groups, threshold))

def primary_error(args):
    if len(args["primary_trait"]) == 0 or len(args["primary_trait"]) > 1:
        return {
            **args,
            "errors": (args.get("errors", tuple()) +
                       ("You must provide one, and only one primary trait",))}
    return args

def controls_error(args):
    if len(args["control_traits"]) == 0 or len(args["control_traits"]) > 3:
        return {
            **args,
            "errors": (
                args.get("errors", tuple()) +
                (("You must provide at least one control trait, and a maximum "
                  "of three control traits"),))}
    return args

def target_traits_error(args, with_target_traits):
    target_traits_present = (
        (args.get("target_traits") is not None) and
        (len(args["target_traits"]) > 0))
    if with_target_traits and not target_traits_present:
        return {
            **args,
            "errors": (
                args.get("errors", tuple()) +
                (("You must provide at least one target trait"),))}
    return args

def target_db_error(args, with_target_db: bool):
    if with_target_db and not args["target_db"]:
        return {
            **args,
            "errors": (
                args.get("errors", tuple()) +
                ("The target database must be provided",))}
    return args

def method_error(args):
    methods = (
        "pearson's r", "spearman's rho",
        "genetic correlation, pearson's r",
        "genetic correlation, spearman's rho",
        "sgo literature correlation",
        "tissue correlation, pearson's r",
        "tissue correlation, spearman's rho")
    if not args["method"] or args["method"].lower() not in methods:
        return {
            **args,
            "errors": (
                args.get("errors", tuple()) +
                ("Invalid correlation method provided",))}
    return args

def criteria_error(args):
    try:
        int(args.get("criteria", "invalid"))
        return args
    except ValueError:
        return {
            **args,
            "errors": (
                args.get("errors", tuple()) +
                ("Invalid return number provided",))}

def errors(args, with_target_db: bool):
    return {
        **criteria_error(
            method_error(
                target_traits_error(
                    target_db_error(
                        controls_error(primary_error(args)),
                        with_target_db),
                    not with_target_db))),
        "with_target_db": with_target_db
    }

def __classify_args(acc, item):
    if item[1].startswith("primary_"):
        return {
            **acc,
            "primary_trait": (acc.get("primary_trait", tuple()) + (item,))}
    if item[1].startswith("controls_"):
        return {**acc, "control_traits": (acc.get("control_traits", tuple()) + (item,))}
    if item[1].startswith("targets_"):
        return {**acc, "target_traits": (acc.get("target_traits", tuple()) + (item,))}
    if item[0] == "target_db":
        return {**acc, "target_db": item[1]}
    if item[0] == "method":
        return {**acc, "method": item[1]}
    if item[0] == "criteria":
        return {**acc, "criteria": item[1]}
    return acc

def __build_args(raw_form, traits):
    args = reduce(__classify_args, raw_form.items(), {})
    return {
        **args,
        "primary_trait": [
            item for item in traits if item["trait_name"] in
            (name[1][8:] for name in args["primary_trait"])],
        "control_traits": [
            item for item in traits if item["trait_name"] in
            (name[1][9:] for name in args["control_traits"])],
        "target_traits": [
            item for item in traits if item["trait_name"] in
            (name[1][8:] for name in args.get("target_traits", tuple()))]
    }

def parse_trait(trait_str):
    return dict(zip(
        ("trait_name", "dataset", "description", "symbol", "location", "mean",
         "lrs", "lrs_location"),
        trait_str.strip().split("|||")))

def response_error_message(response):
    error_messages = {
        404: ("We could not connect to the API server at this time. "
              "Try again later."),
        500: ("The API server experienced a problem. We will be working on a "
              "fix. Please try again later.")
    }
    return error_messages.get(
        response.status_code,
        "General API server error!!")

def render_error(error_message, command_id = None):
    return render_template(
        "partial_correlations/pcorrs_error.html",
        message = error_message,
        command_id = command_id)

def __format_number(num):
    if num is None or math.isnan(num):
        return ""
    if abs(num) <= 1.04E-4:
        return f"{num:.2e}"
    return f"{num:.5f}"

def handle_200_response(response):
    if response.get("queued", False):
        return redirect(
            url_for(
                "poll_partial_correlation_results",
                command_id=response["results"]),
            code=303)
    if response["status"] == "success":
        return render_template(
            "partial_correlations/pcorrs_results_with_target_traits.html",
            primary = response["results"]["results"]["primary_trait"],
            controls = response["results"]["results"]["control_traits"],
            pcorrs = sorted(
                response["results"]["results"]["correlations"],
                key = lambda item: item["partial_corr_p_value"]),
            method = response["results"]["results"]["method"],
            enumerate = enumerate,
            format_number = __format_number)
    return render_error(response["results"])

def handle_response(response):
    if response.status_code != 200:
        return render_template(
            "partial_correlations/pcorrs_error.html",
            message = response_error_message(response))
    return handle_200_response(response.json())

@app.route("/partial_correlations", methods=["POST"])
def partial_correlations():
    form = request.form
    traits = tuple(
        parse_trait(trait) for trait in
        form.get("trait_list").split(";;;"))

    submit = form.get("submit")

    if submit in ("with_target_pearsons", "with_target_spearmans"):
        method = "pearsons" if "pearsons" in submit else "spearmans"
        args = {
            **errors(__build_args(form, traits), with_target_db=False),
            "method": method
        }
        if len(args.get("errors", [])) == 0:
            post_data = {
                **args,
                "primary_trait": args["primary_trait"][0],
                "with_target_db": args["with_target_db"]
            }
            return handle_response(requests.post(
                url=urljoin(GN_SERVER_URL, "correlation/partial"),
                json=post_data))

        for error in args["errors"]:
            flash(error, "alert-danger")

    if submit == "Run Partial Correlations":
        args = errors(__build_args(form, traits), with_target_db=True)
        if len(args.get("errors", [])) == 0:
            post_data = {
                **args,
                "primary_trait": args["primary_trait"][0],
                "with_target_db": args["with_target_db"]
            }
            return handle_response(requests.post(
                url=urljoin(GN_SERVER_URL, "correlation/partial"),
                json=post_data))

        for error in args["errors"]:
            flash(error, "alert-danger")

    with database_connection(get_setting("SQL_URI")) as conn:
        target_dbs = target_databases(conn, traits, threshold=0)
        return render_template(
            "partial_correlations/pcorrs_select_operations.html",
            trait_list_str=form.get("trait_list"),
            traits=traits,
            target_dbs=target_dbs)

def process_pcorrs_command_output(result):
    if result["status"] == "success":

        if result["results"]["dataset_type"] == "NOT SET YET":
            return render_template(
                "partial_correlations/pcorrs_results_with_target_traits.html",
                primary = result["results"]["primary_trait"],
                controls = result["results"]["control_traits"],
                pcorrs = sorted(
                    result["results"]["correlations"],
                    key = lambda item: item["partial_corr_p_value"]),
                method = result["results"]["method"],
                enumerate = enumerate,
                format_number = __format_number)

        return render_template(
            "partial_correlations/pcorrs_results_presentation.html",
            primary=result["results"]["primary_trait"],
            controls=result["results"]["control_traits"],
            correlations=result["results"]["correlations"],
            dataset_type=result["results"]["dataset_type"],
            method=result["results"]["method"],
            enumerate = enumerate,
            format_number=__format_number)
    if result["status"] == "error":
        return render_error(
            f"({result['error_type']}: {result['message']})")

@app.route("/partial_correlations/<command_id>", methods=["GET"])
def poll_partial_correlation_results(command_id):
    response = requests.get(
        url=urljoin(GN_SERVER_URL, f"async_commands/state/{command_id}"))

    if response.status_code == 200:
        data = response.json()
        raw_result = data["result"]
        result = {"status": "computing"}
        if raw_result:
            result = json.loads(raw_result)
        if result["status"].lower() in ("error", "exception"):
            return render_error(
                "We messed up, and the computation failed due to a system "
                "error.",
                command_id)
        if data["status"] == "success":
            return process_pcorrs_command_output(json.loads(data["result"]))
        return render_template(
            "partial_correlations/pcorrs_poll_results.html",
            command_id = command_id)
    return render_error(
        "We messed up, and the computation failed due to a system "
        "error.",
        command_id)