about summary refs log tree commit diff
path: root/wqflask/base/data_set.py
diff options
context:
space:
mode:
authorzsloan2021-10-18 17:50:26 +0000
committerzsloan2021-10-18 17:50:26 +0000
commite36eaf0003a598bc5aa688803dd1b36c24a4c051 (patch)
treea59b7dadf02241575eb0774f97c6048e2425c053 /wqflask/base/data_set.py
parentbd421438f1f0b4de913fa40cd49cfcda27e6b16f (diff)
parent04f3d13aceeaec2e52b94037d59f08ed6dc6a8bb (diff)
downloadgenenetwork2-e36eaf0003a598bc5aa688803dd1b36c24a4c051.tar.gz
Merge branch 'testing' of github.com:genenetwork/genenetwork2 into feature/remove_trait_creation_from_search
Diffstat (limited to 'wqflask/base/data_set.py')
-rw-r--r--wqflask/base/data_set.py250
1 files changed, 161 insertions, 89 deletions
diff --git a/wqflask/base/data_set.py b/wqflask/base/data_set.py
index 178234fe..8906ab69 100644
--- a/wqflask/base/data_set.py
+++ b/wqflask/base/data_set.py
@@ -17,7 +17,10 @@
 # at rwilliams@uthsc.edu and xzhou15@uthsc.edu
 #
 # This module is used by GeneNetwork project (www.genenetwork.org)
-
+from dataclasses import dataclass
+from dataclasses import field
+from dataclasses import InitVar
+from typing import Optional, Dict
 from db.call import fetchall, fetchone, fetch1
 from utility.logger import getLogger
 from utility.tools import USE_GN_SERVER, USE_REDIS, flat_files, flat_file_exists, GN2_BASE_URL
@@ -59,7 +62,8 @@ logger = getLogger(__name__)
 DS_NAME_MAP = {}
 
 
-def create_dataset(dataset_name, dataset_type=None, get_samplelist=True, group_name=None):
+def create_dataset(dataset_name, dataset_type=None,
+                   get_samplelist=True, group_name=None):
     if dataset_name == "Temp":
         dataset_type = "Temp"
 
@@ -74,11 +78,10 @@ def create_dataset(dataset_name, dataset_type=None, get_samplelist=True, group_n
         return dataset_class(dataset_name, get_samplelist)
 
 
+@dataclass
 class DatasetType:
-
-    def __init__(self, redis_instance):
-        """Create a dictionary of samples where the value is set to Geno,
-Publish or ProbeSet. E.g.
+    """Create a dictionary of samples where the value is set to Geno,
+    Publish or ProbeSet. E.g.
 
         {'AD-cases-controls-MyersGeno': 'Geno',
          'AD-cases-controls-MyersPublish': 'Publish',
@@ -89,21 +92,28 @@ Publish or ProbeSet. E.g.
          'All Phenotypes': 'Publish',
          'B139_K_1206_M': 'ProbeSet',
          'B139_K_1206_R': 'ProbeSet' ...
-
+        }
         """
+    redis_instance: InitVar[Redis]
+    datasets: Optional[Dict] = field(init=False, default_factory=dict)
+    data: Optional[Dict] = field(init=False)
+
+    def __post_init__(self, redis_instance):
         self.redis_instance = redis_instance
-        self.datasets = {}
-        data = self.redis_instance.get("dataset_structure")
+        data = redis_instance.get("dataset_structure")
         if data:
             self.datasets = json.loads(data)
-        else:  # ZS: I don't think this should ever run unless Redis is emptied
+        else:
+            # ZS: I don't think this should ever run unless Redis is
+            # emptied
             try:
                 data = json.loads(requests.get(
-                    GN2_BASE_URL + "/api/v_pre1/gen_dropdown", timeout=5).content)
-                for species in data['datasets']:
-                    for group in data['datasets'][species]:
-                        for dataset_type in data['datasets'][species][group]:
-                            for dataset in data['datasets'][species][group][dataset_type]:
+                    GN2_BASE_URL + "/api/v_pre1/gen_dropdown",
+                    timeout=5).content)
+                for _species in data['datasets']:
+                    for group in data['datasets'][_species]:
+                        for dataset_type in data['datasets'][_species][group]:
+                            for dataset in data['datasets'][_species][group][dataset_type]:
                                 short_dataset_name = dataset[1]
                                 if dataset_type == "Phenotypes":
                                     new_type = "Publish"
@@ -112,15 +122,16 @@ Publish or ProbeSet. E.g.
                                 else:
                                     new_type = "ProbeSet"
                                 self.datasets[short_dataset_name] = new_type
-            except:
+            except Exception:  # Do nothing
                 pass
 
-            self.redis_instance.set("dataset_structure", json.dumps(self.datasets))
+            self.redis_instance.set("dataset_structure",
+                                    json.dumps(self.datasets))
+        self.data = data
 
     def set_dataset_key(self, t, name):
-        """If name is not in the object's dataset dictionary, set it, and update
-        dataset_structure in Redis
-
+        """If name is not in the object's dataset dictionary, set it, and
+        update dataset_structure in Redis
         args:
           t: Type of dataset structure which can be: 'mrna_expr', 'pheno',
              'other_pheno', 'geno'
@@ -128,19 +139,20 @@ Publish or ProbeSet. E.g.
 
         """
         sql_query_mapping = {
-            'mrna_expr': ("""SELECT ProbeSetFreeze.Id FROM """ +
-                          """ProbeSetFreeze WHERE ProbeSetFreeze.Name = "{}" """),
-            'pheno': ("""SELECT InfoFiles.GN_AccesionId """ +
-                      """FROM InfoFiles, PublishFreeze, InbredSet """ +
-                      """WHERE InbredSet.Name = '{}' AND """ +
-                      """PublishFreeze.InbredSetId = InbredSet.Id AND """ +
-                      """InfoFiles.InfoPageName = PublishFreeze.Name"""),
-            'other_pheno': ("""SELECT PublishFreeze.Name """ +
-                            """FROM PublishFreeze, InbredSet """ +
-                            """WHERE InbredSet.Name = '{}' AND """ +
-                            """PublishFreeze.InbredSetId = InbredSet.Id"""),
-            'geno':  ("""SELECT GenoFreeze.Id FROM GenoFreeze WHERE """ +
-                      """GenoFreeze.Name = "{}" """)
+            'mrna_expr': ("SELECT ProbeSetFreeze.Id FROM "
+                          "ProbeSetFreeze WHERE "
+                          "ProbeSetFreeze.Name = \"%s\" "),
+            'pheno': ("SELECT InfoFiles.GN_AccesionId "
+                      "FROM InfoFiles, PublishFreeze, InbredSet "
+                      "WHERE InbredSet.Name = '%s' AND "
+                      "PublishFreeze.InbredSetId = InbredSet.Id AND "
+                      "InfoFiles.InfoPageName = PublishFreeze.Name"),
+            'other_pheno': ("SELECT PublishFreeze.Name "
+                            "FROM PublishFreeze, InbredSet "
+                            "WHERE InbredSet.Name = '%s' AND "
+                            "PublishFreeze.InbredSetId = InbredSet.Id"),
+            'geno': ("SELECT GenoFreeze.Id FROM GenoFreeze WHERE "
+                     "GenoFreeze.Name = \"%s\" ")
         }
 
         dataset_name_mapping = {
@@ -154,22 +166,23 @@ Publish or ProbeSet. E.g.
         if t in ['pheno', 'other_pheno']:
             group_name = name.replace("Publish", "")
 
-        results = g.db.execute(sql_query_mapping[t].format(group_name)).fetchone()
+        results = g.db.execute(sql_query_mapping[t] % group_name).fetchone()
         if results:
             self.datasets[name] = dataset_name_mapping[t]
-            self.redis_instance.set("dataset_structure", json.dumps(self.datasets))
+            self.redis_instance.set(
+                "dataset_structure", json.dumps(self.datasets))
             return True
-
         return None
 
     def __call__(self, name):
-
         if name not in self.datasets:
             for t in ["mrna_expr", "pheno", "other_pheno", "geno"]:
-                # This has side-effects, with the end result being a truth-y value
+                # This has side-effects, with the end result being a
+                # truth-y value
                 if(self.set_dataset_key(t, name)):
                     break
-        return self.datasets.get(name, None)  # Return None if name has not been set
+        # Return None if name has not been set
+        return self.datasets.get(name, None)
 
 
 # Do the intensive work at startup one time only
@@ -204,12 +217,12 @@ def create_datasets_list():
 
         if USE_REDIS:
             r.set(key, pickle.dumps(datasets, pickle.HIGHEST_PROTOCOL))
-            r.expire(key, 60*60)
+            r.expire(key, 60 * 60)
 
     return datasets
 
 
-class Markers(object):
+class Markers:
     """Todo: Build in cacheing so it saves us reading the same file more than once"""
 
     def __init__(self, name):
@@ -228,7 +241,8 @@ class Markers(object):
             for line in bimbam_fh:
                 marker = {}
                 marker['name'] = line.split(delimiter)[0].rstrip()
-                marker['Mb'] = float(line.split(delimiter)[1].rstrip())/1000000
+                marker['Mb'] = float(line.split(delimiter)[
+                                     1].rstrip()) / 1000000
                 marker['chr'] = line.split(delimiter)[2].rstrip()
                 markers.append(marker)
 
@@ -262,10 +276,7 @@ class Markers(object):
         elif isinstance(p_values, dict):
             filtered_markers = []
             for marker in self.markers:
-                #logger.debug("marker[name]", marker['name'])
-                #logger.debug("p_values:", p_values)
                 if marker['name'] in p_values:
-                    #logger.debug("marker {} IS in p_values".format(i))
                     marker['p_value'] = p_values[marker['name']]
                     if math.isnan(marker['p_value']) or (marker['p_value'] <= 0):
                         marker['lod_score'] = 0
@@ -276,10 +287,6 @@ class Markers(object):
                         marker['lrs_value'] = - \
                             math.log10(marker['p_value']) * 4.61
                     filtered_markers.append(marker)
-                # else:
-                    #logger.debug("marker {} NOT in p_values".format(i))
-                    # self.markers.remove(marker)
-                    #del self.markers[i]
             self.markers = filtered_markers
 
 
@@ -290,7 +297,6 @@ class HumanMarkers(Markers):
         self.markers = []
         for line in marker_data_fh:
             splat = line.strip().split()
-            #logger.debug("splat:", splat)
             if len(specified_markers) > 0:
                 if splat[1] in specified_markers:
                     marker = {}
@@ -306,13 +312,11 @@ class HumanMarkers(Markers):
                 marker['Mb'] = float(splat[3]) / 1000000
             self.markers.append(marker)
 
-        #logger.debug("markers is: ", pf(self.markers))
-
     def add_pvalues(self, p_values):
         super(HumanMarkers, self).add_pvalues(p_values)
 
 
-class DatasetGroup(object):
+class DatasetGroup:
     """
     Each group has multiple datasets; each species has multiple groups.
 
@@ -365,8 +369,8 @@ class DatasetGroup(object):
     def get_markers(self):
         def check_plink_gemma():
             if flat_file_exists("mapping"):
-                MAPPING_PATH = flat_files("mapping")+"/"
-                if os.path.isfile(MAPPING_PATH+self.name+".bed"):
+                MAPPING_PATH = flat_files("mapping") + "/"
+                if os.path.isfile(MAPPING_PATH + self.name + ".bed"):
                     return True
             return False
 
@@ -392,6 +396,15 @@ class DatasetGroup(object):
         if maternal and paternal:
             self.parlist = [maternal, paternal]
 
+    def get_study_samplelists(self):
+        study_sample_file = locate_ignore_error(self.name + ".json", 'study_sample_lists')
+        try:
+            f = open(study_sample_file)
+        except:
+            return []
+        study_samples = json.load(f)
+        return study_samples
+
     def get_genofiles(self):
         jsonfile = "%s/%s.json" % (webqtlConfig.GENODIR, self.name)
         try:
@@ -412,7 +425,7 @@ class DatasetGroup(object):
         else:
             logger.debug("Cache not hit")
 
-            genotype_fn = locate_ignore_error(self.name+".geno", 'genotype')
+            genotype_fn = locate_ignore_error(self.name + ".geno", 'genotype')
             if genotype_fn:
                 self.samplelist = get_group_samplelists.get_samplelist(
                     "geno", genotype_fn)
@@ -421,7 +434,7 @@ class DatasetGroup(object):
 
             if USE_REDIS:
                 r.set(key, json.dumps(self.samplelist))
-                r.expire(key, 60*5)
+                r.expire(key, 60 * 5)
 
     def all_samples_ordered(self):
         result = []
@@ -434,7 +447,6 @@ class DatasetGroup(object):
         # genotype_1 is Dataset Object without parents and f1
         # genotype_2 is Dataset Object with parents and f1 (not for intercross)
 
-        #genotype_1 = reaper.Dataset()
 
         # reaper barfs on unicode filenames, so here we ensure it's a string
         if self.genofile:
@@ -520,7 +532,6 @@ def datasets(group_name, this_group=None):
                     break
 
             if tissue_already_exists:
-                #logger.debug("dataset_menu:", dataset_menu[i]['datasets'])
                 dataset_menu[i]['datasets'].append((dataset, dataset_short))
             else:
                 dataset_menu.append(dict(tissue=tissue_name,
@@ -528,7 +539,7 @@ def datasets(group_name, this_group=None):
 
     if USE_REDIS:
         r.set(key, pickle.dumps(dataset_menu, pickle.HIGHEST_PROTOCOL))
-        r.expire(key, 60*5)
+        r.expire(key, 60 * 5)
 
     if this_group != None:
         this_group._datasets = dataset_menu
@@ -537,7 +548,7 @@ def datasets(group_name, this_group=None):
         return dataset_menu
 
 
-class DataSet(object):
+class DataSet:
     """
     DataSet class defines a dataset in webqtl, can be either Microarray,
     Published phenotype, genotype, or user input dataset(temp)
@@ -553,6 +564,7 @@ class DataSet(object):
         self.fullname = None
         self.type = None
         self.data_scale = None  # ZS: For example log2
+        self.accession_id = None
 
         self.setup()
 
@@ -569,14 +581,16 @@ class DataSet(object):
             self.group.get_samplelist()
         self.species = species.TheSpecies(self)
 
-    def get_desc(self):
-        """Gets overridden later, at least for Temp...used by trait's get_given_name"""
-        return None
-
-    # Delete this eventually
-    @property
-    def riset():
-        Weve_Renamed_This_As_Group
+    def as_dict(self):
+        return {
+            'name': self.name,
+            'shortname': self.shortname,
+            'fullname': self.fullname,
+            'type': self.type,
+            'data_scale': self.data_scale,
+            'group': self.group.name,
+            'accession_id': self.accession_id
+        }
 
     def get_accession_id(self):
         if self.type == "Publish":
@@ -628,7 +642,7 @@ class DataSet(object):
     WHERE ProbeSetFreeze.ProbeFreezeId = ProbeFreeze.Id
     AND ProbeFreeze.TissueId = Tissue.Id
     AND (ProbeSetFreeze.Name = '%s' OR ProbeSetFreeze.FullName = '%s' OR ProbeSetFreeze.ShortName = '%s')
-                """ % (query_args), "/dataset/"+self.name+".json",
+                """ % (query_args), "/dataset/" + self.name + ".json",
                     lambda r: (r["id"], r["name"], r["full_name"],
                                r["short_name"], r["data_scale"], r["tissue"])
                 )
@@ -651,6 +665,69 @@ class DataSet(object):
                 "Dataset {} is not yet available in GeneNetwork.".format(self.name))
             pass
 
+    def chunk_dataset(self, dataset, n):
+
+        results = {}
+
+        query = """
+                SELECT ProbeSetXRef.DataId,ProbeSet.Name
+                FROM ProbeSet, ProbeSetXRef, ProbeSetFreeze
+                WHERE ProbeSetFreeze.Name = '{}' AND
+                      ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id AND
+                      ProbeSetXRef.ProbeSetId = ProbeSet.Id
+        """.format(self.name)
+
+        # should cache this
+
+        traits_name_dict = dict(g.db.execute(query).fetchall())
+
+        for i in range(0, len(dataset), n):
+            matrix = list(dataset[i:i + n])
+            trait_name = traits_name_dict[matrix[0][0]]
+
+            my_values = [value for (trait_name, strain, value) in matrix]
+            results[trait_name] = my_values
+        return results
+
+    def get_probeset_data(self, sample_list=None, trait_ids=None):
+
+        # improvement of get trait data--->>>
+        if sample_list:
+            self.samplelist = sample_list
+
+        else:
+            self.samplelist = self.group.samplelist
+
+        if self.group.parlist != None and self.group.f1list != None:
+            if (self.group.parlist + self.group.f1list) in self.samplelist:
+                self.samplelist += self.group.parlist + self.group.f1list
+
+        query = """
+            SELECT Strain.Name, Strain.Id FROM Strain, Species
+            WHERE Strain.Name IN {}
+            and Strain.SpeciesId=Species.Id
+            and Species.name = '{}'
+            """.format(create_in_clause(self.samplelist), *mescape(self.group.species))
+        results = dict(g.db.execute(query).fetchall())
+        sample_ids = [results[item] for item in self.samplelist]
+
+        sorted_samplelist = [strain_name for strain_name, strain_id in sorted(
+            results.items(), key=lambda item: item[1])]
+
+        query = """SELECT * from ProbeSetData
+                where StrainID in {}
+                and id in (SELECT ProbeSetXRef.DataId
+                FROM (ProbeSet, ProbeSetXRef, ProbeSetFreeze)
+                WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id
+                and ProbeSetFreeze.Name = '{}'
+                and ProbeSet.Id = ProbeSetXRef.ProbeSetId)""".format(create_in_clause(sample_ids), self.name)
+
+        query_results = list(g.db.execute(query).fetchall())
+        data_results = self.chunk_dataset(query_results, len(sample_ids))
+        self.samplelist = sorted_samplelist
+        self.trait_data = data_results
+        
+
     def get_trait_data(self, sample_list=None):
         if sample_list:
             self.samplelist = sample_list
