diff options
Diffstat (limited to 'wqflask/base/data_set/dataset.py')
-rw-r--r-- | wqflask/base/data_set/dataset.py | 292 |
1 files changed, 292 insertions, 0 deletions
diff --git a/wqflask/base/data_set/dataset.py b/wqflask/base/data_set/dataset.py new file mode 100644 index 00000000..f035e028 --- /dev/null +++ b/wqflask/base/data_set/dataset.py @@ -0,0 +1,292 @@ +"Base Dataset class ..." + +import math +import collections + + +from redis import Redis + + +from base import species +from utility import chunks +from .datasetgroup import DatasetGroup +from wqflask.database import database_connection +from utility.db_tools import escape, mescape, create_in_clause +from .utils import fetch_cached_results, cache_dataset_results + +class DataSet: + """ + DataSet class defines a dataset in webqtl, can be either Microarray, + Published phenotype, genotype, or user input dataset(temp) + + """ + + def __init__(self, name, get_samplelist=True, group_name=None, redis_conn=Redis()): + + assert name, "Need a name" + self.name = name + self.id = None + self.shortname = None + self.fullname = None + self.type = None + self.data_scale = None # ZS: For example log2 + self.accession_id = None + + self.setup() + + if self.type == "Temp": # Need to supply group name as input if temp trait + # sets self.group and self.group_id and gets genotype + self.group = DatasetGroup(self, name=group_name) + else: + self.check_confidentiality() + self.retrieve_other_names() + # sets self.group and self.group_id and gets genotype + self.group = DatasetGroup(self) + self.accession_id = self.get_accession_id() + if get_samplelist == True: + self.group.get_samplelist(redis_conn) + self.species = species.TheSpecies(self) + + def as_dict(self): + return { + 'name': self.name, + 'shortname': self.shortname, + 'fullname': self.fullname, + 'type': self.type, + 'data_scale': self.data_scale, + 'group': self.group.name, + 'accession_id': self.accession_id + } + + def get_accession_id(self): + results = None + with database_connection() as conn, conn.cursor() as cursor: + if self.type == "Publish": + cursor.execute( + "SELECT InfoFiles.GN_AccesionId FROM " + "InfoFiles, PublishFreeze, InbredSet " + "WHERE InbredSet.Name = %s AND " + "PublishFreeze.InbredSetId = InbredSet.Id " + "AND InfoFiles.InfoPageName = PublishFreeze.Name " + "AND PublishFreeze.public > 0 AND " + "PublishFreeze.confidentiality < 1 " + "ORDER BY PublishFreeze.CreateTime DESC", + (self.group.name,) + ) + results = cursor.fetchone() + elif self.type == "Geno": + cursor.execute( + "SELECT InfoFiles.GN_AccesionId FROM " + "InfoFiles, GenoFreeze, InbredSet " + "WHERE InbredSet.Name = %s AND " + "GenoFreeze.InbredSetId = InbredSet.Id " + "AND InfoFiles.InfoPageName = GenoFreeze.ShortName " + "AND GenoFreeze.public > 0 AND " + "GenoFreeze.confidentiality < 1 " + "ORDER BY GenoFreeze.CreateTime DESC", + (self.group.name,) + ) + results = cursor.fetchone() + + if results: + return str(results[0]) + return "None" + + def retrieve_other_names(self): + """This method fetches the the dataset names in search_result. + + If the data set name parameter is not found in the 'Name' field of + the data set table, check if it is actually the FullName or + ShortName instead. + + This is not meant to retrieve the data set info if no name at + all is passed. + + """ + with database_connection() as conn, conn.cursor() as cursor: + try: + if self.type == "ProbeSet": + cursor.execute( + "SELECT ProbeSetFreeze.Id, ProbeSetFreeze.Name, " + "ProbeSetFreeze.FullName, ProbeSetFreeze.ShortName, " + "ProbeSetFreeze.DataScale, Tissue.Name " + "FROM ProbeSetFreeze, ProbeFreeze, Tissue " + "WHERE ProbeSetFreeze.ProbeFreezeId = ProbeFreeze.Id " + "AND ProbeFreeze.TissueId = Tissue.Id " + "AND (ProbeSetFreeze.Name = %s OR " + "ProbeSetFreeze.FullName = %s " + "OR ProbeSetFreeze.ShortName = %s)", + (self.name,)*3) + (self.id, self.name, self.fullname, self.shortname, + self.data_scale, self.tissue) = cursor.fetchone() + else: + self.tissue = "N/A" + cursor.execute( + "SELECT Id, Name, FullName, ShortName " + f"FROM {self.type}Freeze " + "WHERE (Name = %s OR FullName = " + "%s OR ShortName = %s)", + (self.name,)*3) + (self.id, self.name, self.fullname, + self.shortname) = cursor.fetchone() + except TypeError: + pass + + def chunk_dataset(self, dataset, n): + + results = {} + traits_name_dict = () + with database_connection() as conn, conn.cursor() as cursor: + cursor.execute( + "SELECT ProbeSetXRef.DataId,ProbeSet.Name " + "FROM ProbeSet, ProbeSetXRef, ProbeSetFreeze " + "WHERE ProbeSetFreeze.Name = %s AND " + "ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id " + "AND ProbeSetXRef.ProbeSetId = ProbeSet.Id", + (self.name,)) + # should cache this + traits_name_dict = dict(cursor.fetchall()) + + for i in range(0, len(dataset), n): + matrix = list(dataset[i:i + n]) + trait_name = traits_name_dict[matrix[0][0]] + + my_values = [value for (trait_name, strain, value) in matrix] + results[trait_name] = my_values + return results + + def get_probeset_data(self, sample_list=None, trait_ids=None): + + # improvement of get trait data--->>> + if sample_list: + self.samplelist = sample_list + + else: + self.samplelist = self.group.samplelist + + if self.group.parlist != None and self.group.f1list != None: + if (self.group.parlist + self.group.f1list) in self.samplelist: + self.samplelist += self.group.parlist + self.group.f1list + with database_connection() as conn, conn.cursor() as cursor: + cursor.execute( + "SELECT Strain.Name, Strain.Id FROM " + "Strain, Species WHERE Strain.Name IN " + f"{create_in_clause(self.samplelist)} " + "AND Strain.SpeciesId=Species.Id AND " + "Species.name = %s", (self.group.species,) + ) + results = dict(cursor.fetchall()) + sample_ids = [results[item] for item in self.samplelist] + + sorted_samplelist = [strain_name for strain_name, strain_id in sorted( + results.items(), key=lambda item: item[1])] + + cursor.execute( + "SELECT * from ProbeSetData WHERE StrainID IN " + f"{create_in_clause(sample_ids)} AND id IN " + "(SELECT ProbeSetXRef.DataId FROM " + "(ProbeSet, ProbeSetXRef, ProbeSetFreeze) " + "WHERE ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id " + "AND ProbeSetFreeze.Name = %s AND " + "ProbeSet.Id = ProbeSetXRef.ProbeSetId)", + (self.name,) + ) + + query_results = list(cursor.fetchall()) + data_results = self.chunk_dataset(query_results, len(sample_ids)) + self.samplelist = sorted_samplelist + self.trait_data = data_results + + def get_trait_data(self, sample_list=None): + if sample_list: + self.samplelist = sample_list + else: + self.samplelist = self.group.samplelist + + if self.group.parlist != None and self.group.f1list != None: + if (self.group.parlist + self.group.f1list) in self.samplelist: + self.samplelist += self.group.parlist + self.group.f1list + + with database_connection() as conn, conn.cursor() as cursor: + cursor.execute( + "SELECT Strain.Name, Strain.Id FROM Strain, Species " + f"WHERE Strain.Name IN {create_in_clause(self.samplelist)} " + "AND Strain.SpeciesId=Species.Id " + "AND Species.name = %s", + (self.group.species,) + ) + results = dict(cursor.fetchall()) + sample_ids = [ + sample_id for sample_id in + (results.get(item) for item in self.samplelist + if item is not None) + if sample_id is not None + ] + + # MySQL limits the number of tables that can be used in a join to 61, + # so we break the sample ids into smaller chunks + # Postgres doesn't have that limit, so we can get rid of this after we transition + chunk_size = 50 + number_chunks = int(math.ceil(len(sample_ids) / chunk_size)) + + cached_results = fetch_cached_results(self.name, self.type, self.samplelist) + + if cached_results is None: + trait_sample_data = [] + for sample_ids_step in chunks.divide_into_chunks(sample_ids, number_chunks): + if self.type == "Publish": + dataset_type = "Phenotype" + else: + dataset_type = self.type + temp = ['T%s.value' % item for item in sample_ids_step] + if self.type == "Publish": + query = "SELECT {}XRef.Id".format(escape(self.type)) + else: + query = "SELECT {}.Name".format(escape(dataset_type)) + data_start_pos = 1 + if len(temp) > 0: + query = query + ", " + ', '.join(temp) + query += ' FROM ({}, {}XRef, {}Freeze) '.format(*mescape(dataset_type, + self.type, + self.type)) + + for item in sample_ids_step: + query += """ + left join {}Data as T{} on T{}.Id = {}XRef.DataId + and T{}.StrainId={}\n + """.format(*mescape(self.type, item, item, self.type, item, item)) + + if self.type == "Publish": + query += """ + WHERE {}XRef.InbredSetId = {}Freeze.InbredSetId + and {}Freeze.Name = '{}' + and {}.Id = {}XRef.{}Id + order by {}.Id + """.format(*mescape(self.type, self.type, self.type, self.name, + dataset_type, self.type, dataset_type, dataset_type)) + else: + query += """ + WHERE {}XRef.{}FreezeId = {}Freeze.Id + and {}Freeze.Name = '{}' + and {}.Id = {}XRef.{}Id + order by {}.Id + """.format(*mescape(self.type, self.type, self.type, self.type, + self.name, dataset_type, self.type, self.type, dataset_type)) + cursor.execute(query) + results = cursor.fetchall() + trait_sample_data.append([list(result) for result in results]) + + trait_count = len(trait_sample_data[0]) + self.trait_data = collections.defaultdict(list) + + data_start_pos = 1 + for trait_counter in range(trait_count): + trait_name = trait_sample_data[0][trait_counter][0] + for chunk_counter in range(int(number_chunks)): + self.trait_data[trait_name] += ( + trait_sample_data[chunk_counter][trait_counter][data_start_pos:]) + + cache_dataset_results( + self.name, self.type, self.samplelist, self.trait_data) + else: + self.trait_data = cached_results |