about summary refs log tree commit diff
path: root/gn3
diff options
context:
space:
mode:
Diffstat (limited to 'gn3')
-rw-r--r--gn3/api/correlation.py31
-rw-r--r--gn3/authentication.py1
-rw-r--r--gn3/computations/correlations.py5
-rw-r--r--gn3/computations/partial_correlations.py137
-rw-r--r--gn3/db/correlations.py22
-rw-r--r--gn3/db/datasets.py12
-rw-r--r--gn3/db/traits.py48
7 files changed, 196 insertions, 60 deletions
diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py
index 46121f8..1caf31f 100644
--- a/gn3/api/correlation.py
+++ b/gn3/api/correlation.py
@@ -1,13 +1,16 @@
 """Endpoints for running correlations"""
+import json
 from flask import jsonify
 from flask import Blueprint
 from flask import request
+from flask import make_response
 
 from gn3.computations.correlations import compute_all_sample_correlation
 from gn3.computations.correlations import compute_all_lit_correlation
 from gn3.computations.correlations import compute_tissue_correlation
 from gn3.computations.correlations import map_shared_keys_to_values
 from gn3.db_utils import database_connector
+from gn3.computations.partial_correlations import partial_correlations_entry
 
 correlation = Blueprint("correlation", __name__)
 
@@ -83,3 +86,31 @@ def compute_tissue_corr(corr_method="pearson"):
                                          corr_method=corr_method)
 
     return jsonify(results)
+
+@correlation.route("/partial", methods=["POST"])
+def partial_correlation():
+    """API endpoint for partial correlations."""
+    def trait_fullname(trait):
+        return f"{trait['dataset']}::{trait['name']}"
+
+    class OutputEncoder(json.JSONEncoder):
+        """
+        Class to encode output into JSON, for objects which the default
+        json.JSONEncoder class does not have default encoding for.
+        """
+        def default(self, obj):
+            if isinstance(obj, bytes):
+                return str(obj, encoding="utf-8")
+            return json.JSONEncoder.default(self, obj)
+
+    args = request.get_json()
+    conn, _cursor_object = database_connector()
+    corr_results = partial_correlations_entry(
+        conn, trait_fullname(args["primary_trait"]),
+        tuple(trait_fullname(trait) for trait in args["control_traits"]),
+        args["method"], int(args["criteria"]), args["target_db"])
+    response = make_response(
+        json.dumps(corr_results, cls=OutputEncoder).replace(": NaN", ": null"),
+        400 if "error" in corr_results.keys() else 200)
+    response.headers["Content-Type"] = "application/json"
+    return response
diff --git a/gn3/authentication.py b/gn3/authentication.py
index a6372c1..d0b35bc 100644
--- a/gn3/authentication.py
+++ b/gn3/authentication.py
@@ -163,3 +163,4 @@ def create_group(conn: Redis, group_name: Optional[str],
         }
         conn.hset("groups", group_id, json.dumps(group))
         return group
+    return None
diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py
index 37c70e9..09288c5 100644
--- a/gn3/computations/correlations.py
+++ b/gn3/computations/correlations.py
@@ -8,6 +8,7 @@ from typing import List
 from typing import Tuple
 from typing import Optional
 from typing import Callable
+from typing import Generator
 
 import scipy.stats
 import pingouin as pg
@@ -80,7 +81,7 @@ def compute_sample_r_correlation(trait_name, corr_method, trait_vals,
             zip(*list(normalize_values(trait_vals, target_samples_vals))))
         num_overlap = len(normalized_traits_vals)
     except ValueError:
-        return
+        return None
 
     if num_overlap > 5:
 
@@ -107,7 +108,7 @@ package :not packaged in guix
 
 
 def filter_shared_sample_keys(this_samplelist,
-                              target_samplelist) -> Tuple[List, List]:
+                              target_samplelist) -> Generator:
     """Given primary and target sample-list for two base and target trait select
     filter the values using the shared keys
 
diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py
index 719c605..984c15a 100644
--- a/gn3/computations/partial_correlations.py
+++ b/gn3/computations/partial_correlations.py
@@ -18,6 +18,7 @@ from gn3.random import random_string
 from gn3.function_helpers import  compose
 from gn3.data_helpers import parse_csv_line
 from gn3.db.traits import export_informative
+from gn3.db.datasets import retrieve_trait_dataset
 from gn3.db.traits import retrieve_trait_info, retrieve_trait_data
 from gn3.db.species import species_name, translate_to_mouse_gene_id
 from gn3.db.correlations import (
@@ -216,7 +217,7 @@ def good_dataset_samples_indexes(
 def partial_correlations_fast(# pylint: disable=[R0913, R0914]
         samples, primary_vals, control_vals, database_filename,
         fetched_correlations, method: str, correlation_type: str) -> Tuple[
-            float, Tuple[float, ...]]:
+            int, Tuple[float, ...]]:
     """
     Computes partial correlation coefficients using data from a CSV file.
 
