about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--wqflask/wqflask/correlation/pre_computes.py54
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py13
2 files changed, 41 insertions, 26 deletions
diff --git a/wqflask/wqflask/correlation/pre_computes.py b/wqflask/wqflask/correlation/pre_computes.py
index 81111a3c..65774326 100644
--- a/wqflask/wqflask/correlation/pre_computes.py
+++ b/wqflask/wqflask/correlation/pre_computes.py
@@ -13,6 +13,7 @@ from base.webqtlConfig import TEXTDIR
 from base.webqtlConfig import TMPDIR
 
 from json.decoder import JSONDecodeError
+from gn3.db_utils import database_connection
 
 
 def cache_trait_metadata(dataset_name, data):
@@ -177,47 +178,47 @@ def write_db_to_textfile(db_name, conn, text_dir=TMPDIR):
                               *__parse_to_dict__(results))
 
 
-# check for file path
-# I need the lmdb path # tmpdir
 def __generate_target_name__(db_name):
     # todo add expiry time and checker
-    with conn.cursor() as cursor:
-        cursor.execute(
-            'SELECT Id, FullName FROM ProbeSetFreeze WHERE Name = %s', (db_name,))
-        results = cursor.fetchone()
-        if (results):
-            return __sanitise_filename__(
-                f"ProbeSetFreezeId_{results[0]}_{results[1]}")
+    with database_connection(SQL_URI) as conn:
+        with conn.cursor() as cursor:
+            cursor.execute(
+                'SELECT Id, FullName FROM ProbeSetFreeze WHERE Name = %s', (db_name,))
+            results = cursor.fetchone()
+            if (results):
+                return __sanitise_filename__(
+                    f"ProbeSetFreezeId_{results[0]}_{results[1]}")
 
 
-def fetch_csv_info(db_target_name: str, conn):
+def fetch_csv_info(db_target_name: str):
     """
     alternative for processing csv textfiles with rust
     !.currently expiremental
     """
-    csv_file_path = fetch_text_file(dataset_name, conn)
-    if csv_file_path:
-        return {
-            "file_type": "csv",
-            "lmdb_target_path": "csv_file_path",
-            "lmdb_target_path": "",
-        }
+    with database_connection(SQL_URI) as conn:
+        csv_file_path = fetch_text_file(dataset_name, conn)
+        if csv_file_path:
+            return {
+                "file_type": "csv",
+                "lmdb_target_path": "csv_file_path",
+                "lmdb_target_path": "",
+            }
 
 
-def fetch_lmdb_info(db_path: str, db_target_name: str):
+def fetch_lmdb_info(db_target_name: str, lmdb_dataset_path=LMDB_PATH):
     """
     check for if lmdb db exist and also the target db 
     e.g  ProbeSets: ProbestFreeze__112_
     """
     # open db_read if results return none write the file write the file target
     try:
-        with lmdb.open(target_file_path, readonly=True, lock=False) as env:
+        with lmdb.open(lmdb_dataset_path, readonly=True, lock=False) as env:
             with env.begin() as txn:
-                target_key = __generate_file_name__(db_target_name)
+                target_key = __generate_target_name__(db_target_name)
                 dataset = txn.get(target_key.encode())
                 if dataset:
                     return {
-                        "lmdb_target_path": f"{db_path}data.mdb",
+                        "lmdb_target_path": f"{db_path}/data.mdb",
                         "lmdb_target_key": target_key,
                         "file_type": "lmdb",
                     }
@@ -225,15 +226,16 @@ def fetch_lmdb_info(db_path: str, db_target_name: str):
         return {}
 
 
-def generate_general_info(trait_name, sample_names,
-                          strains_vals, file_type_info):
+def generate_general_info(trait_name: str, corr_type: str,
+                          sample_dict: dict, file_type_info: dict):
     if not file_type_info:
         #! code should not be reached at this point
         pass
     # implement fetch code
     return {
-        "trait_name": target_name,
-        "primary_sample_names": primary_sample_names,
-        "primary_strains_vals": strain_vals,
+        "trait_name": trait_name,
+        "corr_type": corr_type,
+        "primary_sample_names": list(sample_dict.keys()),
+        "primary_strains_vals": list(sample_dict.values()),
         **file_type_info
     }
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 41dd77a1..fe1260d7 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -19,6 +19,7 @@ from gn3.computations.correlations import compute_all_lit_correlation
 from gn3.computations.rust_correlation import run_correlation
 from gn3.computations.rust_correlation import get_sample_corr_data
 from gn3.computations.rust_correlation import parse_tissue_corr_data
+from gn3.computations.rust_correlation import  run_lmdb_correlation
 from gn3.db_utils import database_connection
 
 from wqflask.correlation.exceptions import WrongCorrelationType
@@ -258,6 +259,18 @@ def __compute_sample_corr__(
         return {}
 
     if target_dataset.type == "ProbeSet" and start_vars.get("use_cache") == "true":
+
+        #nit code try to fetch lmdb file
+        # add merge case for csv file
+        try:
+
+            lmdb_info = fetch_lmdb_info(target_dataset.name)
+            if lmdb_info:
+                return run_lmdb_correlation(lmdb_info)
+        except Exception:
+            # compute correlation the normal way
+            pass
+
         with database_connection(SQL_URI) as conn:
             file_path = fetch_text_file(target_dataset.name, conn)
             if file_path: