about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py124
1 files changed, 119 insertions, 5 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 3628f549..94720f54 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -1,6 +1,9 @@
 """module contains integration code for rust-gn3"""
 import json
 from functools import reduce
+from flask import g
+from utility.db_tools import mescape
+from utility.db_tools import create_in_clause
 from wqflask.correlation.correlation_functions import get_trait_symbol_and_tissue_values
 from wqflask.correlation.correlation_gn3_api import create_target_this_trait
 from wqflask.correlation.correlation_gn3_api import lit_for_trait_list
@@ -12,6 +15,106 @@ from gn3.computations.rust_correlation import parse_tissue_corr_data
 from gn3.db_utils import database_connector
 
 
+
+
+def chunk_dataset(dataset,steps,name):
+
+    results = []
+
+    query = """
+            SELECT ProbeSetXRef.DataId,ProbeSet.Name
+            FROM ProbeSet, ProbeSetXRef, ProbeSetFreeze
+            WHERE ProbeSetFreeze.Name = '{}' AND
+                  ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id AND
+                  ProbeSetXRef.ProbeSetId = ProbeSet.Id
+    """.format(name)
+
+    traits_name_dict = dict(g.db.execute(query).fetchall())
+
+
+    for i in range(0, len(dataset), steps):
+        matrix = list(dataset[i:i + steps])
+        trait_name = traits_name_dict[matrix[0][0]]
+
+        strains = [trait_name] + [str(value) for (trait_name, strain, value) in matrix]
+        results.append(",".join(strains))
+
+    breakpoint()
+    return results
+
+
+def compute_top_n_sample(start_vars, dataset, trait_list):
+    """only if dataset is of type probeset"""
+
+
+
+
+    def __fetch_sample_ids__(samples_vals, samples_group):
+
+
+        all_samples = json.loads(samples_vals)
+        sample_data = get_sample_corr_data(
+            sample_type=samples_group, all_samples=all_samples,
+            dataset_samples=dataset.group.all_samples_ordered())
+
+
+        with database_connector() as conn:
+
+            curr = conn.cursor()
+
+            curr.execute(
+                """
+                    SELECT Strain.Name, Strain.Id FROM Strain, Species
+                    WHERE Strain.Name IN {}
+                    and Strain.SpeciesId=Species.Id
+                    and Species.name = '{}'
+                    """.format(create_in_clause(list(sample_data.keys())),
+                               *mescape(dataset.group.species))
+
+            )
+
+            return dict(curr.fetchall())
+
+
+
+
+
+
+
+
+
+  
+
+    ty = __fetch_sample_ids__(start_vars["sample_vals"], start_vars["corr_samples_group"])
+
+
+
+    with database_connector() as conn:
+
+        curr = conn.cursor()
+
+        curr.execute(
+
+        """
+            SELECT * from ProbeSetData
+            where StrainID in {}
+            and id in (SELECT ProbeSetXRef.DataId
+            FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze)
+            WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id
+            and ProbeSetFreeze.Name = '{}'
+            and ProbeSet.Name in {}
+            and ProbeSet.Id = ProbeSetXRef.ProbeSetId)
+           """.format(create_in_clause(list(ty.values())),dataset.name,create_in_clause(trait_list))
+
+
+        )
+
+
+
+
+        return chunk_dataset(list(curr.fetchall()),len(ty.values()),dataset.name)
+
+
 def compute_top_n_lit(corr_results, this_dataset, this_trait) -> dict:
     (this_trait_geneid, geneid_dict, species) = do_lit_correlation(
         this_trait, this_dataset)
@@ -69,6 +172,7 @@ def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]:
         }
     return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()]
 
+
 def __compute_sample_corr__(
         start_vars: dict, corr_type: str, method: str, n_top: int,
         target_trait_info: tuple):
@@ -86,11 +190,11 @@ def __compute_sample_corr__(
         r = ",".join(lts)
         target_data.append(r)
 
-
     return run_correlation(
         target_data, list(sample_data.values()), method, ",", corr_type,
         n_top)
 
+
 def __compute_tissue_corr__(
         start_vars: dict, corr_type: str, method: str, n_top: int,
         target_trait_info: tuple):
@@ -111,6 +215,7 @@ def __compute_tissue_corr__(
         return run_correlation(data[1], data[0], method, ",", "tissue")
     return {}
 
+
 def __compute_lit_corr__(
         start_vars: dict, corr_type: str, method: str, n_top: int,
         target_trait_info: tuple):
@@ -127,6 +232,7 @@ def __compute_lit_corr__(
             species=species, gene_id=this_trait_geneid)
     return {}
 
+
 def compute_correlation_rust(
         start_vars: dict, corr_type: str, method: str = "pearson",
         n_top: int = 500, compute_all: bool = False):
@@ -135,7 +241,7 @@ def compute_correlation_rust(
     (this_dataset, this_trait, target_dataset, sample_data) = (
         target_trait_info)
 
-    ## Replace this with `match ...` once we hit Python 3.10
+    # Replace this with `match ...` once we hit Python 3.10
     corr_type_fns = {
         "sample": __compute_sample_corr__,
         "tissue": __compute_tissue_corr__,
@@ -143,15 +249,23 @@ def compute_correlation_rust(
     }
     results = corr_type_fns[corr_type](
         start_vars, corr_type, method, n_top, target_trait_info)
-    ## END: Replace this with `match ...` once we hit Python 3.10
+    # END: Replace this with `match ...` once we hit Python 3.10
 
     top_tissue_results = {}
     top_lit_results = {}
+
+
+    results = compute_top_n_sample(start_vars,target_dataset,list(results.keys()))
+
+
+
+    breakpoint()
+
     if compute_all:
         # example compute of compute both correlation
         top_tissue_results = compute_top_n_tissue(
-            this_dataset,this_trait,results,method)
-        top_lit_results = compute_top_n_lit(results,this_dataset,this_trait)
+            this_dataset, this_trait, results, method)
+        top_lit_results = compute_top_n_lit(results, this_dataset, this_trait)
 
     return {
         "correlation_results": merge_results(