@@ -667,7 +744,6 @@ class DataSet(object):
             and Strain.SpeciesId=Species.Id
             and Species.name = '{}'
             """.format(create_in_clause(self.samplelist), *mescape(self.group.species))
-        logger.sql(query)
         results = dict(g.db.execute(query).fetchall())
         sample_ids = [results[item] for item in self.samplelist]
 
@@ -735,9 +811,6 @@ class PhenotypeDataSet(DataSet):
     DS_NAME_MAP['Publish'] = 'PhenotypeDataSet'
 
     def setup(self):
-
-        #logger.debug("IS A PHENOTYPEDATASET")
-
         # Fields in the database table
         self.search_fields = ['Phenotype.Post_publication_description',
                               'Phenotype.Pre_publication_description',
@@ -841,7 +914,6 @@ class PhenotypeDataSet(DataSet):
                         Geno.Name = '%s' and
                         Geno.SpeciesId = Species.Id
                 """ % (species, this_trait.locus)
-                logger.sql(query)
                 result = g.db.execute(query).fetchone()
 
                 if result:
@@ -871,7 +943,6 @@ class PhenotypeDataSet(DataSet):
                     Order BY
                             Strain.Name
                     """
-        logger.sql(query)
         results = g.db.execute(query, (trait, self.id)).fetchall()
         return results
 
@@ -938,7 +1009,6 @@ class GenotypeDataSet(DataSet):
                     Order BY
                             Strain.Name
                     """
