about summary refs log tree commit diff
path: root/wqflask
diff options
context:
space:
mode:
Diffstat (limited to 'wqflask')
-rw-r--r--wqflask/base/data_set.py67
-rw-r--r--wqflask/wqflask/correlation/correlation_gn3_api.py51
-rw-r--r--wqflask/wqflask/correlation/pre_computes.py158
-rw-r--r--wqflask/wqflask/correlation/show_corr_results.py19
-rw-r--r--wqflask/wqflask/static/gif/waitAnima2.gifbin0 -> 54013 bytes
-rw-r--r--wqflask/wqflask/templates/loading.html4
6 files changed, 249 insertions, 50 deletions
diff --git a/wqflask/base/data_set.py b/wqflask/base/data_set.py
index 768ad49b..49ece9dd 100644
--- a/wqflask/base/data_set.py
+++ b/wqflask/base/data_set.py
@@ -40,6 +40,8 @@ from base import species
 from base import webqtlConfig
 from flask import Flask, g
 from base.webqtlConfig import TMPDIR
+from urllib.parse import urlparse
+from utility.tools import SQL_URI
 import os
 import math
 import string
@@ -747,15 +749,16 @@ class DataSet:
             and Species.name = '{}'
             """.format(create_in_clause(self.samplelist), *mescape(self.group.species))
         results = dict(g.db.execute(query).fetchall())
-        sample_ids = [results[item] for item in self.samplelist]
+        sample_ids = [results.get(item)
+                      for item in self.samplelist if item is not None]
 
         # MySQL limits the number of tables that can be used in a join to 61,
         # so we break the sample ids into smaller chunks
         # Postgres doesn't have that limit, so we can get rid of this after we transition
         chunk_size = 50
         number_chunks = int(math.ceil(len(sample_ids) / chunk_size))
-        # cached_results = fetch_cached_results(self.name, self.type)
-        cached_results =  None
+
+        cached_results = fetch_cached_results(self.name, self.type)
         if cached_results is None:
             trait_sample_data = []
             for sample_ids_step in chunks.divide_into_chunks(sample_ids, number_chunks):
@@ -800,21 +803,21 @@ class DataSet:
                 results = g.db.execute(query).fetchall()
                 trait_sample_data.append([list(result) for result in results])
 
+            trait_count = len(trait_sample_data[0])
+            self.trait_data = collections.defaultdict(list)
 
-        else:
-            trait_sample_data = cached_results
+            data_start_pos = 1
+            for trait_counter in range(trait_count):
+                trait_name = trait_sample_data[0][trait_counter][0]
+                for chunk_counter in range(int(number_chunks)):
+                    self.trait_data[trait_name] += (
+                        trait_sample_data[chunk_counter][trait_counter][data_start_pos:])
 
-        trait_count = len(trait_sample_data[0])
-        self.trait_data = collections.defaultdict(list)
+            cache_dataset_results(
+                self.name, self.type, self.trait_data)
+        else:
 
-        # put all of the separate data together into a dictionary where the keys are
-        # trait names and values are lists of sample values
-        data_start_pos = 1
-        for trait_counter in range(trait_count):
-            trait_name = trait_sample_data[0][trait_counter][0]
-            for chunk_counter in range(int(number_chunks)):
-                self.trait_data[trait_name] += (
-                    trait_sample_data[chunk_counter][trait_counter][data_start_pos:])
+            self.trait_data = cached_results
 
 
 class PhenotypeDataSet(DataSet):
@@ -1254,25 +1257,30 @@ def geno_mrna_confidentiality(ob):
         return True
 
 
+def parse_db_url():
+    parsed_db = urlparse(SQL_URI)
+
+    return (parsed_db.hostname, parsed_db.username,
+            parsed_db.password, parsed_db.path[1:])
+
+
 def query_table_timestamp(dataset_type: str):
     """function to query the update timestamp of a given dataset_type"""
 
     # computation data and actions
 
+    fetch_db_name = parse_db_url()
     query_update_time = f"""
                     SELECT UPDATE_TIME FROM   information_schema.tables
