about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--wqflask/wqflask/correlation/correlation_gn3_api.py103
-rw-r--r--wqflask/wqflask/templates/test_correlation_page.html5
-rw-r--r--wqflask/wqflask/views.py4
3 files changed, 68 insertions, 44 deletions
diff --git a/wqflask/wqflask/correlation/correlation_gn3_api.py b/wqflask/wqflask/correlation/correlation_gn3_api.py
index 98d52591..e5638b5a 100644
--- a/wqflask/wqflask/correlation/correlation_gn3_api.py
+++ b/wqflask/wqflask/correlation/correlation_gn3_api.py
@@ -41,27 +41,33 @@ def process_samples(start_vars, sample_names, excluded_samples=None):
     return sample_data
 
 
+def merge_correlation_results(correlation_results, target_correlation_results):
+
+    corr_dict = {}
+
+    for trait_dict in target_correlation_results:
+        for trait_name, values in trait_dict.items():
+
+            corr_dict[trait_name] = values
+    for trait_dict in correlation_results:
+        for trait_name, values in trait_dict.items():
+
+            if corr_dict.get(trait_name):
+
+                trait_dict[trait_name].update(corr_dict.get(trait_name))
+
+    return correlation_results
+
+
 def sample_for_trait_lists(corr_results, target_dataset,
                            this_trait, this_dataset, start_vars):
     """interface function for correlation on top results"""
 
-    sample_data = process_samples(
-        start_vars, this_dataset.group.samplelist)
-    target_dataset.get_trait_data(list(sample_data.keys()))
-    # should filter target traits from here
-    _corr_results = corr_results
-
-    this_trait = retrieve_sample_data(this_trait, this_dataset)
-
-    this_trait_data = {
-        "trait_sample_data": sample_data,
-        "trait_id": start_vars["trait_id"]
-    }
-    results = map_shared_keys_to_values(
-        target_dataset.samplelist, target_dataset.trait_data)
+    (this_trait_data, target_dataset) = fetch_sample_data(
+        start_vars, this_trait, this_dataset, target_dataset)
     correlation_results = compute_all_sample_correlation(corr_method="pearson",
                                                          this_trait=this_trait_data,
-                                                         target_dataset=results)
+                                                         target_dataset=target_dataset)
 
     return correlation_results
 
@@ -105,6 +111,23 @@ def lit_for_trait_list(corr_results, this_dataset, this_trait):
     return correlation_results
 
 
+def fetch_sample_data(start_vars, this_trait, this_dataset, target_dataset):
+
+    sample_data = process_samples(
+        start_vars, this_dataset.group.samplelist)
+    target_dataset.get_trait_data(list(sample_data.keys()))
+    this_trait = retrieve_sample_data(this_trait, this_dataset)
+    this_trait_data = {
+        "trait_sample_data": sample_data,
+        "trait_id": start_vars["trait_id"]
+    }
+
+    results = map_shared_keys_to_values(
+        target_dataset.samplelist, target_dataset.trait_data)
+
+    return (this_trait_data, results)
+
+
 def compute_correlation(start_vars, method="pearson"):
     """compute correlation for to call gn3  api"""
     # pylint: disable-msg=too-many-locals
@@ -119,31 +142,11 @@ def compute_correlation(start_vars, method="pearson"):
     corr_input_data = {}
 
     if corr_type == "sample":
-
-        sample_data = process_samples(
-            start_vars, this_dataset.group.samplelist)
-        target_dataset.get_trait_data(list(sample_data.keys()))
-        this_trait = retrieve_sample_data(this_trait, this_dataset)
-        this_trait_data = {
-            "trait_sample_data": sample_data,
-            "trait_id": start_vars["trait_id"]
-        }
-        results = map_shared_keys_to_values(
-            target_dataset.samplelist, target_dataset.trait_data)
+        (this_trait_data, target_dataset) = fetch_sample_data(
+            start_vars, this_trait, this_dataset, target_dataset)
         correlation_results = compute_all_sample_correlation(corr_method=method,
                                                              this_trait=this_trait_data,
-                                                             target_dataset=results)
-
-        # do tissue correaltion
-
-        # code to be use later
-
-        # tissue_result = tissue_for_trait_lists(
-        #     correlation_results, this_dataset, this_trait)
-        # # lit spoils the party so slow
-        # lit_result = lit_for_trait_list(
-        #     correlation_results, this_dataset, this_trait)
-
+                                                             target_dataset=target_dataset)
 
     elif corr_type == "tissue":
         trait_symbol_dict = this_dataset.retrieve_genes("Symbol")