-        logger.sql(query)
         results = g.db.execute(query,
                                (webqtlDatabaseFunction.retrieve_species_id(self.group.name),
                                 trait, self.name)).fetchall()
@@ -1040,8 +1110,8 @@ class MrnaAssayDataSet(DataSet):
             else:
                 description_display = this_trait.symbol
 
-            if (len(description_display) > 1 and description_display != 'N/A' and
-                    len(target_string) > 1 and target_string != 'None'):
+            if (len(description_display) > 1 and description_display != 'N/A'
+                    and len(target_string) > 1 and target_string != 'None'):
                 description_display = description_display + '; ' + target_string.strip()
 
             # Save it for the jinja2 template
@@ -1059,9 +1129,6 @@ class MrnaAssayDataSet(DataSet):
                 ProbeSet.Name = '%s'
             """ % (escape(str(this_trait.dataset.id)),
                    escape(this_trait.name)))
-
-            #logger.debug("query is:", pf(query))
-            logger.sql(query)
             result = g.db.execute(query).fetchone()
 
             mean = result[0] if result else 0
@@ -1081,7 +1148,6 @@ class MrnaAssayDataSet(DataSet):
                         Geno.Name = '{}' and
                         Geno.SpeciesId = Species.Id
                 """.format(species, this_trait.locus)
