aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--wqflask/base/mrna_assay_tissue_data.py131
-rw-r--r--wqflask/tests/unit/base/test_mrna_assay_tissue_data.py51
-rw-r--r--wqflask/wqflask/correlation/correlation_functions.py11
3 files changed, 123 insertions, 70 deletions
diff --git a/wqflask/base/mrna_assay_tissue_data.py b/wqflask/base/mrna_assay_tissue_data.py
index b371e39f..a229151d 100644
--- a/wqflask/base/mrna_assay_tissue_data.py
+++ b/wqflask/base/mrna_assay_tissue_data.py
@@ -1,81 +1,82 @@
import collections
-from wqflask.database import database_connection
-
from utility import db_tools
from utility import Bunch
-from utility.db_tools import escape
-from gn3.db_utils import database_connector
-
class MrnaAssayTissueData:
- def __init__(self, gene_symbols=None):
+ def __init__(self, conn, gene_symbols=None):
self.gene_symbols = gene_symbols
- if self.gene_symbols == None:
+ self.conn = conn
+ if self.gene_symbols is None:
self.gene_symbols = []
self.data = collections.defaultdict(Bunch)
-
- query = '''select t.Symbol, t.GeneId, t.DataId, t.Chr, t.Mb, t.description, t.Probe_Target_Description
- from (
- select Symbol, max(Mean) as maxmean
- from TissueProbeSetXRef
- where TissueProbeSetFreezeId=1 and '''
-
- # Note that inner join is necessary in this query to get distinct record in one symbol group
- # with highest mean value
+ results = ()
+ # Note that inner join is necessary in this query to get
+ # distinct record in one symbol group with highest mean value
# Due to the limit size of TissueProbeSetFreezeId table in DB,
- # performance of inner join is acceptable.MrnaAssayTissueData(gene_symbols=symbol_list)
- if len(gene_symbols) == 0:
- query += '''Symbol!='' and Symbol Is Not Null group by Symbol)
- as x inner join TissueProbeSetXRef as t on t.Symbol = x.Symbol
- and t.Mean = x.maxmean;
- '''
- else:
- in_clause = db_tools.create_in_clause(gene_symbols)
-
- # ZS: This was in the query, not sure why: http://docs.python.org/2/library/string.html?highlight=lower#string.lower
- query += ''' Symbol in {} group by Symbol)
- as x inner join TissueProbeSetXRef as t on t.Symbol = x.Symbol
- and t.Mean = x.maxmean;
- '''.format(in_clause)
-
-
- # lower_symbols = []
+ # performance of inner join is
+ # acceptable.MrnaAssayTissueData(gene_symbols=symbol_list)
+ with conn.cursor() as cursor:
+ if len(self.gene_symbols) == 0:
+ cursor.execute(
+ "SELECT t.Symbol, t.GeneId, t.DataId, "
+ "t.Chr, t.Mb, t.description, "
+ "t.Probe_Target_Description FROM (SELECT Symbol, "
+ "max(Mean) AS maxmean "
+ "FROM TissueProbeSetXRef WHERE "
+ "TissueProbeSetFreezeId=1 AND "
+ "Symbol != '' AND Symbol IS NOT "
+ "Null GROUP BY Symbol) "
+ "AS x INNER JOIN "
+ "TissueProbeSetXRef AS t ON "
+ "t.Symbol = x.Symbol "
+ "AND t.Mean = x.maxmean")
+ else:
+ cursor.execute(
+ "SELECT t.Symbol, t.GeneId, t.DataId, "
+ "t.Chr, t.Mb, t.description, "
+ "t.Probe_Target_Description FROM (SELECT Symbol, "
+ "max(Mean) AS maxmean "
+ "FROM TissueProbeSetXRef WHERE "
+ "TissueProbeSetFreezeId=1 AND "
+ "Symbol IN "
+ f"({', '.join(['%s'] * len(self.gene_symbols))}) "
+ "GROUP BY Symbol) AS x INNER JOIN "
+ "TissueProbeSetXRef AS t ON t.Symbol = x.Symbol "
+ "AND t.Mean = x.maxmean",
+ tuple(self.gene_symbols))
+ results = list(cursor.fetchall())
lower_symbols = {}
- for gene_symbol in gene_symbols:
- # lower_symbols[gene_symbol.lower()] = True
- if gene_symbol != None:
+ for gene_symbol in self.gene_symbols:
+ if gene_symbol is not None:
lower_symbols[gene_symbol.lower()] = True
- results = None
- with database_connection() as conn, conn.cursor() as cursor:
- cursor.execute(query)
- results = cursor.fetchall()
for result in results:
- symbol = result[0]
- if symbol is not None and lower_symbols.get(symbol.lower()):
-
+ (symbol, gene_id, data_id, _chr, _mb,
+ descr, probeset_target_descr) = result
+ if symbol is not None and lower_symbols.get(symbol.lower()):
symbol = symbol.lower()
+ self.data[symbol].gene_id = gene_id
+ self.data[symbol].data_id = data_id
+ self.data[symbol].chr = _chr
+ self.data[symbol].mb = _mb
+ self.data[symbol].description = descr
+ (self.data[symbol]
+ .probe_target_description) = probeset_target_descr
- self.data[symbol].gene_id = result.GeneId
- self.data[symbol].data_id = result.DataId
- self.data[symbol].chr = result.Chr
- self.data[symbol].mb = result.Mb
- self.data[symbol].description = result.description
- self.data[symbol].probe_target_description = result.Probe_Target_Description
-
- ###########################################################################
- # Input: cursor, symbolList (list), dataIdDict(Dict)
- # output: symbolValuepairDict (dictionary):one dictionary of Symbol and Value Pair,
- # key is symbol, value is one list of expression values of one probeSet;
- # function: get one dictionary whose key is gene symbol and value is tissue expression data (list type).
- # Attention! All keys are lower case!
- ###########################################################################
def get_symbol_values_pairs(self):
+ """Get one dictionary whose key is gene symbol and value is
+ tissue expression data (list type). All keys are lower case.
+
+ The output is a symbolValuepairDict (dictionary): one
+ dictionary of Symbol and Value Pair; key is symbol, value is
+ one list of expression values of one probeSet;
+
+ """
id_list = [self.data[symbol].data_id for symbol in self.data]
symbol_values_dict = {}
@@ -86,14 +87,14 @@ class MrnaAssayTissueData:
WHERE TissueProbeSetData.Id IN {} and
TissueProbeSetXRef.DataId = TissueProbeSetData.Id""".format(db_tools.create_in_clause(id_list))
results = []
- with database_connection() as conn, conn.cursor() as cursor:
+ with self.conn.cursor() as cursor:
cursor.execute(query)
results = cursor.fetchall()
- for result in results:
- if result.Symbol.lower() not in symbol_values_dict:
- symbol_values_dict[result.Symbol.lower()] = [result.value]
- else:
- symbol_values_dict[result.Symbol.lower()].append(
- result.value)
-
+ for result in results:
+ (symbol, value) = result
+ if symbol.lower() not in symbol_values_dict:
+ symbol_values_dict[symbol.lower()] = [value]
+ else:
+ symbol_values_dict[symbol.lower()].append(
+ value)
return symbol_values_dict
diff --git a/wqflask/tests/unit/base/test_mrna_assay_tissue_data.py b/wqflask/tests/unit/base/test_mrna_assay_tissue_data.py
new file mode 100644
index 00000000..7a21124a
--- /dev/null
+++ b/wqflask/tests/unit/base/test_mrna_assay_tissue_data.py
@@ -0,0 +1,51 @@
+import pytest
+from base.mrna_assay_tissue_data import MrnaAssayTissueData
+
+
+@pytest.mark.parametrize(
+ ('gene_symbols', 'expected_query', 'sql_fetch_all_results'),
+ (
+ (None,
+ (("SELECT t.Symbol, t.GeneId, t.DataId, "
+ "t.Chr, t.Mb, t.description, "
+ "t.Probe_Target_Description "
+ "FROM (SELECT Symbol, "
+ "max(Mean) AS maxmean "
+ "FROM TissueProbeSetXRef WHERE "
+ "TissueProbeSetFreezeId=1 AND "
+ "Symbol != '' AND Symbol IS NOT "
+ "Null GROUP BY Symbol) "
+ "AS x INNER JOIN TissueProbeSetXRef "
+ "AS t ON t.Symbol = x.Symbol "
+ "AND t.Mean = x.maxmean"),),
+ (("symbol", "gene_id",
+ "data_id", "chr", "mb",
+ "description",
+ "probe_target_description"),)),
+ (["k1", "k2", "k3"],
+ ("SELECT t.Symbol, t.GeneId, t.DataId, "
+ "t.Chr, t.Mb, t.description, "
+ "t.Probe_Target_Description FROM (SELECT Symbol, "
+ "max(Mean) AS maxmean "
+ "FROM TissueProbeSetXRef WHERE "
+ "TissueProbeSetFreezeId=1 AND "
+ "Symbol IN (%s, %s, %s) "
+ "GROUP BY Symbol) AS x INNER JOIN "
+ "TissueProbeSetXRef AS "
+ "t ON t.Symbol = x.Symbol "
+ "AND t.Mean = x.maxmean",
+ ("k1", "k2", "k3")),
+ (("k1", "203",
+ "112", "xy", "20.11",
+ "Sample Description",
+ "Sample Probe Target Description"),)),
+ ),
+)
+def test_mrna_assay_tissue_data_initialisation(mocker, gene_symbols,
+ expected_query,
+ sql_fetch_all_results):
+ mock_conn = mocker.MagicMock()
+ with mock_conn.cursor() as cursor:
+ cursor.fetchall.return_value = sql_fetch_all_results
+ MrnaAssayTissueData(conn=mock_conn, gene_symbols=gene_symbols)
+ cursor.execute.assert_called_with(*expected_query)
diff --git a/wqflask/wqflask/correlation/correlation_functions.py b/wqflask/wqflask/correlation/correlation_functions.py
index 85b25d60..5c01b0ac 100644
--- a/wqflask/wqflask/correlation/correlation_functions.py
+++ b/wqflask/wqflask/correlation/correlation_functions.py
@@ -25,7 +25,7 @@
from base.mrna_assay_tissue_data import MrnaAssayTissueData
from gn3.computations.correlations import compute_corr_coeff_p_value
-
+from wqflask.database import database_connection
#####################################################################################
# Input: primaryValue(list): one list of expression values of one probeSet,
@@ -60,7 +60,8 @@ def cal_zero_order_corr_for_tiss(primary_values, target_values, method="pearson"
def get_trait_symbol_and_tissue_values(symbol_list=None):
- tissue_data = MrnaAssayTissueData(gene_symbols=symbol_list)
- if len(tissue_data.gene_symbols) > 0:
- results = tissue_data.get_symbol_values_pairs()
- return results
+ with database_connection() as conn:
+ tissue_data = MrnaAssayTissueData(gene_symbols=symbol_list, conn=conn)
+ if len(tissue_data.gene_symbols) > 0:
+ results = tissue_data.get_symbol_values_pairs()
+ return results