-                    WHERE  TABLE_SCHEMA = 'db_webqtl_s'
+                    WHERE  TABLE_SCHEMA = '{fetch_db_name[-1]}'
                     AND TABLE_NAME = '{dataset_type}Data'
                 """
 
-    # store the timestamp in redis=
     date_time_obj = g.db.execute(query_update_time).fetchone()[0]
-
-    f = "%Y-%m-%d %H:%M:%S"
-    return date_time_obj.strftime(f)
+    return date_time_obj.strftime("%Y-%m-%d %H:%M:%S")
 
 
-def generate_hash_file(dataset_name: str, dataset_timestamp: str):
+def generate_hash_file(dataset_name: str, dataset_type: str, dataset_timestamp: str):
     """given the trait_name generate a unique name for this"""
     string_unicode = f"{dataset_name}{dataset_timestamp}".encode()
     md5hash = hashlib.md5(string_unicode)
@@ -1280,15 +1288,16 @@ def generate_hash_file(dataset_name: str, dataset_timestamp: str):
 
 
 def cache_dataset_results(dataset_name: str, dataset_type: str, query_results: List):
-    """function to cache dataset query results to file"""
+    """function to cache dataset query results to file
+    input dataset_name and type query_results(already processed in default dict format)
+    """
     # data computations actions
     # store the file path on redis
 
     table_timestamp = query_table_timestamp(dataset_type)
 
-    results = r.set(f"{dataset_type}timestamp", table_timestamp)
 
-    file_name = generate_hash_file(dataset_name, table_timestamp)
+    file_name = generate_hash_file(dataset_name, dataset_type, table_timestamp)
     file_path = os.path.join(TMPDIR, f"{file_name}.json")
 
     with open(file_path, "w") as file_handler:
@@ -1298,19 +1307,13 @@ def cache_dataset_results(dataset_name: str, dataset_type: str, query_results: L
 def fetch_cached_results(dataset_name: str, dataset_type: str):
     """function to fetch the cached results"""
 
-    table_timestamp = r.get(f"{dataset_type}timestamp")
-
-    if table_timestamp is not None:
-        table_timestamp = table_timestamp.decode("utf-8")
-    else:
-        table_timestamp = ""
+    table_timestamp = query_table_timestamp(dataset_type)
 
-    file_name = generate_hash_file(dataset_name, table_timestamp)
+    file_name = generate_hash_file(dataset_name, dataset_type, table_timestamp)
     file_path = os.path.join(TMPDIR, f"{file_name}.json")
     try:
         with open(file_path, "r") as file_handler:
 
             return json.load(file_handler)
     except FileNotFoundError:
-        # take actions continue to fetch dataset results and fetch results
         pass
diff --git a/wqflask/wqflask/correlation/correlation_gn3_api.py b/wqflask/wqflask/correlation/correlation_gn3_api.py
index 20c0d99a..c2acd648 100644
--- a/wqflask/wqflask/correlation/correlation_gn3_api.py
+++ b/wqflask/wqflask/correlation/correlation_gn3_api.py
@@ -1,14 +1,18 @@
 """module that calls the gn3 api's to do the correlation """
 import json
+import time
+from functools import wraps
 
 from wqflask.correlation import correlation_functions
-
+from wqflask.correlation.pre_computes import fetch_precompute_results
+from wqflask.correlation.pre_computes import cache_compute_results
 from base import data_set
 
 from base.trait import create_trait
 from base.trait import retrieve_sample_data
 
 from gn3.computations.correlations import compute_all_sample_correlation
+from gn3.computations.correlations import fast_compute_all_sample_correlation
 from gn3.computations.correlations import map_shared_keys_to_values
 from gn3.computations.correlations import compute_all_lit_correlation
 from gn3.computations.correlations import compute_tissue_correlation
@@ -19,9 +23,11 @@ def create_target_this_trait(start_vars):
     """this function creates the required trait and target dataset for correlation"""
 
     if start_vars['dataset'] == "Temp":
-        this_dataset = data_set.create_dataset(dataset_name="Temp", dataset_type="Temp", group_name=start_vars['group'])
+        this_dataset = data_set.create_dataset(
+            dataset_name="Temp", dataset_type="Temp", group_name=start_vars['group'])
     else:
