about summary refs log tree commit diff
path: root/gn2/wqflask/api/correlation.py
diff options
context:
space:
mode:
authorArun Isaac2023-12-29 18:55:37 +0000
committerArun Isaac2023-12-29 19:01:46 +0000
commit204a308be0f741726b9a620d88fbc22b22124c81 (patch)
treeb3cf66906674020b530c844c2bb4982c8a0e2d39 /gn2/wqflask/api/correlation.py
parent83062c75442160427b50420161bfcae2c5c34c84 (diff)
downloadgenenetwork2-204a308be0f741726b9a620d88fbc22b22124c81.tar.gz
Namespace all modules under gn2.
We move all modules under a gn2 directory. This is important for
"correct" packaging and deployment as a Guix service.
Diffstat (limited to 'gn2/wqflask/api/correlation.py')
-rw-r--r--gn2/wqflask/api/correlation.py244
1 files changed, 244 insertions, 0 deletions
diff --git a/gn2/wqflask/api/correlation.py b/gn2/wqflask/api/correlation.py
new file mode 100644
index 00000000..090d13ac
--- /dev/null
+++ b/gn2/wqflask/api/correlation.py
@@ -0,0 +1,244 @@
+import collections
+import scipy
+import numpy
+
+from gn2.base import data_set
+from gn2.base.trait import create_trait, retrieve_sample_data
+from gn2.utility import corr_result_helpers
+from gn2.utility.tools import get_setting
+from gn2.wqflask.correlation import correlation_functions
+from gn2.wqflask.database import database_connection
+
+def do_correlation(start_vars):
+    if 'db' not in start_vars:
+        raise ValueError("'db' not found!")
+    if 'target_db' not in start_vars:
+        raise ValueError("'target_db' not found!")
+    if 'trait_id' not in start_vars:
+        raise ValueError("'trait_id' not found!")
+
+    this_dataset = data_set.create_dataset(dataset_name=start_vars['db'])
+    target_dataset = data_set.create_dataset(
+        dataset_name=start_vars['target_db'])
+    this_trait = create_trait(dataset=this_dataset,
+                              name=start_vars['trait_id'])
+    this_trait = retrieve_sample_data(this_trait, this_dataset)
+
+    corr_params = init_corr_params(start_vars)
+
+    corr_results = calculate_results(
+        this_trait, this_dataset, target_dataset, corr_params)
+
+    final_results = []
+    for _trait_counter, trait in enumerate(list(corr_results.keys())[:corr_params['return_count']]):
+        if corr_params['type'] == "tissue":
+            [sample_r, num_overlap, sample_p, symbol] = corr_results[trait]
+            result_dict = {
+                "trait": trait,
+                "sample_r": sample_r,
+                "#_strains": num_overlap,
+                "p_value": sample_p,
+                "symbol": symbol
+            }
+        elif corr_params['type'] == "literature" or corr_params['type'] == "lit":
+            [gene_id, sample_r] = corr_results[trait]
+            result_dict = {
+                "trait": trait,
+                "sample_r": sample_r,
+                "gene_id": gene_id
+            }
+        else:
+            [sample_r, sample_p, num_overlap] = corr_results[trait]
+            result_dict = {
+                "trait": trait,
+                "sample_r": sample_r,
+                "#_strains": num_overlap,
+                "p_value": sample_p
+            }
+        final_results.append(result_dict)
+    return final_results
+
+
+def calculate_results(this_trait, this_dataset, target_dataset, corr_params):
+    corr_results = {}
+
+    target_dataset.get_trait_data()
+
+    if corr_params['type'] == "tissue":
+        trait_symbol_dict = this_dataset.retrieve_genes("Symbol")
+        corr_results = do_tissue_correlation_for_all_traits(
+            this_trait, trait_symbol_dict, corr_params)
+        sorted_results = collections.OrderedDict(sorted(list(corr_results.items()),
+                                                        key=lambda t: -abs(t[1][1])))
+    # ZS: Just so a user can use either "lit" or "literature"
+    elif corr_params['type'] == "literature" or corr_params['type'] == "lit":
+        trait_geneid_dict = this_dataset.retrieve_genes("GeneId")
+        corr_results = do_literature_correlation_for_all_traits(
+            this_trait, this_dataset, trait_geneid_dict, corr_params)
+        sorted_results = collections.OrderedDict(sorted(list(corr_results.items()),
+                                                        key=lambda t: -abs(t[1][1])))
+    else:
+        for target_trait, target_vals in list(target_dataset.trait_data.items()):
+            result = get_sample_r_and_p_values(
+                this_trait, this_dataset, target_vals, target_dataset, corr_params['type'])
+            if result is not None:
+                corr_results[target_trait] = result
+
+        sorted_results = collections.OrderedDict(
+            sorted(list(corr_results.items()), key=lambda t: -abs(t[1][0])))
+
+    return sorted_results
+
+
+def do_tissue_correlation_for_all_traits(this_trait, trait_symbol_dict, corr_params, tissue_dataset_id=1):
+    # Gets tissue expression values for the primary trait
+    primary_trait_tissue_vals_dict = correlation_functions.get_trait_symbol_and_tissue_values(
+        symbol_list=[this_trait.symbol])
+
+    if this_trait.symbol.lower() in primary_trait_tissue_vals_dict:
+        primary_trait_tissue_values = primary_trait_tissue_vals_dict[this_trait.symbol.lower(
+        )]
+
+        corr_result_tissue_vals_dict = correlation_functions.get_trait_symbol_and_tissue_values(
+            symbol_list=list(trait_symbol_dict.values()))
+
+        tissue_corr_data = {}
+        for trait, symbol in list(trait_symbol_dict.items()):
+            if symbol and symbol.lower() in corr_result_tissue_vals_dict:
+                this_trait_tissue_values = corr_result_tissue_vals_dict[symbol.lower(
+                )]
+
+                result = correlation_functions.cal_zero_order_corr_for_tiss(primary_trait_tissue_values,
+                                                                            this_trait_tissue_values,
+                                                                            corr_params['method'])
+
+                tissue_corr_data[trait] = [
+                    result[0], result[1], result[2], symbol]
+
+        return tissue_corr_data
+
+
+def do_literature_correlation_for_all_traits(this_trait, target_dataset, trait_geneid_dict, corr_params):
+    input_trait_mouse_gene_id = convert_to_mouse_gene_id(
+        target_dataset.group.species.lower(), this_trait.geneid)
+
+    lit_corr_data = {}
+    for trait, gene_id in list(trait_geneid_dict.items()):
+        mouse_gene_id = convert_to_mouse_gene_id(
+            target_dataset.group.species.lower(), gene_id)
+
+        if mouse_gene_id and str(mouse_gene_id).find(";") == -1:
+            result = ""
+            with database_connection(get_setting("SQL_URI")) as conn:
+                with conn.cursor() as cursor:
+                    cursor.execute(
+                        ("SELECT value FROM LCorrRamin3 "
+                         "WHERE GeneId1=%s AND GeneId2=%s"),
+                        (mouse_gene_id,
+                         input_trait_mouse_gene_id))
+                    result = cursor.fetchone()
+                    if not result:
+                        cursor.execute(
+                            ("SELECT value FROM LCorrRamin3 "
+                             "WHERE GeneId2=%s AND GeneId1=%s"),
+                            (mouse_gene_id,
+                             input_trait_mouse_gene_id))
+                        result = cursor.fetchone()
+            if result:
+                lit_corr = result[0]
+                lit_corr_data[trait] = [gene_id, lit_corr]
+            else:
+                lit_corr_data[trait] = [gene_id, 0]
+        else:
+            lit_corr_data[trait] = [gene_id, 0]
+
+    return lit_corr_data
+
+
+def get_sample_r_and_p_values(this_trait, this_dataset, target_vals, target_dataset, type):
+    """
+    Calculates the sample r (or rho) and p-value
+
+    Given a primary trait and a target trait's sample values,
+    calculates either the pearson r or spearman rho and the p-value
+    using the corresponding scipy functions.
+    """
+
+    this_trait_vals = []
+    shared_target_vals = []
+    for i, sample in enumerate(target_dataset.group.samplelist):
+        if sample in this_trait.data:
+            this_sample_value = this_trait.data[sample].value
+            target_sample_value = target_vals[i]
+            this_trait_vals.append(this_sample_value)
+            shared_target_vals.append(target_sample_value)
+
+    this_trait_vals, shared_target_vals, num_overlap = corr_result_helpers.normalize_values(
+        this_trait_vals, shared_target_vals)
+
+    if type == 'pearson':
+        sample_r, sample_p = scipy.stats.pearsonr(
+            this_trait_vals, shared_target_vals)
+    else:
+        sample_r, sample_p = scipy.stats.spearmanr(
+            this_trait_vals, shared_target_vals)
+
+    if num_overlap > 5:
+        if numpy.isnan(sample_r):
+            return None
+        else:
+            return [sample_r, sample_p, num_overlap]
+
+
+def convert_to_mouse_gene_id(species=None, gene_id=None):
+    """If the species is rat or human, translate the gene_id to the mouse geneid
+
+    If there is no input gene_id or there's no corresponding mouse gene_id, return None
+
+    """
+    if not gene_id:
+        return None
+
+    mouse_gene_id = None
+    with database_connection(get_setting("SQL_URI")) as conn:
+        with conn.cursor() as cursor:
+            if species == 'mouse':
+                mouse_gene_id = gene_id
+            elif species == 'rat':
+                cursor.execute(
+                    ("SELECT mouse FROM GeneIDXRef "
+                     "WHERE rat=%s"), gene_id)
+                result = cursor.fetchone()
+                if result:
+                    mouse_gene_id = result[0]
+            elif species == 'human':
+                cursor.execute(
+                    "SELECT mouse FROM GeneIDXRef "
+                    "WHERE human=%s", gene_id)
+                result = cursor.fetchone()
+                if result:
+                    mouse_gene_id = result[0]
+    return mouse_gene_id
+
+
+def init_corr_params(start_vars):
+    method = "pearson"
+    if 'method' in start_vars:
+        method = start_vars['method']
+
+    type = "sample"
+    if 'type' in start_vars:
+        type = start_vars['type']
+
+    return_count = 500
+    if 'return_count' in start_vars:
+        assert(start_vars['return_count'].isdigit())
+        return_count = int(start_vars['return_count'])
+
+    corr_params = {
+        'method': method,
+        'type': type,
+        'return_count': return_count
+    }
+
+    return corr_params