about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMunyoki Kilyungi2022-09-07 11:47:26 +0300
committerBonfaceKilz2022-09-08 14:26:19 +0300
commitb6fadcc7e8de9d7041f2bb00dcd1ada518a1932c (patch)
treecc0199f48307568ee2bdab16d7ec0c8774fda50e
parentf879f82ad0393a770ed50043c70ee1dd4a12daaa (diff)
downloadgenenetwork2-b6fadcc7e8de9d7041f2bb00dcd1ada518a1932c.tar.gz
Inject database connection to mrna_assay_tissue_data class
* wqflask/base/mrna_assay_tissue_data.py: Imports: Delete
database_connection, escape, and database_connector.
(MrnaAssayTissueData): Inject conn.  Re-format queries.  Rework 'if
... else' logic.  Re-work how results are assigned to
'self.data[symbol]' - remove dot-notation.
(MrnaAssayTissueData.get_symbol_values_pairs): Move box-comments to
doc-string.  Rework how results are assigned to 'symbol_values_dict' -
remove dot-notation.
* wqflask/tests/unit/base/test_mrna_assay_tissue_data.py
(test_mrna_assay_tissue_data_initialisation): New test.
* wqflask/wqflask/correlation/correlation_functions.py: Import
database_connection.
(get_trait_symbol_and_tissue_values): Inject conn object.
-rw-r--r--wqflask/base/mrna_assay_tissue_data.py131
-rw-r--r--wqflask/tests/unit/base/test_mrna_assay_tissue_data.py51
-rw-r--r--wqflask/wqflask/correlation/correlation_functions.py11
3 files changed, 123 insertions, 70 deletions
diff --git a/wqflask/base/mrna_assay_tissue_data.py b/wqflask/base/mrna_assay_tissue_data.py
index b371e39f..a229151d 100644
--- a/wqflask/base/mrna_assay_tissue_data.py
+++ b/wqflask/base/mrna_assay_tissue_data.py
@@ -1,81 +1,82 @@
 import collections
 
-from wqflask.database import database_connection
-
 from utility import db_tools
 from utility import Bunch
 
-from utility.db_tools import escape
-from gn3.db_utils import database_connector
-
 
 class MrnaAssayTissueData:
 
-    def __init__(self, gene_symbols=None):
+    def __init__(self, conn, gene_symbols=None):
         self.gene_symbols = gene_symbols
-        if self.gene_symbols == None:
+        self.conn = conn
+        if self.gene_symbols is None:
             self.gene_symbols = []
 
         self.data = collections.defaultdict(Bunch)
-
-        query = '''select t.Symbol, t.GeneId, t.DataId, t.Chr, t.Mb, t.description, t.Probe_Target_Description
-                        from (
-                        select Symbol, max(Mean) as maxmean
-                        from TissueProbeSetXRef
-                        where TissueProbeSetFreezeId=1 and '''
-
-        # Note that inner join is necessary in this query to get distinct record in one symbol group
-        # with highest mean value
+        results = ()
+        # Note that inner join is necessary in this query to get
+        # distinct record in one symbol group with highest mean value
         # Due to the limit size of TissueProbeSetFreezeId table in DB,