-                logger.sql(query)
                 result = g.db.execute(query).fetchone()
 
                 if result:
@@ -1097,7 +1163,8 @@ class MrnaAssayDataSet(DataSet):
                     SELECT
                             Strain.Name, ProbeSetData.value, ProbeSetSE.error, NStrain.count, Strain.Name2
                     FROM
-                            (ProbeSetData, ProbeSetFreeze, Strain, ProbeSet, ProbeSetXRef)
+                            (ProbeSetData, ProbeSetFreeze,
+                             Strain, ProbeSet, ProbeSetXRef)
                     left join ProbeSetSE on
                             (ProbeSetSE.DataId = ProbeSetData.Id AND ProbeSetSE.StrainId = ProbeSetData.StrainId)
                     left join NStrain on
@@ -1112,9 +1179,7 @@ class MrnaAssayDataSet(DataSet):
                     Order BY
                             Strain.Name
                     """ % (escape(trait), escape(self.name))
-        logger.sql(query)
         results = g.db.execute(query).fetchall()
-        #logger.debug("RETRIEVED RESULTS HERE:", results)
         return results
 
     def retrieve_genes(self, column_name):
@@ -1124,7 +1189,6 @@ class MrnaAssayDataSet(DataSet):
                     where ProbeSetXRef.ProbeSetFreezeId = %s and
                     ProbeSetXRef.ProbeSetId=ProbeSet.Id;
                 """ % (column_name, escape(str(self.id)))
-        logger.sql(query)
         results = g.db.execute(query).fetchall()
 
         return dict(results)
@@ -1155,11 +1219,19 @@ class TempDataSet(DataSet):
 
 def geno_mrna_confidentiality(ob):
     dataset_table = ob.type + "Freeze"
-    #logger.debug("dataset_table [%s]: %s" % (type(dataset_table), dataset_table))
 
     query = '''SELECT Id, Name, FullName, confidentiality,
                         AuthorisedUsers FROM %s WHERE Name = "%s"''' % (dataset_table, ob.name)
-    logger.sql(query)
+    result = g.db.execute(query)
+
+    (dataset_id,
+     name,
+     full_name,
+     confidential,
+     authorized_users) = result.fetchall()[0]
+
+    if confidential:
+        return True
     result = g.db.execute(query)
 
     (dataset_id,