about summary refs log tree commit diff
path: root/gn3/db
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-05-31 14:43:12 +0300
committerFrederick Muriuki Muriithi2022-05-31 14:43:12 +0300
commite41a079f162d19e8089f64308e5de9e810461b3e (patch)
treebb3536790b16411b7ebc4aad110fbd5e7922211e /gn3/db
parent85e369311b60faa2490f25c88a2ef87042b91738 (diff)
downloadgenenetwork3-e41a079f162d19e8089f64308e5de9e810461b3e.tar.gz
Extract utility functions from `fetch_all_database_data`
Extract the utility functions to help with understanding the what the
`fetch_all_database_data` function is doing. This helps with maintenance.
Diffstat (limited to 'gn3/db')
-rw-r--r--gn3/db/correlations.py128
1 files changed, 69 insertions, 59 deletions
diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py
index 3ae66ca..07eaa56 100644
--- a/gn3/db/correlations.py
+++ b/gn3/db/correlations.py
@@ -4,7 +4,7 @@ feature to access the database to retrieve data needed for computations.
 """
 import os
 from functools import reduce
-from typing import Any, Dict, Tuple, Union
+from typing import Any, Dict, Tuple, Union, Optional
 
 from gn3.random import random_string
 from gn3.data_helpers import partition_all
@@ -474,6 +474,71 @@ def build_query_tissue_corr(db_type, temp_table, sample_id_columns, joins):
          f"ORDER BY {db_type}.Id"),
         3)
 
+def __build_query__(
+        sample_ids: tuple, db_type: str, method: str,
+        temp_table: Optional[str] = None) -> Tuple[str, int]:
+    """Utility to build the correct query dependent on the `db_type`. Do not use
+    outside of this module."""
+    sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in sample_ids)
+    if db_type == "Publish":
+        joins = (
+            (f"LEFT JOIN PublishData AS T{item} "
+             f"ON T{item}.Id = PublishXRef.DataId "
+             f"AND T{item}.StrainId = %(T{item}_sample_id)s")
+            for item in sample_ids)
+        return (
+            ("SELECT PublishXRef.Id, " +
+             sample_id_columns +
+             " FROM (PublishXRef, PublishFreeze) " +
+             " ".join(joins) +
+             " WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
+             "AND PublishFreeze.Name = %(db_name)s"),
+            1)
+    if temp_table is not None:
+        joins = (
+            (f"LEFT JOIN {db_type}Data AS T{item} "
+             f"ON T{item}.Id = {db_type}XRef.DataId "
+             f"AND T{item}.StrainId=%(T{item}_sample_id)s")
+            for item in sample_ids)
+        if method.lower() == "sgo literature correlation":
+            return build_query_sgo_lit_corr(
+                sample_ids, temp_table, sample_id_columns, joins)
+        if method.lower() in (
+                "tissue correlation, pearson's r",
+                "tissue correlation, spearman's rho"):
+            return build_query_tissue_corr(
+                sample_ids, temp_table, sample_id_columns, joins)
+    joins = (
+        (f"LEFT JOIN {db_type}Data AS T{item} "
+         f"ON T{item}.Id = {db_type}XRef.DataId "
+         f"AND T{item}.StrainId = %(T{item}_sample_id)s")
+        for item in sample_ids)
+    return (
+        (
+            f"SELECT {db_type}.Name, " +
+            sample_id_columns +
+            f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " +
+            " ".join(joins) +
+            f" WHERE {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " +
+            f"AND {db_type}Freeze.Name = %(db_name)s " +
+            f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " +
+            f"ORDER BY {db_type}.Id"),
+        1)
+
+def __fetch_data__(
+        conn, sample_ids: tuple, db_name: str, db_type: str, method: str,
+        temp_table: Optional[str]) -> Tuple[Tuple[Any], int]:
+    """Utility to fetch the data. Do not use outside of this module."""
+    query, data_start_pos = __build_query__(
+        sample_ids, db_type, method, temp_table)
+    with conn.cursor() as cursor:
+        cursor.execute(
+            query,
+            {"db_name": db_name,
+             **{f"T{item}_sample_id": item for item in sample_ids}})
+        return (cursor.fetchall(), data_start_pos)
+
+
 def fetch_all_database_data(# pylint: disable=[R0913, R0914]
         conn: Any, species: str, gene_id: int, trait_symbol: str,
         samples: Tuple[str, ...], dataset: dict, method: str,
@@ -485,62 +550,6 @@ def fetch_all_database_data(# pylint: disable=[R0913, R0914]
     GeneNetwork1.
     """
     db_type = dataset["dataset_type"]
