diff options
author | BonfaceKilz | 2021-11-24 12:48:33 +0300 |
---|---|---|
committer | GitHub | 2021-11-24 12:48:33 +0300 |
commit | 16116373899b44e0f0a3894f1f2e5b7f60a5d498 (patch) | |
tree | 795043155c61640a914f75e3a6094273835605af | |
parent | 41e742904ff4cf35abbd885eeb98902a05d3be80 (diff) | |
parent | fffeb91789943a3c7db5a72d66405e2a0459ed44 (diff) | |
download | genenetwork2-16116373899b44e0f0a3894f1f2e5b7f60a5d498.tar.gz |
Merge pull request #624 from Alexanderlacuna/feature/correlation-optimization2
Feature/correlation optimization2
-rw-r--r-- | wqflask/base/data_set.py | 67 | ||||
-rw-r--r-- | wqflask/wqflask/correlation/correlation_gn3_api.py | 51 | ||||
-rw-r--r-- | wqflask/wqflask/correlation/pre_computes.py | 158 | ||||
-rw-r--r-- | wqflask/wqflask/correlation/show_corr_results.py | 19 | ||||
-rw-r--r-- | wqflask/wqflask/static/gif/waitAnima2.gif | bin | 0 -> 54013 bytes | |||
-rw-r--r-- | wqflask/wqflask/templates/loading.html | 4 |
6 files changed, 249 insertions, 50 deletions
diff --git a/wqflask/base/data_set.py b/wqflask/base/data_set.py index 768ad49b..49ece9dd 100644 --- a/wqflask/base/data_set.py +++ b/wqflask/base/data_set.py @@ -40,6 +40,8 @@ from base import species from base import webqtlConfig from flask import Flask, g from base.webqtlConfig import TMPDIR +from urllib.parse import urlparse +from utility.tools import SQL_URI import os import math import string @@ -747,15 +749,16 @@ class DataSet: and Species.name = '{}' """.format(create_in_clause(self.samplelist), *mescape(self.group.species)) results = dict(g.db.execute(query).fetchall()) - sample_ids = [results[item] for item in self.samplelist] + sample_ids = [results.get(item) + for item in self.samplelist if item is not None] # MySQL limits the number of tables that can be used in a join to 61, # so we break the sample ids into smaller chunks # Postgres doesn't have that limit, so we can get rid of this after we transition chunk_size = 50 number_chunks = int(math.ceil(len(sample_ids) / chunk_size)) - # cached_results = fetch_cached_results(self.name, self.type) - cached_results = None + + cached_results = fetch_cached_results(self.name, self.type) if cached_results is None: trait_sample_data = [] for sample_ids_step in chunks.divide_into_chunks(sample_ids, number_chunks): @@ -800,21 +803,21 @@ class DataSet: results = g.db.execute(query).fetchall() trait_sample_data.append([list(result) for result in results]) + trait_count = len(trait_sample_data[0]) + self.trait_data = collections.defaultdict(list) - else: - trait_sample_data = cached_results + data_start_pos = 1 + for trait_counter in range(trait_count): + trait_name = trait_sample_data[0][trait_counter][0] + for chunk_counter in range(int(number_chunks)): + self.trait_data[trait_name] += ( + trait_sample_data[chunk_counter][trait_counter][data_start_pos:]) - trait_count = len(trait_sample_data[0]) - self.trait_data = collections.defaultdict(list) + cache_dataset_results( + self.name, self.type, self.trait_data) + else: - # put all of the separate data together into a dictionary where the keys are - # trait names and values are lists of sample values - data_start_pos = 1 - for trait_counter in range(trait_count): - trait_name = trait_sample_data[0][trait_counter][0] - for chunk_counter in range(int(number_chunks)): - self.trait_data[trait_name] += ( - trait_sample_data[chunk_counter][trait_counter][data_start_pos:]) + self.trait_data = cached_results class PhenotypeDataSet(DataSet): @@ -1254,25 +1257,30 @@ def geno_mrna_confidentiality(ob): return True +def parse_db_url(): + parsed_db = urlparse(SQL_URI) + + return (parsed_db.hostname, parsed_db.username, + parsed_db.password, parsed_db.path[1:]) + + def query_table_timestamp(dataset_type: str): """function to query the update timestamp of a given dataset_type""" # computation data and actions + fetch_db_name = parse_db_url() query_update_time = f""" SELECT UPDATE_TIME FROM information_schema.tables - WHERE TABLE_SCHEMA = 'db_webqtl_s' + WHERE TABLE_SCHEMA = '{fetch_db_name[-1]}' AND TABLE_NAME = '{dataset_type}Data' """ - # store the timestamp in redis= date_time_obj = g.db.execute(query_update_time).fetchone()[0] - - f = "%Y-%m-%d %H:%M:%S" - return date_time_obj.strftime(f) + return date_time_obj.strftime("%Y-%m-%d %H:%M:%S") -def generate_hash_file(dataset_name: str, dataset_timestamp: str): +def generate_hash_file(dataset_name: str, dataset_type: str, dataset_timestamp: str): """given the trait_name generate a unique name for this""" string_unicode = f"{dataset_name}{dataset_timestamp}".encode() md5hash = hashlib.md5(string_unicode) @@ -1280,15 +1288,16 @@ def generate_hash_file(dataset_name: str, dataset_timestamp: str): def cache_dataset_results(dataset_name: str, dataset_type: str, query_results: List): - """function to cache dataset query results to file""" + """function to cache dataset query results to file + input dataset_name and type query_results(already processed in default dict format) + """ # data computations actions # store the file path on redis table_timestamp = query_table_timestamp(dataset_type) - results = r.set(f"{dataset_type}timestamp", table_timestamp) - file_name = generate_hash_file(dataset_name, table_timestamp) + file_name = generate_hash_file(dataset_name, dataset_type, table_timestamp) file_path = os.path.join(TMPDIR, f"{file_name}.json") with open(file_path, "w") as file_handler: @@ -1298,19 +1307,13 @@ def cache_dataset_results(dataset_name: str, dataset_type: str, query_results: L def fetch_cached_results(dataset_name: str, dataset_type: str): """function to fetch the cached results""" - table_timestamp = r.get(f"{dataset_type}timestamp") - - if table_timestamp is not None: - table_timestamp = table_timestamp.decode("utf-8") - else: - table_timestamp = "" + table_timestamp = query_table_timestamp(dataset_type) - file_name = generate_hash_file(dataset_name, table_timestamp) + file_name = generate_hash_file(dataset_name, dataset_type, table_timestamp) file_path = os.path.join(TMPDIR, f"{file_name}.json") try: with open(file_path, "r") as file_handler: return json.load(file_handler) except FileNotFoundError: - # take actions continue to fetch dataset results and fetch results pass diff --git a/wqflask/wqflask/correlation/correlation_gn3_api.py b/wqflask/wqflask/correlation/correlation_gn3_api.py index 20c0d99a..c2acd648 100644 --- a/wqflask/wqflask/correlation/correlation_gn3_api.py +++ b/wqflask/wqflask/correlation/correlation_gn3_api.py @@ -1,14 +1,18 @@ """module that calls the gn3 api's to do the correlation """ import json +import time +from functools import wraps from wqflask.correlation import correlation_functions - +from wqflask.correlation.pre_computes import fetch_precompute_results +from wqflask.correlation.pre_computes import cache_compute_results from base import data_set from base.trait import create_trait from base.trait import retrieve_sample_data from gn3.computations.correlations import compute_all_sample_correlation +from gn3.computations.correlations import fast_compute_all_sample_correlation from gn3.computations.correlations import map_shared_keys_to_values from gn3.computations.correlations import compute_all_lit_correlation from gn3.computations.correlations import compute_tissue_correlation @@ -19,9 +23,11 @@ def create_target_this_trait(start_vars): """this function creates the required trait and target dataset for correlation""" if start_vars['dataset'] == "Temp": - this_dataset = data_set.create_dataset(dataset_name="Temp", dataset_type="Temp", group_name=start_vars['group']) + this_dataset = data_set.create_dataset( + dataset_name="Temp", dataset_type="Temp", group_name=start_vars['group']) else: - this_dataset = data_set.create_dataset(dataset_name=start_vars['dataset']) + this_dataset = data_set.create_dataset( + dataset_name=start_vars['dataset']) target_dataset = data_set.create_dataset( dataset_name=start_vars['corr_dataset']) this_trait = create_trait(dataset=this_dataset, @@ -58,14 +64,20 @@ def test_process_data(this_trait, dataset, start_vars): return sample_data -def process_samples(start_vars, sample_names, excluded_samples=None): - """process samples""" +def process_samples(start_vars, sample_names=[], excluded_samples=[]): + """code to fetch correct samples""" sample_data = {} - if not excluded_samples: - excluded_samples = () - sample_vals_dict = json.loads(start_vars["sample_vals"]) + sample_vals_dict = json.loads(start_vars["sample_vals"]) + if sample_names: for sample in sample_names: - if sample not in excluded_samples and sample in sample_vals_dict: + if sample in sample_vals_dict and sample not in excluded_samples: + val = sample_vals_dict[sample] + if not val.strip().lower() == "x": + sample_data[str(sample)] = float(val) + + else: + for sample in sample_vals_dict.keys(): + if sample not in excluded_samples: val = sample_vals_dict[sample] if not val.strip().lower() == "x": sample_data[str(sample)] = float(val) @@ -147,6 +159,18 @@ def lit_for_trait_list(corr_results, this_dataset, this_trait): def fetch_sample_data(start_vars, this_trait, this_dataset, target_dataset): + corr_samples_group = start_vars["corr_samples_group"] + if corr_samples_group == "samples_primary": + sample_data = process_samples( + start_vars, this_dataset.group.all_samples_ordered()) + + elif corr_samples_group == "samples_other": + sample_data = process_samples( + start_vars, excluded_samples=this_dataset.group.samplelist) + + else: + sample_data = process_samples(start_vars) + sample_data = process_samples( start_vars, this_dataset.group.all_samples_ordered()) @@ -187,9 +211,9 @@ def compute_correlation(start_vars, method="pearson", compute_all=False): if corr_type == "sample": (this_trait_data, target_dataset_data) = fetch_sample_data( start_vars, this_trait, this_dataset, target_dataset) - correlation_results = compute_all_sample_correlation(corr_method=method, - this_trait=this_trait_data, - target_dataset=target_dataset_data) + + correlation_results = compute_all_sample_correlation( + corr_method=method, this_trait=this_trait_data, target_dataset=target_dataset_data) elif corr_type == "tissue": trait_symbol_dict = this_dataset.retrieve_genes("Symbol") @@ -290,7 +314,8 @@ def get_tissue_correlation_input(this_trait, trait_symbol_dict): """Gets tissue expression values for the primary trait and target tissues values""" primary_trait_tissue_vals_dict = correlation_functions.get_trait_symbol_and_tissue_values( symbol_list=[this_trait.symbol]) - if this_trait.symbol and this_trait.symbol.lower() in primary_trait_tissue_vals_dict: + if this_trait.symbol and this_trait.symbol.lower() in primary_trait_tissue_vals_dict: + primary_trait_tissue_values = primary_trait_tissue_vals_dict[this_trait.symbol.lower( )] corr_result_tissue_vals_dict = correlation_functions.get_trait_symbol_and_tissue_values( diff --git a/wqflask/wqflask/correlation/pre_computes.py b/wqflask/wqflask/correlation/pre_computes.py new file mode 100644 index 00000000..975a53b8 --- /dev/null +++ b/wqflask/wqflask/correlation/pre_computes.py @@ -0,0 +1,158 @@ +import json +import os +import hashlib +from pathlib import Path + +from base.data_set import query_table_timestamp +from base.webqtlConfig import TMPDIR + + +def fetch_all_cached_metadata(dataset_name): + """in a gvein dataset fetch all the traits metadata""" + file_name = generate_filename(dataset_name, suffix="metadata") + + file_path = os.path.join(TMPDIR, file_name) + + try: + with open(file_path, "r+") as file_handler: + dataset_metadata = json.load(file_handler) + return (file_path, dataset_metadata) + + except FileNotFoundError: + Path(file_path).touch(exist_ok=True) + return (file_path, {}) + + +def cache_new_traits_metadata(dataset_metadata: dict, new_traits_metadata, file_path: str): + """function to cache the new traits metadata""" + + if bool(new_traits_metadata): + dataset_metadata.update(new_traits_metadata) + + with open(file_path, "w+") as file_handler: + json.dump(dataset_metadata, file_handler) + + +def generate_filename(*args, suffix="", file_ext="json"): + """given a list of args generate a unique filename""" + + string_unicode = f"{*args,}".encode() + return f"{hashlib.md5(string_unicode).hexdigest()}_{suffix}.{file_ext}" + + +def cache_compute_results(base_dataset_type, + base_dataset_name, + target_dataset_name, + corr_method, + correlation_results, + trait_name): + """function to cache correlation results for heavy computations""" + + base_timestamp = query_table_timestamp(base_dataset_type) + + target_dataset_timestamp = base_timestamp + + file_name = generate_filename( + base_dataset_name, target_dataset_name, + base_timestamp, target_dataset_timestamp, + suffix="corr_precomputes") + + file_path = os.path.join(TMPDIR, file_name) + + try: + with open(file_path, "r+") as json_file_handler: + data = json.load(json_file_handler) + + data[trait_name] = correlation_results + + json_file_handler.seek(0) + + json.dump(data, json_file_handler) + + json_file_handler.truncate() + + except FileNotFoundError: + with open(file_path, "w+") as file_handler: + data = {} + data[trait_name] = correlation_results + + json.dump(data, file_handler) + + +def fetch_precompute_results(base_dataset_name, + target_dataset_name, + dataset_type, + trait_name): + """function to check for precomputed results""" + + base_timestamp = target_dataset_timestamp = query_table_timestamp( + dataset_type) + file_name = generate_filename( + base_dataset_name, target_dataset_name, + base_timestamp, target_dataset_timestamp, + suffix="corr_precomputes") + + file_path = os.path.join(TMPDIR, file_name) + + try: + with open(file_path, "r+") as json_handler: + correlation_results = json.load(json_handler) + + return correlation_results.get(trait_name) + + except FileNotFoundError: + pass + + +def pre_compute_dataset_vs_dataset(base_dataset, + target_dataset, + corr_method): + """compute sample correlation between dataset vs dataset + wn:heavy function should be invoked less frequently + input:datasets_data(two dicts),corr_method + + output:correlation results for entire dataset against entire dataset + """ + dataset_correlation_results = {} + + target_traits_data, base_traits_data = get_datasets_data( + base_dataset, target_dataset_data) + + for (primary_trait_name, strain_values) in base_traits_data: + + this_trait_data = { + "trait_sample_data": strain_values, + "trait_id": primary_trait_name + } + + trait_correlation_result = compute_all_sample_correlation( + corr_method=corr_method, + this_trait=this_trait_data, + target_dataset=target_traits_data) + + dataset_correlation_results[primary_trait_name] = trait_correlation_result + + return dataset_correlation_results + + +def get_datasets_data(base_dataset, target_dataset_data): + """required to pass data in a given format to the pre compute + function + + (works for bxd only probeset datasets) + + output:two dicts for datasets with key==trait and value==strains + """ + samples_fetched = base_dataset.group.all_samples_ordered() + target_traits_data = target_dataset.get_trait_data( + samples_fetched) + + base_traits_data = base_dataset.get_trait_data( + samples_fetched) + + target_results = map_shared_keys_to_values( + samples_fetched, target_traits_data) + base_results = map_shared_keys_to_values( + samples_fetched, base_traits_data) + + return (target_results, base_results) diff --git a/wqflask/wqflask/correlation/show_corr_results.py b/wqflask/wqflask/correlation/show_corr_results.py index 55915a74..2c820658 100644 --- a/wqflask/wqflask/correlation/show_corr_results.py +++ b/wqflask/wqflask/correlation/show_corr_results.py @@ -26,6 +26,9 @@ from base.trait import create_trait, jsonable from base.data_set import create_dataset from base.webqtlConfig import TMPDIR +from wqflask.correlation.pre_computes import fetch_all_cached_metadata +from wqflask.correlation.pre_computes import cache_new_traits_metadata + from utility import hmac @@ -34,7 +37,8 @@ def set_template_vars(start_vars, correlation_data): corr_method = start_vars['corr_sample_method'] if start_vars['dataset'] == "Temp": - this_dataset_ob = create_dataset(dataset_name="Temp", dataset_type="Temp", group_name=start_vars['group']) + this_dataset_ob = create_dataset( + dataset_name="Temp", dataset_type="Temp", group_name=start_vars['group']) else: this_dataset_ob = create_dataset(dataset_name=start_vars['dataset']) this_trait = create_trait(dataset=this_dataset_ob, @@ -86,13 +90,18 @@ def correlation_json_for_table(correlation_data, this_trait, this_dataset, targe corr_results = correlation_data['correlation_results'] results_list = [] + new_traits_metadata = {} + + (file_path, dataset_metadata) = fetch_all_cached_metadata( + target_dataset['name']) for i, trait_dict in enumerate(corr_results): trait_name = list(trait_dict.keys())[0] trait = trait_dict[trait_name] - target_trait = None - if target_trait is None: + target_trait = dataset_metadata.get(trait_name) + if target_trait is None: + target_trait_ob = create_trait(dataset=target_dataset_ob, name=trait_name, get_qtl_info=True) @@ -171,6 +180,10 @@ def correlation_json_for_table(correlation_data, this_trait, this_dataset, targe results_list.append(results_dict) + cache_new_traits_metadata(dataset_metadata, + new_traits_metadata, + file_path) + return json.dumps(results_list) diff --git a/wqflask/wqflask/static/gif/waitAnima2.gif b/wqflask/wqflask/static/gif/waitAnima2.gif Binary files differnew file mode 100644 index 00000000..50aff7f2 --- /dev/null +++ b/wqflask/wqflask/static/gif/waitAnima2.gif diff --git a/wqflask/wqflask/templates/loading.html b/wqflask/wqflask/templates/loading.html index ccf810b0..b9e31ad0 100644 --- a/wqflask/wqflask/templates/loading.html +++ b/wqflask/wqflask/templates/loading.html @@ -66,11 +66,11 @@ {% endif %} {% endif %} {% else %} - <h1>Loading {{ start_vars.tool_used }} Results...</h1> + <h1> {{ start_vars.tool_used }} Computation in progress ...</h1> {% endif %} <br><br> <div style="text-align: center;"> - <img align="center" src="/static/gif/89.gif"> + <img align="center" src="/static/gif/waitAnima2.gif"> </div> {% if start_vars.vals_diff|length != 0 and start_vars.transform == "" %} <br><br> |