diff options
author | Munyoki Kilyungi | 2022-09-07 11:47:26 +0300 |
---|---|---|
committer | BonfaceKilz | 2022-09-08 14:26:19 +0300 |
commit | b6fadcc7e8de9d7041f2bb00dcd1ada518a1932c (patch) | |
tree | cc0199f48307568ee2bdab16d7ec0c8774fda50e | |
parent | f879f82ad0393a770ed50043c70ee1dd4a12daaa (diff) | |
download | genenetwork2-b6fadcc7e8de9d7041f2bb00dcd1ada518a1932c.tar.gz |
Inject database connection to mrna_assay_tissue_data class
* wqflask/base/mrna_assay_tissue_data.py: Imports: Delete
database_connection, escape, and database_connector.
(MrnaAssayTissueData): Inject conn. Re-format queries. Rework 'if
... else' logic. Re-work how results are assigned to
'self.data[symbol]' - remove dot-notation.
(MrnaAssayTissueData.get_symbol_values_pairs): Move box-comments to
doc-string. Rework how results are assigned to 'symbol_values_dict' -
remove dot-notation.
* wqflask/tests/unit/base/test_mrna_assay_tissue_data.py
(test_mrna_assay_tissue_data_initialisation): New test.
* wqflask/wqflask/correlation/correlation_functions.py: Import
database_connection.
(get_trait_symbol_and_tissue_values): Inject conn object.
-rw-r--r-- | wqflask/base/mrna_assay_tissue_data.py | 131 | ||||
-rw-r--r-- | wqflask/tests/unit/base/test_mrna_assay_tissue_data.py | 51 | ||||
-rw-r--r-- | wqflask/wqflask/correlation/correlation_functions.py | 11 |
3 files changed, 123 insertions, 70 deletions
diff --git a/wqflask/base/mrna_assay_tissue_data.py b/wqflask/base/mrna_assay_tissue_data.py index b371e39f..a229151d 100644 --- a/wqflask/base/mrna_assay_tissue_data.py +++ b/wqflask/base/mrna_assay_tissue_data.py @@ -1,81 +1,82 @@ import collections -from wqflask.database import database_connection - from utility import db_tools from utility import Bunch -from utility.db_tools import escape -from gn3.db_utils import database_connector - class MrnaAssayTissueData: - def __init__(self, gene_symbols=None): + def __init__(self, conn, gene_symbols=None): self.gene_symbols = gene_symbols - if self.gene_symbols == None: + self.conn = conn + if self.gene_symbols is None: self.gene_symbols = [] self.data = collections.defaultdict(Bunch) - - query = '''select t.Symbol, t.GeneId, t.DataId, t.Chr, t.Mb, t.description, t.Probe_Target_Description - from ( - select Symbol, max(Mean) as maxmean - from TissueProbeSetXRef - where TissueProbeSetFreezeId=1 and ''' - - # Note that inner join is necessary in this query to get distinct record in one symbol group - # with highest mean value + results = () + # Note that inner join is necessary in this query to get + # distinct record in one symbol group with highest mean value # Due to the limit size of TissueProbeSetFreezeId table in DB, - # performance of inner join is acceptable.MrnaAssayTissueData(gene_symbols=symbol_list) - if len(gene_symbols) == 0: - query += '''Symbol!='' and Symbol Is Not Null group by Symbol) - as x inner join TissueProbeSetXRef as t on t.Symbol = x.Symbol - and t.Mean = x.maxmean; - ''' - else: - in_clause = db_tools.create_in_clause(gene_symbols) - - # ZS: This was in the query, not sure why: http://docs.python.org/2/library/string.html?highlight=lower#string.lower - query += ''' Symbol in {} group by Symbol) - as x inner join TissueProbeSetXRef as t on t.Symbol = x.Symbol - and t.Mean = x.maxmean; - '''.format(in_clause) - - - # lower_symbols = [] + # performance of inner join is + # acceptable.MrnaAssayTissueData(gene_symbols=symbol_list) + with conn.cursor() as cursor: + if len(self.gene_symbols) == 0: + cursor.execute( + "SELECT t.Symbol, t.GeneId, t.DataId, " + "t.Chr, t.Mb, t.description, " + "t.Probe_Target_Description FROM (SELECT Symbol, " + "max(Mean) AS maxmean " + "FROM TissueProbeSetXRef WHERE " + "TissueProbeSetFreezeId=1 AND " + "Symbol != '' AND Symbol IS NOT " + "Null GROUP BY Symbol) " + "AS x INNER JOIN " + "TissueProbeSetXRef AS t ON " + "t.Symbol = x.Symbol " + "AND t.Mean = x.maxmean") + else: + cursor.execute( + "SELECT t.Symbol, t.GeneId, t.DataId, " + "t.Chr, t.Mb, t.description, " + "t.Probe_Target_Description FROM (SELECT Symbol, " + "max(Mean) AS maxmean " + "FROM TissueProbeSetXRef WHERE " + "TissueProbeSetFreezeId=1 AND " + "Symbol IN " + f"({', '.join(['%s'] * len(self.gene_symbols))}) " + "GROUP BY Symbol) AS x INNER JOIN " + "TissueProbeSetXRef AS t ON t.Symbol = x.Symbol " + "AND t.Mean = x.maxmean", + tuple(self.gene_symbols)) + results = list(cursor.fetchall()) lower_symbols = {} - for gene_symbol in gene_symbols: - # lower_symbols[gene_symbol.lower()] = True - if gene_symbol != None: + for gene_symbol in self.gene_symbols: + if gene_symbol is not None: lower_symbols[gene_symbol.lower()] = True - results = None - with database_connection() as conn, conn.cursor() as cursor: - cursor.execute(query) - results = cursor.fetchall() for result in results: - symbol = result[0] - if symbol is not None and lower_symbols.get(symbol.lower()): - + (symbol, gene_id, data_id, _chr, _mb, + descr, probeset_target_descr) = result + if symbol is not None and lower_symbols.get(symbol.lower()): symbol = symbol.lower() + self.data[symbol].gene_id = gene_id + self.data[symbol].data_id = data_id + self.data[symbol].chr = _chr + self.data[symbol].mb = _mb + self.data[symbol].description = descr + (self.data[symbol] + .probe_target_description) = probeset_target_descr - self.data[symbol].gene_id = result.GeneId - self.data[symbol].data_id = result.DataId - self.data[symbol].chr = result.Chr - self.data[symbol].mb = result.Mb - self.data[symbol].description = result.description - self.data[symbol].probe_target_description = result.Probe_Target_Description - - ########################################################################### - # Input: cursor, symbolList (list), dataIdDict(Dict) - # output: symbolValuepairDict (dictionary):one dictionary of Symbol and Value Pair, - # key is symbol, value is one list of expression values of one probeSet; - # function: get one dictionary whose key is gene symbol and value is tissue expression data (list type). - # Attention! All keys are lower case! - ########################################################################### def get_symbol_values_pairs(self): + """Get one dictionary whose key is gene symbol and value is + tissue expression data (list type). All keys are lower case. + + The output is a symbolValuepairDict (dictionary): one + dictionary of Symbol and Value Pair; key is symbol, value is + one list of expression values of one probeSet; + + """ id_list = [self.data[symbol].data_id for symbol in self.data] symbol_values_dict = {} @@ -86,14 +87,14 @@ class MrnaAssayTissueData: WHERE TissueProbeSetData.Id IN {} and TissueProbeSetXRef.DataId = TissueProbeSetData.Id""".format(db_tools.create_in_clause(id_list)) results = [] - with database_connection() as conn, conn.cursor() as cursor: + with self.conn.cursor() as cursor: cursor.execute(query) results = cursor.fetchall() - for result in results: - if result.Symbol.lower() not in symbol_values_dict: - symbol_values_dict[result.Symbol.lower()] = [result.value] - else: - symbol_values_dict[result.Symbol.lower()].append( - result.value) - + for result in results: + (symbol, value) = result + if symbol.lower() not in symbol_values_dict: + symbol_values_dict[symbol.lower()] = [value] + else: + symbol_values_dict[symbol.lower()].append( + value) return symbol_values_dict diff --git a/wqflask/tests/unit/base/test_mrna_assay_tissue_data.py b/wqflask/tests/unit/base/test_mrna_assay_tissue_data.py new file mode 100644 index 00000000..7a21124a --- /dev/null +++ b/wqflask/tests/unit/base/test_mrna_assay_tissue_data.py @@ -0,0 +1,51 @@ +import pytest +from base.mrna_assay_tissue_data import MrnaAssayTissueData + + +@pytest.mark.parametrize( + ('gene_symbols', 'expected_query', 'sql_fetch_all_results'), + ( + (None, + (("SELECT t.Symbol, t.GeneId, t.DataId, " + "t.Chr, t.Mb, t.description, " + "t.Probe_Target_Description " + "FROM (SELECT Symbol, " + "max(Mean) AS maxmean " + "FROM TissueProbeSetXRef WHERE " + "TissueProbeSetFreezeId=1 AND " + "Symbol != '' AND Symbol IS NOT " + "Null GROUP BY Symbol) " + "AS x INNER JOIN TissueProbeSetXRef " + "AS t ON t.Symbol = x.Symbol " + "AND t.Mean = x.maxmean"),), + (("symbol", "gene_id", + "data_id", "chr", "mb", + "description", + "probe_target_description"),)), + (["k1", "k2", "k3"], + ("SELECT t.Symbol, t.GeneId, t.DataId, " + "t.Chr, t.Mb, t.description, " + "t.Probe_Target_Description FROM (SELECT Symbol, " + "max(Mean) AS maxmean " + "FROM TissueProbeSetXRef WHERE " + "TissueProbeSetFreezeId=1 AND " + "Symbol IN (%s, %s, %s) " + "GROUP BY Symbol) AS x INNER JOIN " + "TissueProbeSetXRef AS " + "t ON t.Symbol = x.Symbol " + "AND t.Mean = x.maxmean", + ("k1", "k2", "k3")), + (("k1", "203", + "112", "xy", "20.11", + "Sample Description", + "Sample Probe Target Description"),)), + ), +) +def test_mrna_assay_tissue_data_initialisation(mocker, gene_symbols, + expected_query, + sql_fetch_all_results): + mock_conn = mocker.MagicMock() + with mock_conn.cursor() as cursor: + cursor.fetchall.return_value = sql_fetch_all_results + MrnaAssayTissueData(conn=mock_conn, gene_symbols=gene_symbols) + cursor.execute.assert_called_with(*expected_query) diff --git a/wqflask/wqflask/correlation/correlation_functions.py b/wqflask/wqflask/correlation/correlation_functions.py index 85b25d60..5c01b0ac 100644 --- a/wqflask/wqflask/correlation/correlation_functions.py +++ b/wqflask/wqflask/correlation/correlation_functions.py @@ -25,7 +25,7 @@ from base.mrna_assay_tissue_data import MrnaAssayTissueData from gn3.computations.correlations import compute_corr_coeff_p_value - +from wqflask.database import database_connection ##################################################################################### # Input: primaryValue(list): one list of expression values of one probeSet, @@ -60,7 +60,8 @@ def cal_zero_order_corr_for_tiss(primary_values, target_values, method="pearson" def get_trait_symbol_and_tissue_values(symbol_list=None): - tissue_data = MrnaAssayTissueData(gene_symbols=symbol_list) - if len(tissue_data.gene_symbols) > 0: - results = tissue_data.get_symbol_values_pairs() - return results + with database_connection() as conn: + tissue_data = MrnaAssayTissueData(gene_symbols=symbol_list, conn=conn) + if len(tissue_data.gene_symbols) > 0: + results = tissue_data.get_symbol_values_pairs() + return results |