-    db_name = dataset["dataset_name"]
-    def __build_query__(sample_ids, temp_table):
-        sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in sample_ids)
-        if db_type == "Publish":
-            joins = tuple(
-                (f"LEFT JOIN PublishData AS T{item} "
-                 f"ON T{item}.Id = PublishXRef.DataId "
-                 f"AND T{item}.StrainId = %(T{item}_sample_id)s")
-                for item in sample_ids)
-            return (
-                ("SELECT PublishXRef.Id, " +
-                 sample_id_columns +
-                 " FROM (PublishXRef, PublishFreeze) " +
-                 " ".join(joins) +
-                 " WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
-                 "AND PublishFreeze.Name = %(db_name)s"),
-                1)
-        if temp_table is not None:
-            joins = tuple(
-                (f"LEFT JOIN {db_type}Data AS T{item} "
-                 f"ON T{item}.Id = {db_type}XRef.DataId "
-                 f"AND T{item}.StrainId=%(T{item}_sample_id)s")
-                for item in sample_ids)
-            if method.lower() == "sgo literature correlation":
-                return build_query_sgo_lit_corr(
-                    sample_ids, temp_table, sample_id_columns, joins)
-            if method.lower() in (
-                    "tissue correlation, pearson's r",
-                    "tissue correlation, spearman's rho"):
-                return build_query_tissue_corr(
-                    sample_ids, temp_table, sample_id_columns, joins)
-        joins = tuple(
-            (f"LEFT JOIN {db_type}Data AS T{item} "
-             f"ON T{item}.Id = {db_type}XRef.DataId "
-             f"AND T{item}.StrainId = %(T{item}_sample_id)s")
-            for item in sample_ids)
-        return (
-            (
-                f"SELECT {db_type}.Name, " +
-                sample_id_columns +
-                f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " +
-                " ".join(joins) +
-                f" WHERE {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " +
-                f"AND {db_type}Freeze.Name = %(db_name)s " +
-                f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " +
-                f"ORDER BY {db_type}.Id"),
-            1)
-
-    def __fetch_data__(sample_ids, temp_table):
-        query, data_start_pos = __build_query__(sample_ids, temp_table)
-        with conn.cursor() as cursor:
-            cursor.execute(
-                query,
-                {"db_name": db_name,
-                 **{f"T{item}_sample_id": item for item in sample_ids}})
-            return (cursor.fetchall(), data_start_pos)
 
     sample_ids = tuple(
         # look into graduating this to an argument and removing the `samples`
@@ -550,7 +559,7 @@ def fetch_all_database_data(# pylint: disable=[R0913, R0914]
         fetch_sample_ids(conn, samples, species))
 
     temp_table = None
-    if gene_id and db_type == "probeset":
+    if gene_id and db_type.lower() == "probeset":
         if method.lower() == "sgo literature correlation":
             temp_table = build_temporary_literature_table(
                 conn, species, gene_id, return_number)
@@ -562,7 +571,8 @@ def fetch_all_database_data(# pylint: disable=[R0913, R0914]
 
     trait_database = tuple(
         item for sublist in
-        (__fetch_data__(ssample_ids, temp_table)
+        (__fetch_data__(
+            conn, ssample_ids, dataset["dataset_name"], db_type, method, temp_table)
          for ssample_ids in partition_all(25, sample_ids))
         for item in sublist)