about summary refs log tree commit diff
path: root/wqflask
diff options
context:
space:
mode:
Diffstat (limited to 'wqflask')
-rw-r--r--wqflask/wqflask/correlation/correlation_gn3_api.py39
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py121
2 files changed, 78 insertions, 82 deletions
diff --git a/wqflask/wqflask/correlation/correlation_gn3_api.py b/wqflask/wqflask/correlation/correlation_gn3_api.py
index 6df4eafe..1a375501 100644
--- a/wqflask/wqflask/correlation/correlation_gn3_api.py
+++ b/wqflask/wqflask/correlation/correlation_gn3_api.py
@@ -194,46 +194,13 @@ def compute_correlation(start_vars, method="pearson", compute_all=False):
     method -- Correlation method to be used (pearson, spearman, or bicor)
     compute_all -- Include sample, tissue, and literature correlations (when applicable)
     """
-    # pylint: disable-msg=too-many-locals
+    from wqflask.correlation.rust_correlation import compute_correlation_rust
 
     corr_type = start_vars['corr_type']
-
     method = start_vars['corr_sample_method']
     corr_return_results = int(start_vars.get("corr_return_results", 100))
-    corr_input_data = {}
-
-    from wqflask.correlation.rust_correlation import compute_correlation_rust
-    rust_correlation_results = compute_correlation_rust(
-        start_vars, corr_type, method, corr_return_results)
-    correlation_results = rust_correlation_results["correlation_results"]
-
-    if corr_type == "lit":# elif corr_type == "lit":
-        (this_dataset, this_trait, target_dataset,
-         sample_data) = create_target_this_trait(start_vars)
-        target_dataset_type = target_dataset.type
-        this_dataset_type = this_dataset.type
-        (this_trait_geneid, geneid_dict, species) = do_lit_correlation(
-            this_trait, this_dataset)
-
-        conn = database_connector()
-        with conn:
-            correlation_results = compute_all_lit_correlation(
-                conn=conn, trait_lists=list(geneid_dict.items()),
-                species=species, gene_id=this_trait_geneid)
-
-    correlation_results = correlation_results[0:corr_return_results]
-
-    if (compute_all):
-        correlation_results = compute_corr_for_top_results(
-            start_vars, correlation_results, this_trait, this_dataset,
-            target_dataset, corr_type)
-
-    return {
-        "correlation_results": correlation_results,
-        "this_trait": this_trait.name,
-        "target_dataset": start_vars['corr_dataset'],
-        "return_results": corr_return_results
-    }
+    return compute_correlation_rust(
+        start_vars, corr_type, method, corr_return_results, compute_all)
 
 
 def compute_corr_for_top_results(start_vars,
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 4a22af72..b4435887 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -69,62 +69,91 @@ def merge_results(dict_a: dict, dict_b: dict, dict_c: dict) -> list[dict]:
         }
     return [__merge__(tname, tcorrs) for tname, tcorrs in dict_a.items()]
 
+def __compute_sample_corr__(
+        start_vars: dict, corr_type: str, method: str, n_top: int,
+        target_trait_info: tuple):
+    """Compute the sample correlations"""
+    (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
+    all_samples = json.loads(start_vars["sample_vals"])
+    sample_data = get_sample_corr_data(
+        sample_type=start_vars["corr_samples_group"], all_samples=all_samples,
+        dataset_samples=this_dataset.group.all_samples_ordered())
+    target_dataset.get_trait_data(list(sample_data.keys()))
+
+    target_data = []
+    for (key, val) in target_dataset.trait_data.items():
+        lts = [key] + [str(x) for x in val]
+        r = ",".join(lts)
+        target_data.append(r)
+
+
+    return run_correlation(
+        target_data, list(sample_data.values()), method, ",", corr_type,
+        n_top)
+
+def __compute_tissue_corr__(
+        start_vars: dict, corr_type: str, method: str, n_top: int,
+        target_trait_info: tuple):
+    """Compute the tissue correlations"""
+    (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
+    trait_symbol_dict = this_dataset.retrieve_genes("Symbol")
+    corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values(
+        symbol_list=list(trait_symbol_dict.values()))
 
-def compute_correlation_rust(
-        start_vars: dict, corr_type: str, method: str = "pearson",
-        n_top: int = 500):
-    """function to compute correlation"""
-
-    (this_dataset, this_trait, target_dataset,
-     sample_data) = create_target_this_trait(start_vars)
-
-    if corr_type == "sample":
-
-        all_samples = json.loads(start_vars["sample_vals"])
-        sample_data = get_sample_corr_data(sample_type=start_vars["corr_samples_group"],
-                                           all_samples=all_samples,
-                                           dataset_samples=this_dataset.group.all_samples_ordered())
+    data = parse_tissue_corr_data(
+        symbol_name=this_trait.symbol,
+        symbol_dict=get_trait_symbol_and_tissue_values(
+            symbol_list=[this_trait.symbol]),
+        dataset_symbols=trait_symbol_dict,
+        dataset_vals=corr_result_tissue_vals_dict)
 
-        target_dataset.get_trait_data(list(sample_data.keys()))
+    if data:
+        return run_correlation(data[1], data[0], method, ",", "tissue")
+    return {}
 
-        target_data = []
-        for (key, val) in target_dataset.trait_data.items():
-            lts = [key] + [str(x) for x in val]
-            r = ",".join(lts)
-            target_data.append(r)
+def __compute_lit_corr__(
+        start_vars: dict, corr_type: str, method: str, n_top: int,
+        target_trait_info: tuple):
+    """Compute the literature correlations"""
+    (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
+    target_dataset_type = target_dataset.type
+    this_dataset_type = this_dataset.type
+    (this_trait_geneid, geneid_dict, species) = do_lit_correlation(
+        this_trait, this_dataset)
 
+    with database_connector() as conn:
+        return compute_all_lit_correlation(
+            conn=conn, trait_lists=list(geneid_dict.items()),
+            species=species, gene_id=this_trait_geneid)
+    return {}
 
-        results = run_correlation(
-            target_data, list(sample_data.values()), method, ",", corr_type,
-            n_top)
+def compute_correlation_rust(
+        start_vars: dict, corr_type: str, method: str = "pearson",
+        n_top: int = 500, compute_all: bool = False):
+    """function to compute correlation"""
+    target_trait_info = create_target_this_trait(start_vars)
+    (this_dataset, this_trait, target_dataset, sample_data) = (
+        target_trait_info)
+
+    corr_type_fns = {
+        "sample": __compute_sample_corr__,
+        "tissue": __compute_tissue_corr__,
+        "lit": __compute_lit_corr__
+    }
+    results = corr_type_fns[corr_type](
+        start_vars, corr_type, method, n_top, target_trait_info)
 
+    top_tissue_results = {}
+    top_lit_results = {}
+    if compute_all:
         # example compute of compute both correlation
-        top_tissue_results = compute_top_n_tissue(this_dataset,this_trait,results,method)
+        top_tissue_results = compute_top_n_tissue(
+            this_dataset,this_trait,results,method)
         top_lit_results = compute_top_n_lit(results,this_dataset,this_trait)
 
-        # merging the results
-        results = merge_results(results, top_tissue_results, top_lit_results)
-
-    if corr_type == "tissue":
-
-        trait_symbol_dict = this_dataset.retrieve_genes("Symbol")
-        corr_result_tissue_vals_dict = get_trait_symbol_and_tissue_values(
-            symbol_list=list(trait_symbol_dict.values()))
-
-        data = parse_tissue_corr_data(symbol_name=this_trait.symbol,
-                                      symbol_dict=get_trait_symbol_and_tissue_values(
-                                          symbol_list=[this_trait.symbol]
-                                      ),
-                                      dataset_symbols=trait_symbol_dict,
-                                      dataset_vals=corr_result_tissue_vals_dict)
-
-        if data:
-            results = merge_results(
-                run_correlation(data[1], data[0], method, ",", "tissue"),
-                {}, {})
-
     return {
-        "correlation_results": results,
+        "correlation_results": merge_results(
+            results, top_tissue_results, top_lit_results),
         "this_trait": this_trait.name,
         "target_dataset": start_vars['corr_dataset'],
         "return_results": n_top