about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--wqflask/wqflask/correlation/correlation_gn3_api.py2
-rw-r--r--wqflask/wqflask/correlation/pre_computes.py160
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py37
3 files changed, 55 insertions, 144 deletions
diff --git a/wqflask/wqflask/correlation/correlation_gn3_api.py b/wqflask/wqflask/correlation/correlation_gn3_api.py
index cffcda60..64a17548 100644
--- a/wqflask/wqflask/correlation/correlation_gn3_api.py
+++ b/wqflask/wqflask/correlation/correlation_gn3_api.py
@@ -6,8 +6,6 @@ from functools import wraps
 from utility.tools import SQL_URI
 
 from wqflask.correlation import correlation_functions
-from wqflask.correlation.pre_computes import fetch_precompute_results
-from wqflask.correlation.pre_computes import cache_compute_results
 from base import data_set
 
 from base.trait import create_trait
diff --git a/wqflask/wqflask/correlation/pre_computes.py b/wqflask/wqflask/correlation/pre_computes.py
index f21ec06a..2831bd39 100644
--- a/wqflask/wqflask/correlation/pre_computes.py
+++ b/wqflask/wqflask/correlation/pre_computes.py
@@ -2,6 +2,10 @@ import csv
 import json
 import os
 import hashlib
+import datetime
+
+import lmdb
+import pickle
 from pathlib import Path
 
 from base.data_set import query_table_timestamp
@@ -10,6 +14,31 @@ from base.webqtlConfig import TMPDIR
 
 from json.decoder import JSONDecodeError
 
+def cache_trait_metadata(dataset_name, data):
+
+
+    try:
+        with lmdb.open(os.path.join(TMPDIR,f"metadata_{dataset_name}"),map_size=20971520) as env:
+            with  env.begin(write=True) as  txn:
+                data_bytes = pickle.dumps(data)
+                txn.put(f"{dataset_name}".encode(), data_bytes)
+                current_date = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+                txn.put(b"creation_date", current_date.encode())
+                return "success"
+
+    except lmdb.Error as  error:
+        pass
+
+def read_trait_metadata(dataset_name):
+    try:
+        with lmdb.open(os.path.join(TMPDIR,f"metadata_{dataset_name}"),
+            readonly=True, lock=False) as env:
+            with env.begin() as txn:
+                db_name = txn.get(dataset_name.encode())
+                return (pickle.loads(db_name) if db_name else {})
+    except lmdb.Error as error:
+        return {}
+
 
 def fetch_all_cached_metadata(dataset_name):
     """in a gvein dataset fetch all the traits metadata"""
@@ -53,132 +82,15 @@ def generate_filename(*args, suffix="", file_ext="json"):
     return f"{hashlib.md5(string_unicode).hexdigest()}_{suffix}.{file_ext}"
 
 
-def cache_compute_results(base_dataset_type,
-                          base_dataset_name,
-                          target_dataset_name,
-                          corr_method,
-                          correlation_results,
-                          trait_name):
-    """function to cache correlation results for heavy computations"""
-
-    base_timestamp = query_table_timestamp(base_dataset_type)
-
-    target_dataset_timestamp = base_timestamp
-
-    file_name = generate_filename(
-        base_dataset_name, target_dataset_name,
-        base_timestamp, target_dataset_timestamp,
-        suffix="corr_precomputes")
-
-    file_path = os.path.join(TMPDIR, file_name)
-
-    try:
-        with open(file_path, "r+") as json_file_handler:
-            data = json.load(json_file_handler)
-
-            data[trait_name] = correlation_results
-
-            json_file_handler.seek(0)
-
-            json.dump(data, json_file_handler)
-
-            json_file_handler.truncate()
-
-    except FileNotFoundError:
-        with open(file_path, "w+") as file_handler:
-            data = {}
-            data[trait_name] = correlation_results
-
-            json.dump(data, file_handler)
-
-
-def fetch_precompute_results(base_dataset_name,
-                             target_dataset_name,
-                             dataset_type,
-                             trait_name):
-    """function to check for precomputed  results"""
-
-    base_timestamp = target_dataset_timestamp = query_table_timestamp(
-        dataset_type)
-    file_name = generate_filename(
-        base_dataset_name, target_dataset_name,
-        base_timestamp, target_dataset_timestamp,
-        suffix="corr_precomputes")
 
