about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--wqflask/base/data_set.py51
-rw-r--r--wqflask/wqflask/correlation/correlation_gn3_api.py2
-rw-r--r--wqflask/wqflask/views.py3
3 files changed, 51 insertions, 5 deletions
diff --git a/wqflask/base/data_set.py b/wqflask/base/data_set.py
index 178234fe..468c4da0 100644
--- a/wqflask/base/data_set.py
+++ b/wqflask/base/data_set.py
@@ -115,7 +115,8 @@ Publish or ProbeSet. E.g.
             except:
                 pass
 
-            self.redis_instance.set("dataset_structure", json.dumps(self.datasets))
+            self.redis_instance.set(
+                "dataset_structure", json.dumps(self.datasets))
 
     def set_dataset_key(self, t, name):
         """If name is not in the object's dataset dictionary, set it, and update
@@ -154,10 +155,12 @@ Publish or ProbeSet. E.g.
         if t in ['pheno', 'other_pheno']:
             group_name = name.replace("Publish", "")
 
-        results = g.db.execute(sql_query_mapping[t].format(group_name)).fetchone()
+        results = g.db.execute(
+            sql_query_mapping[t].format(group_name)).fetchone()
         if results:
             self.datasets[name] = dataset_name_mapping[t]
-            self.redis_instance.set("dataset_structure", json.dumps(self.datasets))
+            self.redis_instance.set(
+                "dataset_structure", json.dumps(self.datasets))
             return True
 
         return None
@@ -169,7 +172,8 @@ Publish or ProbeSet. E.g.
                 # This has side-effects, with the end result being a truth-y value
                 if(self.set_dataset_key(t, name)):
                     break
-        return self.datasets.get(name, None)  # Return None if name has not been set
+        # Return None if name has not been set
+        return self.datasets.get(name, None)
 
 
 # Do the intensive work at startup one time only
@@ -651,6 +655,43 @@ class DataSet(object):
                 "Dataset {} is not yet available in GeneNetwork.".format(self.name))
             pass
 
+    def fetch_probe_trait_data(self, sample_list=None):
+        if sample_list:
+            self.samplelist = sample_list
+        else:
+            self.samplelist = self.group.samplelist
+
+        if self.group.parlist != None and self.group.f1list != None:
+            if (self.group.parlist + self.group.f1list) in self.samplelist:
+                self.samplelist += self.group.parlist + self.group.f1list
+
+        query = """
+            SELECT Strain.Name, Strain.Id FROM Strain, Species
+            WHERE Strain.Name IN {}
+            and Strain.SpeciesId=Species.Id
+            and Species.name = '{}'
+            """.format(create_in_clause(self.samplelist), *mescape(self.group.species))
+        logger.sql(query)
+        results = dict(g.db.execute(query).fetchall())
+        sample_ids = [results[item] for item in self.samplelist]
+
+        query = """SELECT * from ProbeSetData WHERE Id in ( SELECT ProbeSetXRef.DataId FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze) WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id  and ProbeSetFreeze.Name = 'HC_M2_0606_P'  and ProbeSet.Id = ProbeSetXRef.ProbeSetId  order by ProbeSet.Id )    and  StrainId in ({})""".format(
+            ",".join(str(sample_id) for sample_id in sample_ids))
+
+        results = g.db.execute(query).fetchall()
+
+        # with conn:
+        #     cursor = conn.cursor()
+        #     cursor.execute(query)
+        #     results = cursor.fetchall()
+        trait_data = {}
+        for trait_id, StrainId, value in results:
+            if trait_id in trait_data:
+                trait_data[trait_id].append(value)
+            else:
+                trait_data[trait_id] = [value]
+        self.trait_data = trait_data
+
     def get_trait_data(self, sample_list=None):
         if sample_list:
             self.samplelist = sample_list
@@ -670,6 +711,7 @@ class DataSet(object):
         logger.sql(query)
         results = dict(g.db.execute(query).fetchall())
         sample_ids = [results[item] for item in self.samplelist]
+        print("the number of sample ids are", len(sample_ids))
 
         # MySQL limits the number of tables that can be used in a join to 61,
         # so we break the sample ids into smaller chunks
@@ -720,6 +762,7 @@ class DataSet(object):
             trait_sample_data.append(results)
 
         trait_count = len(trait_sample_data[0])
+        print("the trait count is >>>", trait_count)
         self.trait_data = collections.defaultdict(list)
 
         # put all of the separate data together into a dictionary where the keys are
diff --git a/wqflask/wqflask/correlation/correlation_gn3_api.py b/wqflask/wqflask/correlation/correlation_gn3_api.py
index e7394647..51bf5fb5 100644
--- a/wqflask/wqflask/correlation/correlation_gn3_api.py
+++ b/wqflask/wqflask/correlation/correlation_gn3_api.py
@@ -78,7 +78,7 @@ def compute_correlation(start_vars, method="pearson"):
         # }
         sample_data = process_samples(
             start_vars, this_dataset.group.samplelist)
-        target_dataset.get_trait_data(list(sample_data.keys()))
+        target_dataset.fetch_probe_trait_data(list(sample_data.keys()))
         this_trait = retrieve_sample_data(this_trait, this_dataset)
 
         print("Creating dataset and trait took", time.time()-initial_time)
diff --git a/wqflask/wqflask/views.py b/wqflask/wqflask/views.py
index 072db466..2c239425 100644
--- a/wqflask/wqflask/views.py
+++ b/wqflask/wqflask/views.py
@@ -881,7 +881,10 @@ def network_graph_page():
 def corr_compute_page():
     logger.info("In corr_compute, request.form is:", pf(request.form))
     logger.info(request.url)
+    import time
+    initial_time = time.time()
     correlation_results = compute_correlation(request.form)
+    print(">>>>Time taken by this endpoint",time.time()-initial_time)
     return render_template("demo_correlation_page.html",correlation_results=correlation_results[1:20])
 
 @app.route("/corr_matrix", methods=('POST',))