-        # performance of inner join is acceptable.MrnaAssayTissueData(gene_symbols=symbol_list)
-        if len(gene_symbols) == 0:
-            query += '''Symbol!='' and Symbol Is Not Null group by Symbol)
-                as x inner join TissueProbeSetXRef as t on t.Symbol = x.Symbol
-                and t.Mean = x.maxmean;
-                    '''
-        else:
-            in_clause = db_tools.create_in_clause(gene_symbols)
-
-            # ZS: This was in the query, not sure why: http://docs.python.org/2/library/string.html?highlight=lower#string.lower
-            query += ''' Symbol in {} group by Symbol)
-                as x inner join TissueProbeSetXRef as t on t.Symbol = x.Symbol
-                and t.Mean = x.maxmean;
-                    '''.format(in_clause)
-
-
-        # lower_symbols = []
+        # performance of inner join is
+        # acceptable.MrnaAssayTissueData(gene_symbols=symbol_list)
+        with conn.cursor() as cursor:
+            if len(self.gene_symbols) == 0:
+                cursor.execute(
+                    "SELECT t.Symbol, t.GeneId, t.DataId, "
+                    "t.Chr, t.Mb, t.description, "
+                    "t.Probe_Target_Description FROM (SELECT Symbol, "
+                    "max(Mean) AS maxmean "
+                    "FROM TissueProbeSetXRef WHERE "
+                    "TissueProbeSetFreezeId=1 AND "
+                    "Symbol != '' AND Symbol IS NOT "
+                    "Null GROUP BY Symbol) "
+                    "AS x INNER JOIN "
+                    "TissueProbeSetXRef AS t ON "
+                    "t.Symbol = x.Symbol "
+                    "AND t.Mean = x.maxmean")
+            else:
+                cursor.execute(
+                    "SELECT t.Symbol, t.GeneId, t.DataId, "
+                    "t.Chr, t.Mb, t.description, "
+                    "t.Probe_Target_Description FROM (SELECT Symbol, "
+                    "max(Mean) AS maxmean "
+                    "FROM TissueProbeSetXRef WHERE "
+                    "TissueProbeSetFreezeId=1 AND "
+                    "Symbol IN "
+                    f"({', '.join(['%s'] * len(self.gene_symbols))}) "
+                    "GROUP BY Symbol) AS x INNER JOIN "
+                    "TissueProbeSetXRef AS t ON t.Symbol = x.Symbol "
+                    "AND t.Mean = x.maxmean",
+                    tuple(self.gene_symbols))
+            results = list(cursor.fetchall())
         lower_symbols = {}
-        for gene_symbol in gene_symbols:
-            # lower_symbols[gene_symbol.lower()] = True
-            if gene_symbol != None:
+        for gene_symbol in self.gene_symbols:
+            if gene_symbol is not None:
                 lower_symbols[gene_symbol.lower()] = True
 
-        results = None
-        with database_connection() as conn, conn.cursor() as cursor:
-            cursor.execute(query)
-            results = cursor.fetchall()
         for result in results:
-            symbol = result[0]
-            if symbol  is not None and lower_symbols.get(symbol.lower()):
-
+            (symbol, gene_id, data_id, _chr, _mb,
+             descr, probeset_target_descr) = result
+            if symbol is not None and lower_symbols.get(symbol.lower()):
                 symbol = symbol.lower()
+                self.data[symbol].gene_id = gene_id
+                self.data[symbol].data_id = data_id
+                self.data[symbol].chr = _chr
+                self.data[symbol].mb = _mb
+                self.data[symbol].description = descr
+                (self.data[symbol]
+                 .probe_target_description) = probeset_target_descr
 
-                self.data[symbol].gene_id = result.GeneId
-                self.data[symbol].data_id = result.DataId
-                self.data[symbol].chr = result.Chr
-                self.data[symbol].mb = result.Mb
-                self.data[symbol].description = result.description
-                self.data[symbol].probe_target_description = result.Probe_Target_Description
-
-    ###########################################################################
-    # Input: cursor, symbolList (list), dataIdDict(Dict)
-    # output: symbolValuepairDict (dictionary):one dictionary of Symbol and Value Pair,
-    #        key is symbol, value is one list of expression values of one probeSet;
-    # function: get one dictionary whose key is gene symbol and value is tissue expression data (list type).
-    # Attention! All keys are lower case!
-    ###########################################################################
 
     def get_symbol_values_pairs(self):
+        """Get one dictionary whose key is gene symbol and value is
+        tissue expression data (list type).  All keys are lower case.
+
+        The output is a symbolValuepairDict (dictionary): one
+        dictionary of Symbol and Value Pair; key is symbol, value is
+        one list of expression values of one probeSet;
+
+        """
         id_list = [self.data[symbol].data_id for symbol in self.data]
 
         symbol_values_dict = {}
