about summary refs log tree commit diff
path: root/wqflask
diff options
context:
space:
mode:
Diffstat (limited to 'wqflask')
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py71
1 files changed, 29 insertions, 42 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index 37e2ba76..5c22efbf 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -28,11 +28,9 @@ def chunk_dataset(dataset, steps, name):
     """.format(name)
 
     with database_connector() as conn:
-        curr = conn.cursor()
-
-        curr.execute(query)
-
-        traits_name_dict = dict(curr.fetchall())
+        with conn.cursor() as curr:
+            curr.execute(query)
+            traits_name_dict = dict(curr.fetchall())
 
     for i in range(0, len(dataset), steps):
         matrix = list(dataset[i:i + steps])
@@ -52,60 +50,49 @@ def compute_top_n_sample(start_vars, dataset, trait_list):
         return {}
 
     def __fetch_sample_ids__(samples_vals, samples_group):
-
         all_samples = json.loads(samples_vals)
         sample_data = get_sample_corr_data(
             sample_type=samples_group, all_samples=all_samples,
             dataset_samples=dataset.group.all_samples_ordered())
 
         with database_connector() as conn:
-
-            curr = conn.cursor()
-
-            curr.execute(
-                """
+            with conn.cursor() as curr:
+                curr.execute(
+                    """
                     SELECT Strain.Name, Strain.Id FROM Strain, Species
                     WHERE Strain.Name IN {}
                     and Strain.SpeciesId=Species.Id
                     and Species.name = '{}'
                     """.format(create_in_clause(list(sample_data.keys())),
-                               *mescape(dataset.group.species))
-
-            )
-
-            return (sample_data, dict(curr.fetchall()))
+                               *mescape(dataset.group.species)))
+                return (sample_data, dict(curr.fetchall()))
 
     (sample_data, sample_ids) = __fetch_sample_ids__(
         start_vars["sample_vals"], start_vars["corr_samples_group"])
 
     with database_connector() as conn:
+        with conn.cursor() as curr:
+            # fetching strain data in bulk
+            curr.execute(
+                """
+                SELECT * from ProbeSetData
+                where StrainID in {}
+                and id in (SELECT ProbeSetXRef.DataId
+                FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze)
+                WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id
+                and ProbeSetFreeze.Name = '{}'
+                and ProbeSet.Name in {}
+                and ProbeSet.Id = ProbeSetXRef.ProbeSetId)
+                """.format(
+                    create_in_clause(list(sample_ids.values())),
+                    dataset.name,
+                    create_in_clause(trait_list)))
+
+            corr_data = chunk_dataset(
+                list(curr.fetchall()), len(sample_ids.values()), dataset.name)
 
-        curr = conn.cursor()
-
-        # fetching strain data in bulk
-
-        curr.execute(
-
-            """
-            SELECT * from ProbeSetData
-            where StrainID in {}
-            and id in (SELECT ProbeSetXRef.DataId
-            FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze)
-            WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id
-            and ProbeSetFreeze.Name = '{}'
-            and ProbeSet.Name in {}
-            and ProbeSet.Id = ProbeSetXRef.ProbeSetId)
-           """.format(create_in_clause(list(sample_ids.values())), dataset.name, create_in_clause(trait_list))
-
-
-        )
-
-        corr_data = chunk_dataset(list(curr.fetchall()), len(
-            sample_ids.values()), dataset.name)
-
-        return run_correlation(corr_data,
-                               list(sample_data.values()),
-                               "pearson", ",")
+        return run_correlation(
+            corr_data, list(sample_data.values()), "pearson", ",")
 
 
 def compute_top_n_lit(corr_results, this_dataset, this_trait) -> dict: