about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-09-20 21:24:53 +0300
committerFrederick Muriuki Muriithi2022-09-20 21:24:53 +0300
commit3a8d99868cbc03e5ad6edced016504ed549ef468 (patch)
tree2de48ac76b4e1a25e04246f51ff14e29f4db2d69
parent8f732461b897a7c229c3b49a74fd831c2e440989 (diff)
parent3458808a65d5e55644ea23aa00982973230ac556 (diff)
downloadgenenetwork2-3a8d99868cbc03e5ad6edced016504ed549ef468.tar.gz
Merge branch 'Alexanderlacuna-feature/generate-text-files' into testing
-rw-r--r--wqflask/wqflask/correlation/pre_computes.py59
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py15
2 files changed, 69 insertions, 5 deletions
diff --git a/wqflask/wqflask/correlation/pre_computes.py b/wqflask/wqflask/correlation/pre_computes.py
index 1c52a0f5..afcea88f 100644
--- a/wqflask/wqflask/correlation/pre_computes.py
+++ b/wqflask/wqflask/correlation/pre_computes.py
@@ -171,11 +171,12 @@ def get_datasets_data(base_dataset, target_dataset_data):
     return (target_results, base_results)
 
 
-def fetch_text_file(dataset_name, conn, text_dir=TEXTDIR):
+def fetch_text_file(dataset_name, conn, text_dir=TMPDIR):
     """fetch textfiles with strain vals if exists"""
 
     with conn.cursor() as cursor:
-        cursor.execute('SELECT Id, FullName FROM ProbeSetFreeze WHERE Name = %s', (dataset_name,))
+        cursor.execute(
+            'SELECT Id, FullName FROM ProbeSetFreeze WHERE Name = %s', (dataset_name,))
         results = cursor.fetchone()
     if results:
         try:
@@ -204,3 +205,57 @@ def read_text_file(sample_dict, file_path):
         _posit, sample_vals = __fetch_id_positions__(
             next(csv_reader)[1:], sample_dict)
         return (sample_vals, [",".join([line[i] for i in _posit]) for line in csv_reader])
+
+
+def write_db_to_textfile(db_name, conn, text_dir=TMPDIR):
+
+    def __generate_file_name__(db_name):
+        # todo add expiry time and checker
+        with conn.cursor() as cursor:
+            cursor.execute(
+                'SELECT Id, FullName FROM ProbeSetFreeze WHERE Name = %s', (db_name,))
+            results = cursor.fetchone()
+            if (results):
+                return f"ProbeSetFreezeId_{results[0]}_{results[1]}"
+
+    def __parse_to_dict__(results):
+        ids = ["ID"]
+        data = {}
+        for (trait, strain, val) in results:
+            if strain not in ids:
+                ids.append(strain)
+            if trait in data:
+                data[trait].append(val)
+            else:
+                data[trait] = [trait, val]
+        return (data, ids)
+
+    def __write_to_file__(file_path, data, col_names):
+        with open(file_path, 'w+', encoding='UTF8') as file_handler:
+
+            writer = csv.writer(file_handler)
+            writer.writerow(col_names)
+            writer.writerows(data.values())
+    with conn.cursor() as cursor:
+        cursor.execute(
+            "SELECT ProbeSet.Name,Strain.Name, ProbeSetData.value "
+            "FROM (ProbeSetData, ProbeSetFreeze, Strain, ProbeSet, "
+            "ProbeSetXRef) LEFT JOIN ProbeSetSE ON "
+            "(ProbeSetSE.DataId = ProbeSetData.Id AND "
+            "ProbeSetSE.StrainId = ProbeSetData.StrainId) "
+            "LEFT JOIN NStrain ON "
+            "(NStrain.DataId = ProbeSetData.Id AND "
+            "NStrain.StrainId = ProbeSetData.StrainId) "
+            "WHERE ProbeSetXRef.ProbeSetId = ProbeSet.Id "
+            "AND ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id "
+            "AND ProbeSetFreeze.Name = %s AND "
+            "ProbeSetXRef.DataId = ProbeSetData.Id "
+            "AND ProbeSetData.StrainId = Strain.Id "
+            "ORDER BY Strain.Name",
+            (db_name,))
+        results = cursor.fetchall()
+        file_name = __generate_file_name__(
+            db_name)
+        if (results and file_name):
+            __write_to_file__(os.path.join(text_dir, file_name),
+                              *__parse_to_dict__(results))
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 5b39c871..d9193459 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -10,6 +10,7 @@ from wqflask.correlation.correlation_gn3_api import lit_for_trait_list
 from wqflask.correlation.correlation_gn3_api import do_lit_correlation
 from wqflask.correlation.pre_computes import fetch_text_file
 from wqflask.correlation.pre_computes import read_text_file
+from wqflask.correlation.pre_computes import write_db_to_textfile
 from gn3.computations.correlations import compute_all_lit_correlation
 from gn3.computations.rust_correlation import run_correlation
 from gn3.computations.rust_correlation import get_sample_corr_data
@@ -195,7 +196,7 @@ def compute_top_n_tissue(this_dataset, this_trait, traits, method):
                                   symbol_dict=get_trait_symbol_and_tissue_values(
                                       symbol_list=[this_trait.symbol]),
                                   dataset_symbols=trait_symbol_dict,
-                                  dataset_vals=corr_result_tissue_vals_dict)    
+                                  dataset_vals=corr_result_tissue_vals_dict)
 
     if data and data[0]:
         return run_correlation(
@@ -237,7 +238,15 @@ def __compute_sample_corr__(
             if file_path:
                 (sample_vals, target_data) = read_text_file(
                     sample_data, file_path)
-                return run_correlation(target_data, sample_vals, method, ",", corr_type, n_top)
+                return run_correlation(target_data, sample_vals,
+                                       method, ",", corr_type, n_top)
+            write_db_to_textfile(target_dataset.name, conn)
+            file_path = fetch_text_file(target_dataset.name, conn)
+            if file_path:
+                (sample_vals, target_data) = read_text_file(
+                    sample_data, file_path)
+                return run_correlation(target_data, sample_vals,
+                                       method, ",", corr_type, n_top)
 
     target_dataset.get_trait_data(list(sample_data.keys()))
 
@@ -248,7 +257,7 @@ def __compute_sample_corr__(
         target_data.append(r)
 
     if len(target_data) == 0:
-        return  {}
+        return {}
 
     return run_correlation(
         target_data, list(sample_data.values()), method, ",", corr_type,