aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gn3/db/correlations.py128
1 files changed, 69 insertions, 59 deletions
diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py
index 3ae66ca..07eaa56 100644
--- a/gn3/db/correlations.py
+++ b/gn3/db/correlations.py
@@ -4,7 +4,7 @@ feature to access the database to retrieve data needed for computations.
"""
import os
from functools import reduce
-from typing import Any, Dict, Tuple, Union
+from typing import Any, Dict, Tuple, Union, Optional
from gn3.random import random_string
from gn3.data_helpers import partition_all
@@ -474,6 +474,71 @@ def build_query_tissue_corr(db_type, temp_table, sample_id_columns, joins):
f"ORDER BY {db_type}.Id"),
3)
+def __build_query__(
+ sample_ids: tuple, db_type: str, method: str,
+ temp_table: Optional[str] = None) -> Tuple[str, int]:
+ """Utility to build the correct query dependent on the `db_type`. Do not use
+ outside of this module."""
+ sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in sample_ids)
+ if db_type == "Publish":
+ joins = (
+ (f"LEFT JOIN PublishData AS T{item} "
+ f"ON T{item}.Id = PublishXRef.DataId "
+ f"AND T{item}.StrainId = %(T{item}_sample_id)s")
+ for item in sample_ids)
+ return (
+ ("SELECT PublishXRef.Id, " +
+ sample_id_columns +
+ " FROM (PublishXRef, PublishFreeze) " +
+ " ".join(joins) +
+ " WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
+ "AND PublishFreeze.Name = %(db_name)s"),
+ 1)
+ if temp_table is not None:
+ joins = (
+ (f"LEFT JOIN {db_type}Data AS T{item} "
+ f"ON T{item}.Id = {db_type}XRef.DataId "
+ f"AND T{item}.StrainId=%(T{item}_sample_id)s")
+ for item in sample_ids)
+ if method.lower() == "sgo literature correlation":
+ return build_query_sgo_lit_corr(
+ sample_ids, temp_table, sample_id_columns, joins)
+ if method.lower() in (
+ "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho"):
+ return build_query_tissue_corr(
+ sample_ids, temp_table, sample_id_columns, joins)
+ joins = (
+ (f"LEFT JOIN {db_type}Data AS T{item} "
+ f"ON T{item}.Id = {db_type}XRef.DataId "
+ f"AND T{item}.StrainId = %(T{item}_sample_id)s")
+ for item in sample_ids)
+ return (
+ (
+ f"SELECT {db_type}.Name, " +
+ sample_id_columns +
+ f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " +
+ " ".join(joins) +
+ f" WHERE {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " +
+ f"AND {db_type}Freeze.Name = %(db_name)s " +
+ f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " +
+ f"ORDER BY {db_type}.Id"),
+ 1)
+
+def __fetch_data__(
+ conn, sample_ids: tuple, db_name: str, db_type: str, method: str,
+ temp_table: Optional[str]) -> Tuple[Tuple[Any], int]:
+ """Utility to fetch the data. Do not use outside of this module."""
+ query, data_start_pos = __build_query__(
+ sample_ids, db_type, method, temp_table)
+ with conn.cursor() as cursor:
+ cursor.execute(
+ query,
+ {"db_name": db_name,
+ **{f"T{item}_sample_id": item for item in sample_ids}})
+ return (cursor.fetchall(), data_start_pos)
+
+
def fetch_all_database_data(# pylint: disable=[R0913, R0914]
conn: Any, species: str, gene_id: int, trait_symbol: str,
samples: Tuple[str, ...], dataset: dict, method: str,
@@ -485,62 +550,6 @@ def fetch_all_database_data(# pylint: disable=[R0913, R0914]
GeneNetwork1.
"""
db_type = dataset["dataset_type"]
- db_name = dataset["dataset_name"]
- def __build_query__(sample_ids, temp_table):
- sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in sample_ids)
- if db_type == "Publish":
- joins = tuple(
- (f"LEFT JOIN PublishData AS T{item} "
- f"ON T{item}.Id = PublishXRef.DataId "
- f"AND T{item}.StrainId = %(T{item}_sample_id)s")
- for item in sample_ids)
- return (
- ("SELECT PublishXRef.Id, " +
- sample_id_columns +
- " FROM (PublishXRef, PublishFreeze) " +
- " ".join(joins) +
- " WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
- "AND PublishFreeze.Name = %(db_name)s"),
- 1)
- if temp_table is not None:
- joins = tuple(
- (f"LEFT JOIN {db_type}Data AS T{item} "
- f"ON T{item}.Id = {db_type}XRef.DataId "
- f"AND T{item}.StrainId=%(T{item}_sample_id)s")
- for item in sample_ids)
- if method.lower() == "sgo literature correlation":
- return build_query_sgo_lit_corr(
- sample_ids, temp_table, sample_id_columns, joins)
- if method.lower() in (
- "tissue correlation, pearson's r",
- "tissue correlation, spearman's rho"):
- return build_query_tissue_corr(
- sample_ids, temp_table, sample_id_columns, joins)
- joins = tuple(
- (f"LEFT JOIN {db_type}Data AS T{item} "
- f"ON T{item}.Id = {db_type}XRef.DataId "
- f"AND T{item}.StrainId = %(T{item}_sample_id)s")
- for item in sample_ids)
- return (
- (
- f"SELECT {db_type}.Name, " +
- sample_id_columns +
- f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " +
- " ".join(joins) +
- f" WHERE {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " +
- f"AND {db_type}Freeze.Name = %(db_name)s " +
- f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " +
- f"ORDER BY {db_type}.Id"),
- 1)
-
- def __fetch_data__(sample_ids, temp_table):
- query, data_start_pos = __build_query__(sample_ids, temp_table)
- with conn.cursor() as cursor:
- cursor.execute(
- query,
- {"db_name": db_name,
- **{f"T{item}_sample_id": item for item in sample_ids}})
- return (cursor.fetchall(), data_start_pos)
sample_ids = tuple(
# look into graduating this to an argument and removing the `samples`
@@ -550,7 +559,7 @@ def fetch_all_database_data(# pylint: disable=[R0913, R0914]
fetch_sample_ids(conn, samples, species))
temp_table = None
- if gene_id and db_type == "probeset":
+ if gene_id and db_type.lower() == "probeset":
if method.lower() == "sgo literature correlation":
temp_table = build_temporary_literature_table(
conn, species, gene_id, return_number)
@@ -562,7 +571,8 @@ def fetch_all_database_data(# pylint: disable=[R0913, R0914]
trait_database = tuple(
item for sublist in
- (__fetch_data__(ssample_ids, temp_table)
+ (__fetch_data__(
+ conn, ssample_ids, dataset["dataset_name"], db_type, method, temp_table)
for ssample_ids in partition_all(25, sample_ids))
for item in sublist)