From e41a079f162d19e8089f64308e5de9e810461b3e Mon Sep 17 00:00:00 2001 From: Frederick Muriuki Muriithi Date: Tue, 31 May 2022 14:43:12 +0300 Subject: Extract utility functions from `fetch_all_database_data` Extract the utility functions to help with understanding the what the `fetch_all_database_data` function is doing. This helps with maintenance. --- gn3/db/correlations.py | 128 ++++++++++++++++++++++++++----------------------- 1 file changed, 69 insertions(+), 59 deletions(-) (limited to 'gn3') diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py index 3ae66ca..07eaa56 100644 --- a/gn3/db/correlations.py +++ b/gn3/db/correlations.py @@ -4,7 +4,7 @@ feature to access the database to retrieve data needed for computations. """ import os from functools import reduce -from typing import Any, Dict, Tuple, Union +from typing import Any, Dict, Tuple, Union, Optional from gn3.random import random_string from gn3.data_helpers import partition_all @@ -474,6 +474,71 @@ def build_query_tissue_corr(db_type, temp_table, sample_id_columns, joins): f"ORDER BY {db_type}.Id"), 3) +def __build_query__( + sample_ids: tuple, db_type: str, method: str, + temp_table: Optional[str] = None) -> Tuple[str, int]: + """Utility to build the correct query dependent on the `db_type`. Do not use + outside of this module.""" + sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in sample_ids) + if db_type == "Publish": + joins = ( + (f"LEFT JOIN PublishData AS T{item} " + f"ON T{item}.Id = PublishXRef.DataId " + f"AND T{item}.StrainId = %(T{item}_sample_id)s") + for item in sample_ids) + return ( + ("SELECT PublishXRef.Id, " + + sample_id_columns + + " FROM (PublishXRef, PublishFreeze) " + + " ".join(joins) + + " WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId " + "AND PublishFreeze.Name = %(db_name)s"), + 1) + if temp_table is not None: + joins = ( + (f"LEFT JOIN {db_type}Data AS T{item} " + f"ON T{item}.Id = {db_type}XRef.DataId " + f"AND T{item}.StrainId=%(T{item}_sample_id)s") + for item in sample_ids) + if method.lower() == "sgo literature correlation": + return build_query_sgo_lit_corr( + sample_ids, temp_table, sample_id_columns, joins) + if method.lower() in ( + "tissue correlation, pearson's r", + "tissue correlation, spearman's rho"): + return build_query_tissue_corr( + sample_ids, temp_table, sample_id_columns, joins) + joins = ( + (f"LEFT JOIN {db_type}Data AS T{item} " + f"ON T{item}.Id = {db_type}XRef.DataId " + f"AND T{item}.StrainId = %(T{item}_sample_id)s") + for item in sample_ids) + return ( + ( + f"SELECT {db_type}.Name, " + + sample_id_columns + + f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " + + " ".join(joins) + + f" WHERE {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " + + f"AND {db_type}Freeze.Name = %(db_name)s " + + f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " + + f"ORDER BY {db_type}.Id"), + 1) + +def __fetch_data__( + conn, sample_ids: tuple, db_name: str, db_type: str, method: str, + temp_table: Optional[str]) -> Tuple[Tuple[Any], int]: + """Utility to fetch the data. Do not use outside of this module.""" + query, data_start_pos = __build_query__( + sample_ids, db_type, method, temp_table) + with conn.cursor() as cursor: + cursor.execute( + query, + {"db_name": db_name, + **{f"T{item}_sample_id": item for item in sample_ids}}) + return (cursor.fetchall(), data_start_pos) + + def fetch_all_database_data(# pylint: disable=[R0913, R0914] conn: Any, species: str, gene_id: int, trait_symbol: str, samples: Tuple[str, ...], dataset: dict, method: str, @@ -485,62 +550,6 @@ def fetch_all_database_data(# pylint: disable=[R0913, R0914] GeneNetwork1. """ db_type = dataset["dataset_type"] - db_name = dataset["dataset_name"] - def __build_query__(sample_ids, temp_table): - sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in sample_ids) - if db_type == "Publish": - joins = tuple( - (f"LEFT JOIN PublishData AS T{item} " - f"ON T{item}.Id = PublishXRef.DataId " - f"AND T{item}.StrainId = %(T{item}_sample_id)s") - for item in sample_ids) - return ( - ("SELECT PublishXRef.Id, " + - sample_id_columns + - " FROM (PublishXRef, PublishFreeze) " + - " ".join(joins) + - " WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId " - "AND PublishFreeze.Name = %(db_name)s"), - 1) - if temp_table is not None: - joins = tuple( - (f"LEFT JOIN {db_type}Data AS T{item} " - f"ON T{item}.Id = {db_type}XRef.DataId " - f"AND T{item}.StrainId=%(T{item}_sample_id)s") - for item in sample_ids) - if method.lower() == "sgo literature correlation": - return build_query_sgo_lit_corr( - sample_ids, temp_table, sample_id_columns, joins) - if method.lower() in ( - "tissue correlation, pearson's r", - "tissue correlation, spearman's rho"): - return build_query_tissue_corr( - sample_ids, temp_table, sample_id_columns, joins) - joins = tuple( - (f"LEFT JOIN {db_type}Data AS T{item} " - f"ON T{item}.Id = {db_type}XRef.DataId " - f"AND T{item}.StrainId = %(T{item}_sample_id)s") - for item in sample_ids) - return ( - ( - f"SELECT {db_type}.Name, " + - sample_id_columns + - f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " + - " ".join(joins) + - f" WHERE {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " + - f"AND {db_type}Freeze.Name = %(db_name)s " + - f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " + - f"ORDER BY {db_type}.Id"), - 1) - - def __fetch_data__(sample_ids, temp_table): - query, data_start_pos = __build_query__(sample_ids, temp_table) - with conn.cursor() as cursor: - cursor.execute( - query, - {"db_name": db_name, - **{f"T{item}_sample_id": item for item in sample_ids}}) - return (cursor.fetchall(), data_start_pos) sample_ids = tuple( # look into graduating this to an argument and removing the `samples` @@ -550,7 +559,7 @@ def fetch_all_database_data(# pylint: disable=[R0913, R0914] fetch_sample_ids(conn, samples, species)) temp_table = None - if gene_id and db_type == "probeset": + if gene_id and db_type.lower() == "probeset": if method.lower() == "sgo literature correlation": temp_table = build_temporary_literature_table( conn, species, gene_id, return_number) @@ -562,7 +571,8 @@ def fetch_all_database_data(# pylint: disable=[R0913, R0914] trait_database = tuple( item for sublist in - (__fetch_data__(ssample_ids, temp_table) + (__fetch_data__( + conn, ssample_ids, dataset["dataset_name"], db_type, method, temp_table) for ssample_ids in partition_all(25, sample_ids)) for item in sublist) -- cgit v1.2.3