@@ -257,8 +258,9 @@ def partial_correlations_fast(# pylint: disable=[R0913, R0914]
     ## `correlation_type` parameter
     return len(all_correlations), tuple(
         corr + (
-            (fetched_correlations[corr[0]],) if correlation_type == "literature"
-            else fetched_correlations[corr[0]][0:2])
+            (fetched_correlations[corr[0]],) # type: ignore[index]
+            if correlation_type == "literature"
+            else fetched_correlations[corr[0]][0:2]) # type: ignore[index]
         for idx, corr in enumerate(all_correlations))
 
 def build_data_frame(
@@ -305,11 +307,19 @@ def compute_partial(
             prim for targ, prim in zip(targ_vals, primary_vals)
             if targ is not None]
 
+        if len(primary) < 3:
+            return None
+
+        def __remove_controls_for_target_nones(cont_targ):
+            return tuple(cont for cont, targ in cont_targ if targ is not None)
+
+        conts_targs = tuple(tuple(
+            zip(control, targ_vals)) for control in control_vals)
         datafrm = build_data_frame(
             primary,
-            tuple(targ for targ in targ_vals if targ is not None),
-            tuple(cont for i, cont in enumerate(control_vals)
-                  if target[0][i] is not None))
+            [targ for targ in targ_vals if targ is not None],
+            [__remove_controls_for_target_nones(cont_targ)
+             for cont_targ in conts_targs])
         covariates = "z" if datafrm.shape[1] == 3 else [
             col for col in datafrm.columns if col not in ("x", "y")]
         ppc = pingouin.partial_corr(
@@ -332,13 +342,17 @@ def compute_partial(
             zero_order_corr["r"][0], zero_order_corr["p-val"][0])
 
     return tuple(
-        __compute_trait_info__(target)
-        for target in zip(target_vals, target_names))
+        result for result in (
+            __compute_trait_info__(target)
+            for target in zip(target_vals, target_names))
+        if result is not None)
 
 def partial_correlations_normal(# pylint: disable=R0913
         primary_vals, control_vals, input_trait_gene_id, trait_database,
         data_start_pos: int, db_type: str, method: str) -> Tuple[
-            float, Tuple[float, ...]]:
+            int, Tuple[Union[
+                Tuple[str, int, float, float, float, float], None],
+                       ...]]:#Tuple[float, ...]
     """
     Computes the correlation coefficients.
 
@@ -360,7 +374,7 @@ def partial_correlations_normal(# pylint: disable=R0913
             return tuple(item) + (trait_database[1], trait_database[2])
         return item
 
-    target_trait_names, target_trait_vals = reduce(
+    target_trait_names, target_trait_vals = reduce(# type: ignore[var-annotated]
         lambda acc, item: (acc[0]+(item[0],), acc[1]+(item[data_start_pos:],)),
         trait_database, (tuple(), tuple()))
 
@@ -413,7 +427,7 @@ def partial_corrs(# pylint: disable=[R0913]
         data_start_pos, dataset, method)
 
 def literature_correlation_by_list(
-        conn: Any, species: str, trait_list: Tuple[dict]) -> Tuple[dict]:
+        conn: Any, species: str, trait_list: Tuple[dict]) -> Tuple[dict, ...]:
     """
     This is a migration of the
     `web.webqtl.correlation.CorrelationPage.getLiteratureCorrelationByList`