@@ -86,14 +87,14 @@ class MrnaAssayTissueData:
                        WHERE TissueProbeSetData.Id IN {} and
                              TissueProbeSetXRef.DataId = TissueProbeSetData.Id""".format(db_tools.create_in_clause(id_list))
             results = []
-            with database_connection() as conn, conn.cursor() as cursor:
+            with self.conn.cursor() as cursor:
                 cursor.execute(query)
                 results = cursor.fetchall()
-            for result in results:
-                if result.Symbol.lower() not in symbol_values_dict:
-                    symbol_values_dict[result.Symbol.lower()] = [result.value]
-                else:
-                    symbol_values_dict[result.Symbol.lower()].append(
-                        result.value)
-
+                for result in results:
+                    (symbol, value) = result
+                    if symbol.lower() not in symbol_values_dict:
+                        symbol_values_dict[symbol.lower()] = [value]
+                    else:
+                        symbol_values_dict[symbol.lower()].append(
+                            value)
         return symbol_values_dict
diff --git a/wqflask/tests/unit/base/test_mrna_assay_tissue_data.py b/wqflask/tests/unit/base/test_mrna_assay_tissue_data.py
new file mode 100644
index 00000000..7a21124a
--- /dev/null
+++ b/wqflask/tests/unit/base/test_mrna_assay_tissue_data.py
@@ -0,0 +1,51 @@
+import pytest
+from base.mrna_assay_tissue_data import MrnaAssayTissueData
+
+
+@pytest.mark.parametrize(
+    ('gene_symbols', 'expected_query', 'sql_fetch_all_results'),
+    (
+        (None,
+         (("SELECT t.Symbol, t.GeneId, t.DataId, "
+           "t.Chr, t.Mb, t.description, "
+           "t.Probe_Target_Description "
+           "FROM (SELECT Symbol, "
+           "max(Mean) AS maxmean "
+           "FROM TissueProbeSetXRef WHERE "
+           "TissueProbeSetFreezeId=1 AND "
+           "Symbol != '' AND Symbol IS NOT "
+           "Null GROUP BY Symbol) "
+           "AS x INNER JOIN TissueProbeSetXRef "
+           "AS t ON t.Symbol = x.Symbol "
+           "AND t.Mean = x.maxmean"),),
+         (("symbol", "gene_id",
+           "data_id", "chr", "mb",
+           "description",
+           "probe_target_description"),)),
+        (["k1", "k2", "k3"],
+         ("SELECT t.Symbol, t.GeneId, t.DataId, "
+          "t.Chr, t.Mb, t.description, "
+          "t.Probe_Target_Description FROM (SELECT Symbol, "
+          "max(Mean) AS maxmean "
+          "FROM TissueProbeSetXRef WHERE "
+          "TissueProbeSetFreezeId=1 AND "
+          "Symbol IN (%s, %s, %s) "
+          "GROUP BY Symbol) AS x INNER JOIN "
+          "TissueProbeSetXRef AS "
+          "t ON t.Symbol = x.Symbol "
+          "AND t.Mean = x.maxmean",
+          ("k1", "k2", "k3")),
+         (("k1", "203",
+           "112", "xy", "20.11",
+           "Sample Description",
+           "Sample Probe Target Description"),)),
+    ),
+)
+def test_mrna_assay_tissue_data_initialisation(mocker, gene_symbols,
+                                               expected_query,
+                                               sql_fetch_all_results):
+    mock_conn = mocker.MagicMock()
+    with mock_conn.cursor() as cursor:
+        cursor.fetchall.return_value = sql_fetch_all_results
+        MrnaAssayTissueData(conn=mock_conn, gene_symbols=gene_symbols)
+        cursor.execute.assert_called_with(*expected_query)
diff --git a/wqflask/wqflask/correlation/correlation_functions.py b/wqflask/wqflask/correlation/correlation_functions.py
index 85b25d60..5c01b0ac 100644
--- a/wqflask/wqflask/correlation/correlation_functions.py
+++ b/wqflask/wqflask/correlation/correlation_functions.py
@@ -25,7 +25,7 @@
 
 from base.mrna_assay_tissue_data import MrnaAssayTissueData
 from gn3.computations.correlations import compute_corr_coeff_p_value
-
+from wqflask.database import database_connection
 
 #####################################################################################
 # Input: primaryValue(list): one list of expression values of one probeSet,
@@ -60,7 +60,8 @@ def cal_zero_order_corr_for_tiss(primary_values, target_values, method="pearson"
 
 
 def get_trait_symbol_and_tissue_values(symbol_list=None):
-    tissue_data = MrnaAssayTissueData(gene_symbols=symbol_list)
-    if len(tissue_data.gene_symbols) > 0:
-        results = tissue_data.get_symbol_values_pairs()
-        return results
+    with database_connection() as conn:
+        tissue_data = MrnaAssayTissueData(gene_symbols=symbol_list, conn=conn)
+        if len(tissue_data.gene_symbols) > 0:
+            results = tissue_data.get_symbol_values_pairs()
+            return results