-    file_path = os.path.join(TMPDIR, file_name)
 
-    try:
-        with open(file_path, "r+") as json_handler:
-            correlation_results = json.load(json_handler)
-
-        return correlation_results.get(trait_name)
-
-    except FileNotFoundError:
-        pass
-
-
-def pre_compute_dataset_vs_dataset(base_dataset,
-                                   target_dataset,
-                                   corr_method):
-    """compute sample correlation between dataset vs dataset
-    wn:heavy function should be invoked less frequently
-    input:datasets_data(two dicts),corr_method
-
-    output:correlation results for entire dataset against entire dataset
-    """
-    dataset_correlation_results = {}
-
-    target_traits_data, base_traits_data = get_datasets_data(
-        base_dataset, target_dataset_data)
-
-    for (primary_trait_name, strain_values) in base_traits_data:
-
-        this_trait_data = {
-            "trait_sample_data": strain_values,
-            "trait_id": primary_trait_name
-        }
-
-        trait_correlation_result = compute_all_sample_correlation(
-            corr_method=corr_method,
-            this_trait=this_trait_data,
-            target_dataset=target_traits_data)
-
-        dataset_correlation_results[primary_trait_name] = trait_correlation_result
-
-    return dataset_correlation_results
-
-
-def get_datasets_data(base_dataset, target_dataset_data):
-    """required to pass data in a given format to the pre compute
-    function
-
-    (works for bxd only probeset datasets)
-
-    output:two dicts for datasets with key==trait and value==strains
-    """
-    samples_fetched = base_dataset.group.all_samples_ordered()
-    target_traits_data = target_dataset.get_trait_data(
-        samples_fetched)
-
-    base_traits_data = base_dataset.get_trait_data(
-        samples_fetched)
-
-    target_results = map_shared_keys_to_values(
-        samples_fetched, target_traits_data)
-    base_results = map_shared_keys_to_values(
-        samples_fetched, base_traits_data)
-
-    return (target_results, base_results)
-
-
-def fetch_text_file(dataset_name, conn, text_dir=TEXTDIR):
+def fetch_text_file(dataset_name, conn, text_dir=TMPDIR):
     """fetch textfiles with strain vals if exists"""
 