-        this_dataset = data_set.create_dataset(dataset_name=start_vars['dataset'])
+        this_dataset = data_set.create_dataset(
+            dataset_name=start_vars['dataset'])
     target_dataset = data_set.create_dataset(
         dataset_name=start_vars['corr_dataset'])
     this_trait = create_trait(dataset=this_dataset,
@@ -58,14 +64,20 @@ def test_process_data(this_trait, dataset, start_vars):
     return sample_data
 
 
-def process_samples(start_vars, sample_names, excluded_samples=None):
-    """process samples"""
+def process_samples(start_vars, sample_names=[], excluded_samples=[]):
+    """code to fetch correct samples"""
     sample_data = {}
-    if not excluded_samples:
-        excluded_samples = ()
-        sample_vals_dict = json.loads(start_vars["sample_vals"])
+    sample_vals_dict = json.loads(start_vars["sample_vals"])
+    if sample_names:
         for sample in sample_names:
-            if sample not in excluded_samples and sample in sample_vals_dict:
+            if sample in sample_vals_dict and sample not in excluded_samples:
+                val = sample_vals_dict[sample]
+                if not val.strip().lower() == "x":
+                    sample_data[str(sample)] = float(val)
+
+    else:
+        for sample in sample_vals_dict.keys():
+            if sample not in excluded_samples:
                 val = sample_vals_dict[sample]
                 if not val.strip().lower() == "x":
                     sample_data[str(sample)] = float(val)
@@ -147,6 +159,18 @@ def lit_for_trait_list(corr_results, this_dataset, this_trait):
 
 def fetch_sample_data(start_vars, this_trait, this_dataset, target_dataset):
 
+    corr_samples_group = start_vars["corr_samples_group"]
+    if corr_samples_group == "samples_primary":
+        sample_data = process_samples(
+            start_vars, this_dataset.group.all_samples_ordered())
+
+    elif corr_samples_group == "samples_other":
+        sample_data = process_samples(
+            start_vars, excluded_samples=this_dataset.group.samplelist)
+
+    else:
+        sample_data = process_samples(start_vars)
+
     sample_data = process_samples(
         start_vars, this_dataset.group.all_samples_ordered())
 
@@ -187,9 +211,9 @@ def compute_correlation(start_vars, method="pearson", compute_all=False):
     if corr_type == "sample":
         (this_trait_data, target_dataset_data) = fetch_sample_data(
             start_vars, this_trait, this_dataset, target_dataset)
-        correlation_results = compute_all_sample_correlation(corr_method=method,
-                                                             this_trait=this_trait_data,
-                                                             target_dataset=target_dataset_data)
+
+        correlation_results = compute_all_sample_correlation(
+            corr_method=method, this_trait=this_trait_data, target_dataset=target_dataset_data)
 
     elif corr_type == "tissue":
         trait_symbol_dict = this_dataset.retrieve_genes("Symbol")
@@ -290,7 +314,8 @@ def get_tissue_correlation_input(this_trait, trait_symbol_dict):
     """Gets tissue expression values for the primary trait and target tissues values"""
     primary_trait_tissue_vals_dict = correlation_functions.get_trait_symbol_and_tissue_values(
         symbol_list=[this_trait.symbol])
-    if this_trait.symbol  and this_trait.symbol.lower() in primary_trait_tissue_vals_dict:
+    if this_trait.symbol and this_trait.symbol.lower() in primary_trait_tissue_vals_dict:
+
         primary_trait_tissue_values = primary_trait_tissue_vals_dict[this_trait.symbol.lower(
         )]
         corr_result_tissue_vals_dict = correlation_functions.get_trait_symbol_and_tissue_values(
diff --git a/wqflask/wqflask/correlation/pre_computes.py b/wqflask/wqflask/correlation/pre_computes.py
new file mode 100644
index 00000000..975a53b8
--- /dev/null
+++ b/wqflask/wqflask/correlation/pre_computes.py
@@ -0,0 +1,158 @@
+import json
+import os
+import hashlib
+from pathlib import Path
+
+from base.data_set import query_table_timestamp
+from base.webqtlConfig import TMPDIR
+
+
+def fetch_all_cached_metadata(dataset_name):
+    """in a gvein dataset fetch all the traits metadata"""
+    file_name = generate_filename(dataset_name, suffix="metadata")
+
+    file_path = os.path.join(TMPDIR, file_name)
+
+    try:
+        with open(file_path, "r+") as file_handler:
+            dataset_metadata = json.load(file_handler)
+            return (file_path, dataset_metadata)
+
+    except FileNotFoundError:
+        Path(file_path).touch(exist_ok=True)
+        return (file_path, {})
+
+
+def cache_new_traits_metadata(dataset_metadata: dict, new_traits_metadata, file_path: str):
+    """function to cache the new traits metadata"""
+
+    if bool(new_traits_metadata):
+        dataset_metadata.update(new_traits_metadata)
+             
+    with open(file_path, "w+") as file_handler:
+        json.dump(dataset_metadata, file_handler)
+
+
+def generate_filename(*args, suffix="", file_ext="json"):
+    """given a list of args generate a unique filename"""
+
+    string_unicode = f"{*args,}".encode()
+    return f"{hashlib.md5(string_unicode).hexdigest()}_{suffix}.{file_ext}"
+
+
+def cache_compute_results(base_dataset_type,
+                          base_dataset_name,
+                          target_dataset_name,
+                          corr_method,
+                          correlation_results,
+                          trait_name):
+    """function to cache correlation results for heavy computations"""
+
+    base_timestamp = query_table_timestamp(base_dataset_type)
+
+    target_dataset_timestamp = base_timestamp
+
+    file_name = generate_filename(
+        base_dataset_name, target_dataset_name,
+        base_timestamp, target_dataset_timestamp,
+        suffix="corr_precomputes")
+
+    file_path = os.path.join(TMPDIR, file_name)
+
+    try:
+        with open(file_path, "r+") as json_file_handler:
+            data = json.load(json_file_handler)
+
+            data[trait_name] = correlation_results
+
+            json_file_handler.seek(0)
+
+            json.dump(data, json_file_handler)
+
+            json_file_handler.truncate()
+
+    except FileNotFoundError:
+        with open(file_path, "w+") as file_handler:
+            data = {}
+            data[trait_name] = correlation_results
+
+            json.dump(data, file_handler)
+
+
+def fetch_precompute_results(base_dataset_name,
+                             target_dataset_name,
+                             dataset_type,
+                             trait_name):
+    """function to check for precomputed  results"""
+
+    base_timestamp = target_dataset_timestamp = query_table_timestamp(
+        dataset_type)
+    file_name = generate_filename(
+        base_dataset_name, target_dataset_name,
+        base_timestamp, target_dataset_timestamp,
+        suffix="corr_precomputes")
+
+    file_path = os.path.join(TMPDIR, file_name)
+
+    try:
+        with open(file_path, "r+") as json_handler:
+            correlation_results = json.load(json_handler)
+
+        return correlation_results.get(trait_name)
+
+    except FileNotFoundError:
+        pass
+
+
+def pre_compute_dataset_vs_dataset(base_dataset,
+                                   target_dataset,
+                                   corr_method):
+    """compute sample correlation between dataset vs dataset
+    wn:heavy function should be invoked less frequently
+    input:datasets_data(two dicts),corr_method
+
+    output:correlation results for entire dataset against entire dataset
+    """
+    dataset_correlation_results = {}
+
+    target_traits_data, base_traits_data = get_datasets_data(
+        base_dataset, target_dataset_data)
+
+    for (primary_trait_name, strain_values) in base_traits_data:
+
+        this_trait_data = {
+            "trait_sample_data": strain_values,
+            "trait_id": primary_trait_name
+        }
+
+        trait_correlation_result = compute_all_sample_correlation(
+            corr_method=corr_method,
+            this_trait=this_trait_data,
+            target_dataset=target_traits_data)
+
+        dataset_correlation_results[primary_trait_name] = trait_correlation_result
+
+    return dataset_correlation_results
+
+
+def get_datasets_data(base_dataset, target_dataset_data):
+    """required to pass data in a given format to the pre compute
+    function
+
+    (works for bxd only probeset datasets)
+
+    output:two dicts for datasets with key==trait and value==strains
+    """
+    samples_fetched = base_dataset.group.all_samples_ordered()
+    target_traits_data = target_dataset.get_trait_data(
+        samples_fetched)
+
+    base_traits_data = base_dataset.get_trait_data(
+        samples_fetched)
+
+    target_results = map_shared_keys_to_values(
+        samples_fetched, target_traits_data)
+    base_results = map_shared_keys_to_values(
+        samples_fetched, base_traits_data)
+
+    return (target_results, base_results)
diff --git a/wqflask/wqflask/correlation/show_corr_results.py b/wqflask/wqflask/correlation/show_corr_results.py
index 55915a74..2c820658 100644
--- a/wqflask/wqflask/correlation/show_corr_results.py
+++ b/wqflask/wqflask/correlation/show_corr_results.py
@@ -26,6 +26,9 @@ from base.trait import create_trait, jsonable
 from base.data_set import create_dataset
 from base.webqtlConfig import TMPDIR
 
+from wqflask.correlation.pre_computes import fetch_all_cached_metadata
+from wqflask.correlation.pre_computes import cache_new_traits_metadata
+
 from utility import hmac
 
 
@@ -34,7 +37,8 @@ def set_template_vars(start_vars, correlation_data):
     corr_method = start_vars['corr_sample_method']
 
     if start_vars['dataset'] == "Temp":
-        this_dataset_ob = create_dataset(dataset_name="Temp", dataset_type="Temp", group_name=start_vars['group'])
+        this_dataset_ob = create_dataset(
+            dataset_name="Temp", dataset_type="Temp", group_name=start_vars['group'])
     else:
         this_dataset_ob = create_dataset(dataset_name=start_vars['dataset'])
     this_trait = create_trait(dataset=this_dataset_ob,
@@ -86,13 +90,18 @@ def correlation_json_for_table(correlation_data, this_trait, this_dataset, targe
     corr_results = correlation_data['correlation_results']
     results_list = []
 
+    new_traits_metadata = {}
+
+    (file_path, dataset_metadata) = fetch_all_cached_metadata(
+        target_dataset['name'])
 
     for i, trait_dict in enumerate(corr_results):
         trait_name = list(trait_dict.keys())[0]
         trait = trait_dict[trait_name]
 
-        target_trait = None
-        if  target_trait is None: 
+        target_trait = dataset_metadata.get(trait_name)
+        if target_trait is None:
+
             target_trait_ob = create_trait(dataset=target_dataset_ob,
                                            name=trait_name,
                                            get_qtl_info=True)
@@ -171,6 +180,10 @@ def correlation_json_for_table(correlation_data, this_trait, this_dataset, targe
 
         results_list.append(results_dict)
 
+    cache_new_traits_metadata(dataset_metadata,
+                              new_traits_metadata,
+                              file_path)
+
     return json.dumps(results_list)
 
 
diff --git a/wqflask/wqflask/static/gif/waitAnima2.gif b/wqflask/wqflask/static/gif/waitAnima2.gif
new file mode 100644
index 00000000..50aff7f2
--- /dev/null
+++ b/wqflask/wqflask/static/gif/waitAnima2.gif
Binary files differdiff --git a/wqflask/wqflask/templates/loading.html b/wqflask/wqflask/templates/loading.html
index ccf810b0..b9e31ad0 100644
--- a/wqflask/wqflask/templates/loading.html
+++ b/wqflask/wqflask/templates/loading.html
@@ -66,11 +66,11 @@
           {% endif %}
           {% endif %}
           {% else %}
-          <h1>Loading&nbsp;{{ start_vars.tool_used }}&nbsp;Results...</h1>
+          <h1>&nbsp;{{ start_vars.tool_used }}&nbsp;Computation in progress ...</h1>
           {% endif %}
           <br><br>
           <div style="text-align: center;">
-            <img align="center" src="/static/gif/89.gif">
+            <img align="center" src="/static/gif/waitAnima2.gif">
           </div>
           {% if start_vars.vals_diff|length != 0 and start_vars.transform == "" %}
           <br><br>