about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-08-17 07:26:57 +0300
committerFrederick Muriuki Muriithi2022-08-17 07:26:57 +0300
commitbc589b21d21f157ceaa4b35ca511ff2df8fbc85f (patch)
tree0a07ccc670d90b5ad0d0d1cfdc6e0b5cf9127753
parentee4436538b124b1cd311b396998b1d1d9eb641ef (diff)
parentbadec635d70d8befbe8d2bea1a2c7468546836a6 (diff)
downloadgenenetwork2-bc589b21d21f157ceaa4b35ca511ff2df8fbc85f.tar.gz
Merge branch 'Alexanderlacuna-chores/rust-enhancements' into testing
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py135
-rw-r--r--wqflask/wqflask/correlation/show_corr_results.py6
-rw-r--r--wqflask/wqflask/views.py2
3 files changed, 125 insertions, 18 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 79d08a59..5c22efbf 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -1,7 +1,10 @@
 """module contains integration code for rust-gn3"""
 import json
 from functools import reduce
-from wqflask.correlation.correlation_functions import get_trait_symbol_and_tissue_values
+from utility.db_tools import mescape
+from utility.db_tools import create_in_clause
+from wqflask.correlation.correlation_functions\
+    import get_trait_symbol_and_tissue_values
 from wqflask.correlation.correlation_gn3_api import create_target_this_trait
 from wqflask.correlation.correlation_gn3_api import lit_for_trait_list
 from wqflask.correlation.correlation_gn3_api import do_lit_correlation
@@ -12,11 +15,92 @@ from gn3.computations.rust_correlation import parse_tissue_corr_data
 from gn3.db_utils import database_connector
 
 
+def chunk_dataset(dataset, steps, name):
+
+    results = []
+
+    query = """
+            SELECT ProbeSetXRef.DataId,ProbeSet.Name
+            FROM ProbeSet, ProbeSetXRef, ProbeSetFreeze
+            WHERE ProbeSetFreeze.Name = '{}' AND
+                  ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id AND
+                  ProbeSetXRef.ProbeSetId = ProbeSet.Id
+    """.format(name)
+
+    with database_connector() as conn:
+        with conn.cursor() as curr:
+            curr.execute(query)
+            traits_name_dict = dict(curr.fetchall())
+
+    for i in range(0, len(dataset), steps):
+        matrix = list(dataset[i:i + steps])
+        trait_name = traits_name_dict[matrix[0][0]]
+
+        strains = [trait_name] + [str(value)
+                                  for (trait_name, strain, value) in matrix]
+        results.append(",".join(strains))
+
+    return results
+
+
+def compute_top_n_sample(start_vars, dataset, trait_list):
+    """check if dataset is of type probeset"""
+
+    if dataset.type.lower() != "probeset":
+        return {}
+
+    def __fetch_sample_ids__(samples_vals, samples_group):
+        all_samples = json.loads(samples_vals)
+        sample_data = get_sample_corr_data(
+            sample_type=samples_group, all_samples=all_samples,
+            dataset_samples=dataset.group.all_samples_ordered())
+
+        with database_connector() as conn:
+            with conn.cursor() as curr:
+                curr.execute(
+                    """
+                    SELECT Strain.Name, Strain.Id FROM Strain, Species
+                    WHERE Strain.Name IN {}
+                    and Strain.SpeciesId=Species.Id
+                    and Species.name = '{}'
+                    """.format(create_in_clause(list(sample_data.keys())),
+                               *mescape(dataset.group.species)))
+                return (sample_data, dict(curr.fetchall()))
+
+    (sample_data, sample_ids) = __fetch_sample_ids__(
+        start_vars["sample_vals"], start_vars["corr_samples_group"])
+
+    with database_connector() as conn:
+        with conn.cursor() as curr:
+            # fetching strain data in bulk
+            curr.execute(
+                """
+                SELECT * from ProbeSetData
+                where StrainID in {}
+                and id in (SELECT ProbeSetXRef.DataId
+                FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze)
+                WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id
+                and ProbeSetFreeze.Name = '{}'
+                and ProbeSet.Name in {}
+                and ProbeSet.Id = ProbeSetXRef.ProbeSetId)
+                """.format(
+                    create_in_clause(list(sample_ids.values())),
+                    dataset.name,
+                    create_in_clause(trait_list)))
+
+            corr_data = chunk_dataset(
+                list(curr.fetchall()), len(sample_ids.values()), dataset.name)
+
+        return run_correlation(
+            corr_data, list(sample_data.values()), "pearson", ",")
+
+
 def compute_top_n_lit(corr_results, this_dataset, this_trait) -> dict:
     (this_trait_geneid, geneid_dict, species) = do_lit_correlation(
         this_trait, this_dataset)
 
-    geneid_dict = {trait_name: geneid for (trait_name, geneid) in geneid_dict.items() if
+    geneid_dict = {trait_name: geneid for (trait_name, geneid)
+                   in geneid_dict.items() if
                    corr_results.get(trait_name)}
     with database_connector() as conn:
         return reduce(
@@ -69,6 +153,7 @@ def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]:
         }
     return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()]
 