-
-    def __file_scanner__(text_dir,target_file):
-        for file  in os.listdir(text_dir):
-            if file.startswith(f"ProbeSetFreezeId_{results[0]}_"):
-                return os.path.join(text_dir,file)
+    def __file_scanner__(text_dir, target_file):
+        for file in os.listdir(text_dir):
+            if file.startswith(f"ProbeSetFreezeId_{target_file}_"):
+                return os.path.join(text_dir, file)
 
     with conn.cursor() as cursor:
         cursor.execute(
@@ -186,9 +98,9 @@ def fetch_text_file(dataset_name, conn, text_dir=TEXTDIR):
         results = cursor.fetchone()
     if results:
         try:
-            # addition check for matrix file in gn_matrix folder
+            # checks first for recently generated textfiles if not use gn1 datamatrix
 
-            return __file_scanner__(text_dir,results) or __file_scanner__(TEXTDIR,results)
+            return __file_scanner__(text_dir, results[0]) or __file_scanner__(TEXTDIR, results[0])
 
         except Exception:
             pass
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 67bd5ff5..41dd77a1 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -13,6 +13,8 @@ from wqflask.correlation.correlation_gn3_api import do_lit_correlation
 from wqflask.correlation.pre_computes import fetch_text_file
 from wqflask.correlation.pre_computes import read_text_file
 from wqflask.correlation.pre_computes import write_db_to_textfile
+from wqflask.correlation.pre_computes import read_trait_metadata
+from wqflask.correlation.pre_computes import cache_trait_metadata
 from gn3.computations.correlations import compute_all_lit_correlation
 from gn3.computations.rust_correlation import run_correlation
 from gn3.computations.rust_correlation import get_sample_corr_data
@@ -25,7 +27,7 @@ from wqflask.correlation.exceptions import WrongCorrelationType
 def query_probes_metadata(dataset, trait_list):
     """query traits metadata in bulk for probeset"""
 
-    if not bool(trait_list) or dataset.type!="ProbeSet":
+    if not bool(trait_list) or dataset.type != "ProbeSet":
         return []
 
     with database_connection(SQL_URI) as conn:
@@ -63,8 +65,11 @@ def get_metadata(dataset, traits):
         if probe_mb:
             return f"Chr{probe_chr}: {probe_mb:.6f}"
         return f"Chr{probe_chr}: ???"
-
-    return {trait_name: {
+    cached_metadata = read_trait_metadata(dataset.name)
+    to_fetch_metadata = list(
+        set(traits).difference(list(cached_metadata.keys())))
+    if to_fetch_metadata:
+        results = {**({trait_name: {
             "name": trait_name,
             "view": True,
             "symbol": symbol,
@@ -77,13 +82,16 @@ def get_metadata(dataset, traits):
             "location": __location__(probe_chr, probe_mb),
             "chr": probe_chr,
             "mb": probe_mb,
-            "lrs_location":f'Chr{chr_score}: {mb:{".6f" if mb  else ""}}',
+            "lrs_location": f'Chr{chr_score}: {mb:{".6f" if mb  else ""}}',
             "lrs_chr": chr_score,
             "lrs_mb": mb
 
-            } for trait_name, probe_chr, probe_mb, symbol, mean, description,
+        } for trait_name, probe_chr, probe_mb, symbol, mean, description,
             additive, lrs, chr_score, mb
-            in query_probes_metadata(dataset, traits)}
+            in query_probes_metadata(dataset, to_fetch_metadata)}), **cached_metadata}
+        cache_trait_metadata(dataset.name, results)
+        return results
+    return cached_metadata
 
 
 def chunk_dataset(dataset, steps, name):
@@ -235,21 +243,20 @@ def __compute_sample_corr__(
     """Compute the sample correlations"""
     (this_dataset, this_trait, target_dataset, sample_data) = target_trait_info
 
-    if this_dataset.group.f1list !=None:
-        this_dataset.group.samplelist+= this_dataset.group.f1list
+    if this_dataset.group.f1list != None:
+        this_dataset.group.samplelist += this_dataset.group.f1list
 
-    if this_dataset.group.parlist!= None:
-        this_dataset.group.samplelist+= this_dataset.group.parlist
+    if this_dataset.group.parlist != None:
+        this_dataset.group.samplelist += this_dataset.group.parlist
 
     sample_data = get_sample_corr_data(
         sample_type=start_vars["corr_samples_group"],
-        sample_data= json.loads(start_vars["sample_vals"]),
+        sample_data=json.loads(start_vars["sample_vals"]),
         dataset_samples=this_dataset.group.all_samples_ordered())
 
     if not bool(sample_data):
         return {}
 
-
     if target_dataset.type == "ProbeSet" and start_vars.get("use_cache") == "true":
         with database_connection(SQL_URI) as conn:
             file_path = fetch_text_file(target_dataset.name, conn)
@@ -257,23 +264,18 @@ def __compute_sample_corr__(
                 (sample_vals, target_data) = read_text_file(
                     sample_data, file_path)
 
-        
                 return run_correlation(target_data, sample_vals,
                                        method, ",", corr_type, n_top)
 
-
             write_db_to_textfile(target_dataset.name, conn)
             file_path = fetch_text_file(target_dataset.name, conn)
             if file_path:
                 (sample_vals, target_data) = read_text_file(
                     sample_data, file_path)
 
-
                 return run_correlation(target_data, sample_vals,
                                        method, ",", corr_type, n_top)
 
-
-
     target_dataset.get_trait_data(list(sample_data.keys()))
 
     def __merge_key_and_values__(rows, current):
@@ -288,7 +290,6 @@ def __compute_sample_corr__(
     if len(target_data) == 0:
         return {}
 
-
     return run_correlation(
         target_data, list(sample_data.values()), method, ",", corr_type,
         n_top)