@@ -473,7 +487,7 @@ def literature_correlation_by_list(
 
 def tissue_correlation_by_list(
         conn: Any, primary_trait_symbol: str, tissue_probeset_freeze_id: int,
-        method: str, trait_list: Tuple[dict]) -> Tuple[dict]:
+        method: str, trait_list: Tuple[dict]) -> Tuple[dict, ...]:
     """
     This is a migration of the
     `web.webqtl.correlation.CorrelationPage.getTissueCorrelationByList`
@@ -496,7 +510,7 @@ def tissue_correlation_by_list(
             primary_trait_value = prim_trait_symbol_value_dict[
                 primary_trait_symbol.lower()]
             gene_symbol_list = tuple(
-                trait for trait in trait_list if "symbol" in trait.keys())
+                trait["symbol"] for trait in trait_list if "symbol" in trait.keys())
             symbol_value_dict = fetch_gene_symbol_tissue_value_dict_for_trait(
                 gene_symbol_list, tissue_probeset_freeze_id, conn)
             return tuple(
@@ -514,6 +528,54 @@ def tissue_correlation_by_list(
         } for trait in trait_list)
     return trait_list
 
+def trait_for_output(trait):
+    """
+    Process a trait for output.
+
+    Removes a lot of extraneous data from the trait, that is not needed for
+    the display of partial correlation results.
+    This function also removes all key-value pairs, for which the value is
+    `None`, because it is a waste of network resources to transmit the key-value
+    pair just to indicate it does not exist.
+    """
+    trait = {
+        "trait_type": trait["trait_type"],
+        "dataset_name": trait["db"]["dataset_name"],
+        "dataset_type": trait["db"]["dataset_type"],
+        "group": trait["db"]["group"],
+        "trait_fullname": trait["trait_fullname"],
+        "trait_name": trait["trait_name"],
+        "symbol": trait.get("symbol"),
+        "description": trait.get("description"),
+        "pre_publication_description": trait.get(
+            "pre_publication_description"),
+        "post_publication_description": trait.get(
+            "post_publication_description"),
+        "original_description": trait.get(
+            "original_description"),
+        "authors": trait.get("authors"),
+        "year": trait.get("year"),
+        "probe_target_description": trait.get(
+            "probe_target_description"),
+        "chr": trait.get("chr"),
+        "mb": trait.get("mb"),
+        "geneid": trait.get("geneid"),
+        "homologeneid": trait.get("homologeneid"),
+        "noverlap": trait.get("noverlap"),
+        "partial_corr": trait.get("partial_corr"),
+        "partial_corr_p_value": trait.get("partial_corr_p_value"),
+        "corr": trait.get("corr"),
+        "corr_p_value": trait.get("corr_p_value"),
+        "rank_order": trait.get("rank_order"),
+        "delta": (
+            None if trait.get("partial_corr") is None
+            else (trait.get("partial_corr") - trait.get("corr"))),
+        "l_corr":  trait.get("l_corr"),
+        "tissue_corr": trait.get("tissue_corr"),
+        "tissue_p_value": trait.get("tissue_p_value")
+    }
+    return {key: val for key, val in trait.items() if val is not None}
+
 def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
         conn: Any, primary_trait_name: str,
         control_trait_names: Tuple[str, ...], method: str,
@@ -640,28 +702,47 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
                 "any associated Tissue Correlation Information."),
             "error_type": "Tissue Correlation"}
 
+    target_dataset = retrieve_trait_dataset(
+        ("Temp" if "Temp" in target_db_name else
+         ("Publish" if "Publish" in target_db_name else
+          "Geno" if "Geno" in target_db_name else "ProbeSet")),
+        {"db": {"dataset_name": target_db_name}, "trait_name": "_"},
+        threshold,
+        conn)
+
     database_filename = get_filename(conn, target_db_name, TEXTDIR)
     _total_traits, all_correlations = partial_corrs(
         conn, common_primary_control_samples, fixed_primary_vals,
         fixed_control_vals, len(fixed_primary_vals), species,
         input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id,
-        method, primary_trait["db"], database_filename)
+        method, {**target_dataset, "dataset_type": target_dataset["type"]}, database_filename)
 
 
     def __make_sorter__(method):
-        def __sort_6__(row):
-            return row[6]
-
-        def __sort_3__(row):
+        def __compare_lit_or_tiss_correlation_values_(row):
+            # Index  Content
+            # 0      trait name
+            # 1      N
+            # 2      partial correlation coefficient
+            # 3      p value of partial correlation
+            # 6      literature/tissue correlation value
+            return (row[6], row[3])
+
+        def __compare_partial_correlation_p_values__(row):
+            # Index  Content
+            # 0      trait name
+            # 1      partial correlation coefficient
+            # 2      N
+            # 3      p value of partial correlation
             return row[3]
 
         if "literature" in method.lower():
-            return __sort_6__
+            return __compare_lit_or_tiss_correlation_values_
 
         if "tissue" in method.lower():
-            return __sort_6__
+            return __compare_lit_or_tiss_correlation_values_
 
-        return __sort_3__
+        return __compare_partial_correlation_p_values__
 
     sorted_correlations = sorted(
         all_correlations, key=__make_sorter__(method))
@@ -676,7 +757,7 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
         {
             **retrieve_trait_info(
                 threshold,
-                f"{primary_trait['db']['dataset_name']}::{item[0]}",
+                f"{target_dataset['dataset_name']}::{item[0]}",
                 conn),
             "noverlap": item[1],
             "partial_corr": item[2],
@@ -694,4 +775,14 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911]
         for item in
         sorted_correlations[:min(criteria, len(all_correlations))]))
 
-    return trait_list
+    return {
+        "status": "success",
+        "results": {
+            "primary_trait": trait_for_output(primary_trait),
+            "control_traits": tuple(
+                trait_for_output(trait) for trait in cntrl_traits),
+            "correlations": tuple(
+                trait_for_output(trait) for trait in trait_list),
+            "dataset_type": target_dataset["type"],
+            "method": "spearman" if "spearman" in method.lower() else "pearson"
+        }}
diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py
index 254af10..5361a1e 100644
--- a/gn3/db/correlations.py
+++ b/gn3/db/correlations.py
@@ -157,11 +157,12 @@ def fetch_symbol_value_pair_dict(
         symbol: data_id_dict.get(symbol) for symbol in symbol_list
         if data_id_dict.get(symbol) is not None
     }
-    query = "SELECT Id, value FROM TissueProbeSetData WHERE Id IN %(data_ids)s"
+    query = "SELECT Id, value FROM TissueProbeSetData WHERE Id IN ({})".format(
+        ",".join(f"%(id{i})s" for i in range(len(data_ids.values()))))
     with conn.cursor() as cursor:
         cursor.execute(
             query,
-            data_ids=tuple(data_ids.values()))
+            **{f"id{i}": did for i, did in enumerate(data_ids.values())})
         value_results = cursor.fetchall()
         return {
             key: tuple(row[1] for row in value_results if row[0] == key)
@@ -406,21 +407,22 @@ def fetch_sample_ids(
     """
     query = (
         "SELECT Strain.Id FROM Strain, Species "
-        "WHERE Strain.Name IN %(samples_names)s "
+        "WHERE Strain.Name IN ({}) "
         "AND Strain.SpeciesId=Species.Id "
-        "AND Species.name=%(species_name)s")
+        "AND Species.name=%(species_name)s").format(
+            ",".join(f"%(s{i})s" for i in range(len(sample_names))))
     with conn.cursor() as cursor:
         cursor.execute(
             query,
             {
-                "samples_names": tuple(sample_names),
+                **{f"s{i}": sname for i, sname in enumerate(sample_names)},
                 "species_name": species_name
             })
         return tuple(row[0] for row in cursor.fetchall())
 
 def build_query_sgo_lit_corr(
         db_type: str, temp_table: str, sample_id_columns: str,
-        joins: Tuple[str, ...]) -> str:
+        joins: Tuple[str, ...]) -> Tuple[str, int]:
     """
     Build query for `SGO Literature Correlation` data, when querying the given
     `temp_table` temporary table.
@@ -483,14 +485,14 @@ def fetch_all_database_data(# pylint: disable=[R0913, R0914]
         sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in sample_ids)
         if db_type == "Publish":
             joins = tuple(
-                ("LEFT JOIN PublishData AS T{item} "
-                 "ON T{item}.Id = PublishXRef.DataId "
-                 "AND T{item}.StrainId = %(T{item}_sample_id)s")
+                (f"LEFT JOIN PublishData AS T{item} "
+                 f"ON T{item}.Id = PublishXRef.DataId "
+                 f"AND T{item}.StrainId = %(T{item}_sample_id)s")
                 for item in sample_ids)
             return (
                 ("SELECT PublishXRef.Id, " +
                  sample_id_columns +
-                 "FROM (PublishXRef, PublishFreeze) " +
+                 " FROM (PublishXRef, PublishFreeze) " +
                  " ".join(joins) +
                  " WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
                  "AND PublishFreeze.Name = %(db_name)s"),
diff --git a/gn3/db/datasets.py b/gn3/db/datasets.py
index c50e148..a41e228 100644
--- a/gn3/db/datasets.py
+++ b/gn3/db/datasets.py
@@ -3,7 +3,7 @@ This module contains functions relating to specific trait dataset manipulation
 """
 import re
 from string import Template
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional
 from SPARQLWrapper import JSON, SPARQLWrapper
 from gn3.settings import SPARQL_ENDPOINT
 