+
 def __compute_sample_corr__(
         start_vars: dict, corr_type: str, method: str, n_top: int,
         target_trait_info: tuple):
@@ -86,11 +171,11 @@ def __compute_sample_corr__(
         r = ",".join(lts)
         target_data.append(r)
 
-
     return run_correlation(
         target_data, list(sample_data.values()), method, ",", corr_type,
         n_top)
 
+
 def __compute_tissue_corr__(
         start_vars: dict, corr_type: str, method: str, n_top: int,
         target_trait_info: tuple):
@@ -111,6 +196,7 @@ def __compute_tissue_corr__(
         return run_correlation(data[1], data[0], method, ",", "tissue")
     return {}
 
+
 def __compute_lit_corr__(
         start_vars: dict, corr_type: str, method: str, n_top: int,
         target_trait_info: tuple):
@@ -130,15 +216,16 @@ def __compute_lit_corr__(
             {})
     return {}
 
+
 def compute_correlation_rust(
         start_vars: dict, corr_type: str, method: str = "pearson",
-        n_top: int = 500, compute_all: bool = False):
+        n_top: int = 500, should_compute_all: bool = False):
     """function to compute correlation"""
     target_trait_info = create_target_this_trait(start_vars)
     (this_dataset, this_trait, target_dataset, sample_data) = (
         target_trait_info)
 
-    ## Replace this with `match ...` once we hit Python 3.10
+    # Replace this with `match ...` once we hit Python 3.10
     corr_type_fns = {
         "sample": __compute_sample_corr__,
         "tissue": __compute_tissue_corr__,
@@ -146,19 +233,39 @@ def compute_correlation_rust(
     }
     results = corr_type_fns[corr_type](
         start_vars, corr_type, method, n_top, target_trait_info)
-    ## END: Replace this with `match ...` once we hit Python 3.10
 
-    top_tissue_results = {}
-    top_lit_results = {}
-    if compute_all:
-        # example compute of compute both correlation
-        top_tissue_results = compute_top_n_tissue(
-            this_dataset,this_trait,results,method)
-        top_lit_results = compute_top_n_lit(results,this_dataset,this_trait)
+    # END: Replace this with `match ...` once we hit Python 3.10
+
+    top_a = top_b = {}
+
+    if should_compute_all:
+
+        if corr_type == "sample":
+
+            top_a = compute_top_n_tissue(
+                this_dataset, this_trait, results, method)
+
+            top_b = compute_top_n_lit(results, this_dataset, this_trait)
+
+        elif corr_type == "lit":
+
+            # currently fails for lit
+
+            top_a = compute_top_n_sample(
+                start_vars, target_dataset, list(results.keys()))
+            top_b = compute_top_n_tissue(
+                this_dataset, this_trait, results, method)
+
+        else:
+
+            top_a = compute_top_n_sample(
+                start_vars, target_dataset, list(results.keys()))
+
+            top_b = compute_top_n_lit(results, this_dataset, this_trait)
 
     return {
         "correlation_results": merge_results(
-            results, top_tissue_results, top_lit_results),
+            results, top_a, top_b),
         "this_trait": this_trait.name,
         "target_dataset": start_vars['corr_dataset'],
         "return_results": n_top
diff --git a/wqflask/wqflask/correlation/show_corr_results.py b/wqflask/wqflask/correlation/show_corr_results.py
index 1c391386..f5fdd9b3 100644
--- a/wqflask/wqflask/correlation/show_corr_results.py
+++ b/wqflask/wqflask/correlation/show_corr_results.py
@@ -121,9 +121,9 @@ def correlation_json_for_table(correlation_data, this_trait, this_dataset, targe
         results_dict['dataset'] = target_dataset['name']
         results_dict['hmac'] = hmac.data_hmac(
             '{}:{}'.format(target_trait['name'], target_dataset['name']))
-        results_dict['sample_r'] = f"{float(trait['corr_coefficient']):.3f}"
-        results_dict['num_overlap'] = trait['num_overlap']
-        results_dict['sample_p'] = f"{float(trait['p_value']):.3e}"
+        results_dict['sample_r'] = f"{float(trait.get('corr_coefficient',0.0)):.3f}"
+        results_dict['num_overlap'] = trait.get('num_overlap',0)
+        results_dict['sample_p'] = f"{float(trait.get('p_value',0)):.3e}"
         if target_dataset['type'] == "ProbeSet":
             results_dict['symbol'] = target_trait['symbol']
             results_dict['description'] = "N/A"
diff --git a/wqflask/wqflask/views.py b/wqflask/wqflask/views.py
index 2e13451d..e054cd49 100644
--- a/wqflask/wqflask/views.py
+++ b/wqflask/wqflask/views.py
@@ -876,7 +876,7 @@ def test_corr_compute_page():
     correlation_results = compute_correlation_rust(start_vars,
                                                    start_vars["corr_type"],
                                                    start_vars['corr_sample_method'],
-                                                   int(start_vars.get("corr_return_results", 500)))
+                                                   int(start_vars.get("corr_return_results", 500)),True)
 
     correlation_results = set_template_vars(request.form, correlation_results)