about summary refs log tree commit diff
path: root/wqflask/base
diff options
context:
space:
mode:
authorBonfaceKilz2021-11-24 12:48:33 +0300
committerGitHub2021-11-24 12:48:33 +0300
commit16116373899b44e0f0a3894f1f2e5b7f60a5d498 (patch)
tree795043155c61640a914f75e3a6094273835605af /wqflask/base
parent41e742904ff4cf35abbd885eeb98902a05d3be80 (diff)
parentfffeb91789943a3c7db5a72d66405e2a0459ed44 (diff)
downloadgenenetwork2-16116373899b44e0f0a3894f1f2e5b7f60a5d498.tar.gz
Merge pull request #624 from Alexanderlacuna/feature/correlation-optimization2
Feature/correlation optimization2
Diffstat (limited to 'wqflask/base')
-rw-r--r--wqflask/base/data_set.py67
1 files changed, 35 insertions, 32 deletions
diff --git a/wqflask/base/data_set.py b/wqflask/base/data_set.py
index 768ad49b..49ece9dd 100644
--- a/wqflask/base/data_set.py
+++ b/wqflask/base/data_set.py
@@ -40,6 +40,8 @@ from base import species
 from base import webqtlConfig
 from flask import Flask, g
 from base.webqtlConfig import TMPDIR
+from urllib.parse import urlparse
+from utility.tools import SQL_URI
 import os
 import math
 import string
@@ -747,15 +749,16 @@ class DataSet:
             and Species.name = '{}'
             """.format(create_in_clause(self.samplelist), *mescape(self.group.species))
         results = dict(g.db.execute(query).fetchall())
-        sample_ids = [results[item] for item in self.samplelist]
+        sample_ids = [results.get(item)
+                      for item in self.samplelist if item is not None]
 
         # MySQL limits the number of tables that can be used in a join to 61,
         # so we break the sample ids into smaller chunks
         # Postgres doesn't have that limit, so we can get rid of this after we transition
         chunk_size = 50
         number_chunks = int(math.ceil(len(sample_ids) / chunk_size))
-        # cached_results = fetch_cached_results(self.name, self.type)
-        cached_results =  None
+
+        cached_results = fetch_cached_results(self.name, self.type)
         if cached_results is None:
             trait_sample_data = []
             for sample_ids_step in chunks.divide_into_chunks(sample_ids, number_chunks):
@@ -800,21 +803,21 @@ class DataSet:
                 results = g.db.execute(query).fetchall()
                 trait_sample_data.append([list(result) for result in results])
 
+            trait_count = len(trait_sample_data[0])
+            self.trait_data = collections.defaultdict(list)
 
-        else:
-            trait_sample_data = cached_results
+            data_start_pos = 1
+            for trait_counter in range(trait_count):
+                trait_name = trait_sample_data[0][trait_counter][0]
+                for chunk_counter in range(int(number_chunks)):
+                    self.trait_data[trait_name] += (
+                        trait_sample_data[chunk_counter][trait_counter][data_start_pos:])
 
-        trait_count = len(trait_sample_data[0])
-        self.trait_data = collections.defaultdict(list)
+            cache_dataset_results(
+                self.name, self.type, self.trait_data)
+        else:
 
-        # put all of the separate data together into a dictionary where the keys are
-        # trait names and values are lists of sample values
-        data_start_pos = 1
-        for trait_counter in range(trait_count):
-            trait_name = trait_sample_data[0][trait_counter][0]
-            for chunk_counter in range(int(number_chunks)):
-                self.trait_data[trait_name] += (
-                    trait_sample_data[chunk_counter][trait_counter][data_start_pos:])
+            self.trait_data = cached_results
 
 
 class PhenotypeDataSet(DataSet):
@@ -1254,25 +1257,30 @@ def geno_mrna_confidentiality(ob):
         return True
 
 
+def parse_db_url():
+    parsed_db = urlparse(SQL_URI)
+
+    return (parsed_db.hostname, parsed_db.username,
+            parsed_db.password, parsed_db.path[1:])
+
+
 def query_table_timestamp(dataset_type: str):
     """function to query the update timestamp of a given dataset_type"""
 
     # computation data and actions
 
+    fetch_db_name = parse_db_url()
     query_update_time = f"""
                     SELECT UPDATE_TIME FROM   information_schema.tables
-                    WHERE  TABLE_SCHEMA = 'db_webqtl_s'
+                    WHERE  TABLE_SCHEMA = '{fetch_db_name[-1]}'
                     AND TABLE_NAME = '{dataset_type}Data'
                 """
 
-    # store the timestamp in redis=
     date_time_obj = g.db.execute(query_update_time).fetchone()[0]
-
-    f = "%Y-%m-%d %H:%M:%S"
-    return date_time_obj.strftime(f)
+    return date_time_obj.strftime("%Y-%m-%d %H:%M:%S")
 
 
-def generate_hash_file(dataset_name: str, dataset_timestamp: str):
+def generate_hash_file(dataset_name: str, dataset_type: str, dataset_timestamp: str):
     """given the trait_name generate a unique name for this"""
     string_unicode = f"{dataset_name}{dataset_timestamp}".encode()
     md5hash = hashlib.md5(string_unicode)
@@ -1280,15 +1288,16 @@ def generate_hash_file(dataset_name: str, dataset_timestamp: str):
 
 
 def cache_dataset_results(dataset_name: str, dataset_type: str, query_results: List):
-    """function to cache dataset query results to file"""
+    """function to cache dataset query results to file
+    input dataset_name and type query_results(already processed in default dict format)
+    """
     # data computations actions
     # store the file path on redis
 
     table_timestamp = query_table_timestamp(dataset_type)
 
-    results = r.set(f"{dataset_type}timestamp", table_timestamp)
 
-    file_name = generate_hash_file(dataset_name, table_timestamp)
+    file_name = generate_hash_file(dataset_name, dataset_type, table_timestamp)
     file_path = os.path.join(TMPDIR, f"{file_name}.json")
 
     with open(file_path, "w") as file_handler:
@@ -1298,19 +1307,13 @@ def cache_dataset_results(dataset_name: str, dataset_type: str, query_results: L
 def fetch_cached_results(dataset_name: str, dataset_type: str):
     """function to fetch the cached results"""
 
-    table_timestamp = r.get(f"{dataset_type}timestamp")
-
-    if table_timestamp is not None:
-        table_timestamp = table_timestamp.decode("utf-8")
-    else:
-        table_timestamp = ""
+    table_timestamp = query_table_timestamp(dataset_type)
 
-    file_name = generate_hash_file(dataset_name, table_timestamp)
+    file_name = generate_hash_file(dataset_name, dataset_type, table_timestamp)
     file_path = os.path.join(TMPDIR, f"{file_name}.json")
     try:
         with open(file_path, "r") as file_handler:
 
             return json.load(file_handler)
     except FileNotFoundError:
-        # take actions continue to fetch dataset results and fetch results
         pass