@@ -172,7 +175,29 @@ def compute_correlation(start_vars, method="pearson"):
                 conn=conn, trait_lists=list(geneid_dict.items()),
                 species=species, gene_id=this_trait_geneid)
 
-    return correlation_results[0:corr_return_results]
+    # correlation_results = correlation_results[0:corr_return_results]
+    # if corr_type != "tissue" and this_dataset.type == "ProbeSet" and target_dataset.type == "ProbeSet":
+    #     pass
+    #     # tissue_result = tissue_for_trait_lists(
+    #     #     correlation_results, this_dataset, this_trait)
+
+    #     # correlation_results = merge_correlation_results(
+    #     #     correlation_results, tissue_result)
+
+    # if corr_type != "lit" and this_dataset.type == "ProbeSet" and target_dataset.type == "ProbeSet":
+    #     pass
+    #     # lit  is very slow
+    #     # lit_result = lit_for_trait_list(correlation_results, this_dataset, this_trait)
+    #     # correlation_results = merge_correlation_results(correlation_results,lit_result)
+    # if corr_type != "sample":
+    #     # do sample correlation
+    #     pass
+    correlation_results = correlation_results[0:corr_return_results]
+    correlation_data = {"correlation_results": correlation_results,
+                        "this_trait": this_trait.name,
+                        "target_dataset": start_vars['corr_dataset']}
+
+    return correlation_data
 
 
 def do_lit_correlation(this_trait, this_dataset):
diff --git a/wqflask/wqflask/templates/test_correlation_page.html b/wqflask/wqflask/templates/test_correlation_page.html
index 037e9735..186de4b7 100644
--- a/wqflask/wqflask/templates/test_correlation_page.html
+++ b/wqflask/wqflask/templates/test_correlation_page.html
@@ -42,7 +42,7 @@
 {% block content %}
 
 <div class="correlation-title">
-	<h3>Correlation Results for <span>Dataset_name</span> against <span><a href="">trait_name</a></span> for the top <span>all</span> Results</h3>
+	<h3>Correlation Results for <span>{{target_dataset}}</span> against <span><a href="">{{this_trait}}</a></span> for the top <span>all</span> Results</h3>
 </div>
 <div class="header-toggle-vis">
 	      <h4 style="font-weight: bolder;padding: 5px 3px;">Toggle Columns</h4>
@@ -84,7 +84,6 @@
 <script language="javascript" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.1/js/all.min.js"></script>
 <script language="javascript" type="text/javascript" src="{{ url_for('js', filename='DataTablesExtensions/scroller/js/dataTables.scroller.min.js') }}"></script>
 <script type="text/javascript">
-	console.log("running this script")
 	let correlationResults = {{correlation_results|safe}}
 	// document.querySelector(".content").innerHTML =correlationResults
 	// parse the data
@@ -102,7 +101,7 @@
 		return new_dict;
 	})
 
-console.log(correlationResults) 
+console.log(correlationResults)
 	
 </script>
 
diff --git a/wqflask/wqflask/views.py b/wqflask/wqflask/views.py
index 3c875163..4834ee63 100644
--- a/wqflask/wqflask/views.py
+++ b/wqflask/wqflask/views.py
@@ -972,8 +972,8 @@ def corr_compute_page():
 
 @app.route("/test_corr_compute", methods=["POST"])
 def test_corr_compute_page():
-    correlation_results = compute_correlation(request.form)
-    return render_template("test_correlation_page.html", correlation_results=correlation_results)
+    correlation_data = compute_correlation(request.form)
+    return render_template("test_correlation_page.html", **correlation_data)
     
 @app.route("/corr_matrix", methods=('POST',))
 def corr_matrix_page():