about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--wqflask/wqflask/correlation/rust_correlation.py27
1 files changed, 14 insertions, 13 deletions
diff --git a/wqflask/wqflask/correlation/rust_correlation.py b/wqflask/wqflask/correlation/rust_correlation.py
index ef3988e5..92cb362c 100644
--- a/wqflask/wqflask/correlation/rust_correlation.py
+++ b/wqflask/wqflask/correlation/rust_correlation.py
@@ -134,20 +134,21 @@ def compute_top_n_sample(start_vars, dataset, trait_list):
     with database_connector() as conn:
         with conn.cursor() as curr:
             # fetching strain data in bulk
+            query = (
+                "SELECT * from ProbeSetData "
+                f"WHERE StrainID IN ({', '.join(['%s'] * len(sample_ids))})"
+                "AND id IN ("
+                "  SELECT ProbeSetXRef.DataId "
+                "  FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze) "
+                "  WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id "
+                "  AND ProbeSetFreeze.Name = %s "
+                "  AND ProbeSet.Name "
+                f" IN ({', '.join(['%s'] * len(trait_list))})"
+                "  ProbeSet.Id = ProbeSetXRef.ProbeSetId)"
+                ")")
             curr.execute(
-                """
-                SELECT * from ProbeSetData
-                where StrainID in {}
-                and id in (SELECT ProbeSetXRef.DataId
-                FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze)
-                WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id
-                and ProbeSetFreeze.Name = '{}'
-                and ProbeSet.Name in {}
-                and ProbeSet.Id = ProbeSetXRef.ProbeSetId)
-                """.format(
-                    create_in_clause(list(sample_ids.values())),
-                    dataset.name,
-                    create_in_clause(trait_list)))
+                query,
+                tuple(sample_ids.values()) + (dataset.name,) + tuple(trait_list))
 
             corr_data = chunk_dataset(
                 list(curr.fetchall()), len(sample_ids.values()), dataset.name)