@@ -297,7 +297,7 @@ def retrieve_trait_dataset(trait_type, trait, threshold, conn):
         **group
     }
 
-def sparql_query(query: str) -> Dict[str, Any]:
+def sparql_query(query: str) -> List[Dict[str, Any]]:
     """Run a SPARQL query and return the bound variables."""
     sparql = SPARQLWrapper(SPARQL_ENDPOINT)
     sparql.setQuery(query)
@@ -328,7 +328,7 @@ WHERE {
   OPTIONAL { ?dataset gn:geoSeries ?geo_series } .
 }
 """,
-             """
+               """
 PREFIX gn: <http://genenetwork.org/>
 SELECT ?platform_name ?normalization_name ?species_name ?inbred_set_name ?tissue_name
 WHERE {
@@ -341,7 +341,7 @@ WHERE {
   OPTIONAL { ?dataset gn:datasetOfPlatform / gn:name ?platform_name } .
 }
 """,
-             """
+               """
 PREFIX gn: <http://genenetwork.org/>
 SELECT ?specifics ?summary ?about_cases ?about_tissue ?about_platform
        ?about_data_processing ?notes ?experiment_design ?contributors
@@ -362,8 +362,8 @@ WHERE {
   OPTIONAL { ?dataset gn:acknowledgment ?acknowledgment . }
 }
 """]
-    result = {'accession_id': accession_id,
-              'investigator': {}}
+    result: Dict[str, Any] = {'accession_id': accession_id,
+                              'investigator': {}}
     query_result = {}
     for query in queries:
         if sparql_result := sparql_query(Template(query).substitute(accession_id=accession_id)):
diff --git a/gn3/db/traits.py b/gn3/db/traits.py
index 7994aef..338b320 100644
--- a/gn3/db/traits.py
+++ b/gn3/db/traits.py
@@ -1,6 +1,5 @@
 """This class contains functions relating to trait data manipulation"""
 import os
-import MySQLdb
 from functools import reduce
 from typing import Any, Dict, Union, Sequence
 
@@ -111,7 +110,6 @@ def get_trait_csv_sample_data(conn: Any,
 
 
 def update_sample_data(conn: Any, #pylint: disable=[R0913]
-
                        trait_name: str,
                        strain_name: str,
                        phenotype_id: int,
@@ -204,25 +202,30 @@ def delete_sample_data(conn: Any,
                  "AND Strain.Name = \"%s\"") % (trait_name,
                                                 phenotype_id,
                                                 str(strain_name)))
-            strain_id, data_id = cursor.fetchone()
 
-            cursor.execute(("DELETE FROM PublishData "
+            # Check if it exists if the data was already deleted:
+            if _result := cursor.fetchone():
+                strain_id, data_id = _result
+
+            # Only run if the strain_id and data_id exist
+            if strain_id and data_id:
+                cursor.execute(("DELETE FROM PublishData "
                             "WHERE StrainId = %s AND Id = %s")
-                           % (strain_id, data_id))
-            deleted_published_data = cursor.rowcount
-
-            # Delete the PublishSE table
-            cursor.execute(("DELETE FROM PublishSE "
-                            "WHERE StrainId = %s AND DataId = %s") %
-                           (strain_id, data_id))
-            deleted_se_data = cursor.rowcount
-
-            # Delete the NStrain table
-            cursor.execute(("DELETE FROM NStrain "
-                            "WHERE StrainId = %s AND DataId = %s" %
-                            (strain_id, data_id)))
-            deleted_n_strains = cursor.rowcount
-        except Exception as e: #pylint: disable=[C0103, W0612]
+                               % (strain_id, data_id))
+                deleted_published_data = cursor.rowcount
+
+                # Delete the PublishSE table
+                cursor.execute(("DELETE FROM PublishSE "
+                                "WHERE StrainId = %s AND DataId = %s") %
+                               (strain_id, data_id))
+                deleted_se_data = cursor.rowcount
+
+                # Delete the NStrain table
+                cursor.execute(("DELETE FROM NStrain "
+                                "WHERE StrainId = %s AND DataId = %s" %
+                                (strain_id, data_id)))
+                deleted_n_strains = cursor.rowcount
+        except Exception as e:  #pylint: disable=[C0103, W0612]
             conn.rollback()
             raise MySQLdb.Error
         conn.commit()
@@ -255,6 +258,13 @@ def insert_sample_data(conn: Any, #pylint: disable=[R0913]
                            (strain_name,))
             strain_id = cursor.fetchone()
 
+            # Return early if an insert already exists!
+            cursor.execute("SELECT Id FROM PublishData where Id = %s "
+                           "AND StrainId = %s",
+                           (data_id, strain_id))
+            if cursor.fetchone():  # This strain already exists
+                return (0, 0, 0)
+
             # Insert the PublishData table
             cursor.execute(("INSERT INTO PublishData (Id, StrainId, value)"
                             "VALUES (%s, %s, %s)"),