aboutsummaryrefslogtreecommitdiff
path: root/gn3/db
diff options
context:
space:
mode:
Diffstat (limited to 'gn3/db')
-rw-r--r--gn3/db/correlations.py234
-rw-r--r--gn3/db/datasets.py152
-rw-r--r--gn3/db/genotypes.py44
-rw-r--r--gn3/db/partial_correlations.py791
-rw-r--r--gn3/db/sample_data.py365
-rw-r--r--gn3/db/species.py17
-rw-r--r--gn3/db/traits.py195
7 files changed, 1594 insertions, 204 deletions
diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py
index 06b3310..3ae66ca 100644
--- a/gn3/db/correlations.py
+++ b/gn3/db/correlations.py
@@ -2,17 +2,16 @@
This module will hold functions that are used in the (partial) correlations
feature to access the database to retrieve data needed for computations.
"""
-
+import os
from functools import reduce
-from typing import Any, Dict, Tuple
+from typing import Any, Dict, Tuple, Union
from gn3.random import random_string
from gn3.data_helpers import partition_all
from gn3.db.species import translate_to_mouse_gene_id
-from gn3.computations.partial_correlations import correlations_of_all_tissue_traits
-
-def get_filename(target_db_name: str, conn: Any) -> str:
+def get_filename(conn: Any, target_db_name: str, text_files_dir: str) -> Union[
+ str, bool]:
"""
Retrieve the name of the reference database file with which correlations are
computed.
@@ -23,18 +22,23 @@ def get_filename(target_db_name: str, conn: Any) -> str:
"""
with conn.cursor() as cursor:
cursor.execute(
- "SELECT Id, FullName from ProbeSetFreeze WHERE Name-%s",
- target_db_name)
+ "SELECT Id, FullName from ProbeSetFreeze WHERE Name=%s",
+ (target_db_name,))
result = cursor.fetchone()
if result:
- return "ProbeSetFreezeId_{tid}_FullName_{fname}.txt".format(
- tid=result[0],
- fname=result[1].replace(' ', '_').replace('/', '_'))
+ filename = (
+ f"ProbeSetFreezeId_{result[0]}_FullName_"
+ f"{result[1].replace(' ', '_').replace('/', '_')}.txt")
+ full_filename = f"{text_files_dir}/{filename}"
+ return (
+ os.path.exists(full_filename) and
+ (filename in os.listdir(text_files_dir)) and
+ full_filename)
- return ""
+ return False
def build_temporary_literature_table(
- species: str, gene_id: int, return_number: int, conn: Any) -> str:
+ conn: Any, species: str, gene_id: int, return_number: int) -> str:
"""
Build and populate a temporary table to hold the literature correlation data
to be used in computations.
@@ -49,7 +53,7 @@ def build_temporary_literature_table(
query = {
"rat": "SELECT rat FROM GeneIDXRef WHERE mouse=%s",
"human": "SELECT human FROM GeneIDXRef WHERE mouse=%d"}
- if species in query.keys():
+ if species in query:
cursor.execute(query[species], row[1])
record = cursor.fetchone()
if record:
@@ -128,7 +132,7 @@ def fetch_literature_correlations(
GeneNetwork1.
"""
temp_table = build_temporary_literature_table(
- species, gene_id, return_number, conn)
+ conn, species, gene_id, return_number)
query_fns = {
"Geno": fetch_geno_literature_correlations,
# "Temp": fetch_temp_literature_correlations,
@@ -156,11 +160,14 @@ def fetch_symbol_value_pair_dict(
symbol: data_id_dict.get(symbol) for symbol in symbol_list
if data_id_dict.get(symbol) is not None
}
- query = "SELECT Id, value FROM TissueProbeSetData WHERE Id IN %(data_ids)s"
+ data_ids_fields = (f"%(id{i})s" for i in range(len(data_ids.values())))
+ query = (
+ "SELECT Id, value FROM TissueProbeSetData "
+ f"WHERE Id IN ({','.join(data_ids_fields)})")
with conn.cursor() as cursor:
cursor.execute(
query,
- data_ids=tuple(data_ids.values()))
+ **{f"id{i}": did for i, did in enumerate(data_ids.values())})
value_results = cursor.fetchall()
return {
key: tuple(row[1] for row in value_results if row[0] == key)
@@ -234,8 +241,10 @@ def fetch_tissue_probeset_xref_info(
"INNER JOIN TissueProbeSetXRef AS t ON t.Symbol = x.Symbol "
"AND t.Mean = x.maxmean")
cursor.execute(
- query, probeset_freeze_id=probeset_freeze_id,
- symbols=tuple(gene_name_list))
+ query, {
+ "probeset_freeze_id": probeset_freeze_id,
+ "symbols": tuple(gene_name_list)
+ })
results = cursor.fetchall()
@@ -268,8 +277,8 @@ def fetch_gene_symbol_tissue_value_dict_for_trait(
return {}
def build_temporary_tissue_correlations_table(
- trait_symbol: str, probeset_freeze_id: int, method: str,
- return_number: int, conn: Any) -> str:
+ conn: Any, trait_symbol: str, probeset_freeze_id: int, method: str,
+ return_number: int) -> str:
"""
Build a temporary table to hold the tissue correlations data.
@@ -279,6 +288,16 @@ def build_temporary_tissue_correlations_table(
# We should probably pass the `correlations_of_all_tissue_traits` function
# as an argument to this function and get rid of the one call immediately
# following this comment.
+ from gn3.computations.partial_correlations import (#pylint: disable=[C0415, R0401]
+ correlations_of_all_tissue_traits)
+ # This import above is necessary within the function to avoid
+ # circular-imports.
+ #
+ #
+ # This import above is indicative of convoluted code, with the computation
+ # being interwoven with the data retrieval. This needs to be changed, such
+ # that the function being imported here is no longer necessary, or have the
+ # imported function passed to this function as an argument.
symbol_corr_dict, symbol_p_value_dict = correlations_of_all_tissue_traits(
fetch_gene_symbol_tissue_value_dict_for_trait(
(trait_symbol,), probeset_freeze_id, conn),
@@ -320,7 +339,7 @@ def fetch_tissue_correlations(# pylint: disable=R0913
GeneNetwork1.
"""
temp_table = build_temporary_tissue_correlations_table(
- trait_symbol, probeset_freeze_id, method, return_number, conn)
+ conn, trait_symbol, probeset_freeze_id, method, return_number)
with conn.cursor() as cursor:
cursor.execute(
(
@@ -379,3 +398,176 @@ def check_symbol_for_tissue_correlation(
return True
return False
+
+def fetch_sample_ids(
+ conn: Any, sample_names: Tuple[str, ...], species_name: str) -> Tuple[
+ int, ...]:
+ """
+ Given a sequence of sample names, and a species name, return the sample ids
+ that correspond to both.
+
+ This is a partial migration of the
+ `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in
+ GeneNetwork1.
+ """
+ samples_fields = (f"%(s{i})s" for i in range(len(sample_names)))
+ query = (
+ "SELECT Strain.Id FROM Strain, Species "
+ f"WHERE Strain.Name IN ({','.join(samples_fields)}) "
+ "AND Strain.SpeciesId=Species.Id "
+ "AND Species.name=%(species_name)s")
+ with conn.cursor() as cursor:
+ cursor.execute(
+ query,
+ {
+ **{f"s{i}": sname for i, sname in enumerate(sample_names)},
+ "species_name": species_name
+ })
+ return tuple(row[0] for row in cursor.fetchall())
+
+def build_query_sgo_lit_corr(
+ db_type: str, temp_table: str, sample_id_columns: str,
+ joins: Tuple[str, ...]) -> Tuple[str, int]:
+ """
+ Build query for `SGO Literature Correlation` data, when querying the given
+ `temp_table` temporary table.
+
+ This is a partial migration of the
+ `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in
+ GeneNetwork1.
+ """
+ return (
+ (f"SELECT {db_type}.Name, {temp_table}.value, " +
+ sample_id_columns +
+ f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " +
+ f"LEFT JOIN {temp_table} ON {temp_table}.GeneId2=ProbeSet.GeneId " +
+ " ".join(joins) +
+ " WHERE ProbeSet.GeneId IS NOT NULL " +
+ f"AND {temp_table}.value IS NOT NULL " +
+ f"AND {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " +
+ f"AND {db_type}Freeze.Name = %(db_name)s " +
+ f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " +
+ f"ORDER BY {db_type}.Id"),
+ 2)
+
+def build_query_tissue_corr(db_type, temp_table, sample_id_columns, joins):
+ """
+ Build query for `Tissue Correlation` data, when querying the given
+ `temp_table` temporary table.
+
+ This is a partial migration of the
+ `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in
+ GeneNetwork1.
+ """
+ return (
+ (f"SELECT {db_type}.Name, {temp_table}.Correlation, " +
+ f"{temp_table}.PValue, " +
+ sample_id_columns +
+ f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " +
+ f"LEFT JOIN {temp_table} ON {temp_table}.Symbol=ProbeSet.Symbol " +
+ " ".join(joins) +
+ " WHERE ProbeSet.Symbol IS NOT NULL " +
+ f"AND {temp_table}.Correlation IS NOT NULL " +
+ f"AND {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " +
+ f"AND {db_type}Freeze.Name = %(db_name)s " +
+ f"AND {db_type}.Id = {db_type}XRef.{db_type}Id "
+ f"ORDER BY {db_type}.Id"),
+ 3)
+
+def fetch_all_database_data(# pylint: disable=[R0913, R0914]
+ conn: Any, species: str, gene_id: int, trait_symbol: str,
+ samples: Tuple[str, ...], dataset: dict, method: str,
+ return_number: int, probeset_freeze_id: int) -> Tuple[
+ Tuple[float], int]:
+ """
+ This is a migration of the
+ `web.webqtl.correlation.CorrelationPage.fetchAllDatabaseData` function in
+ GeneNetwork1.
+ """
+ db_type = dataset["dataset_type"]
+ db_name = dataset["dataset_name"]
+ def __build_query__(sample_ids, temp_table):
+ sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in sample_ids)
+ if db_type == "Publish":
+ joins = tuple(
+ (f"LEFT JOIN PublishData AS T{item} "
+ f"ON T{item}.Id = PublishXRef.DataId "
+ f"AND T{item}.StrainId = %(T{item}_sample_id)s")
+ for item in sample_ids)
+ return (
+ ("SELECT PublishXRef.Id, " +
+ sample_id_columns +
+ " FROM (PublishXRef, PublishFreeze) " +
+ " ".join(joins) +
+ " WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
+ "AND PublishFreeze.Name = %(db_name)s"),
+ 1)
+ if temp_table is not None:
+ joins = tuple(
+ (f"LEFT JOIN {db_type}Data AS T{item} "
+ f"ON T{item}.Id = {db_type}XRef.DataId "
+ f"AND T{item}.StrainId=%(T{item}_sample_id)s")
+ for item in sample_ids)
+ if method.lower() == "sgo literature correlation":
+ return build_query_sgo_lit_corr(
+ sample_ids, temp_table, sample_id_columns, joins)
+ if method.lower() in (
+ "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho"):
+ return build_query_tissue_corr(
+ sample_ids, temp_table, sample_id_columns, joins)
+ joins = tuple(
+ (f"LEFT JOIN {db_type}Data AS T{item} "
+ f"ON T{item}.Id = {db_type}XRef.DataId "
+ f"AND T{item}.StrainId = %(T{item}_sample_id)s")
+ for item in sample_ids)
+ return (
+ (
+ f"SELECT {db_type}.Name, " +
+ sample_id_columns +
+ f" FROM ({db_type}, {db_type}XRef, {db_type}Freeze) " +
+ " ".join(joins) +
+ f" WHERE {db_type}XRef.{db_type}FreezeId = {db_type}Freeze.Id " +
+ f"AND {db_type}Freeze.Name = %(db_name)s " +
+ f"AND {db_type}.Id = {db_type}XRef.{db_type}Id " +
+ f"ORDER BY {db_type}.Id"),
+ 1)
+
+ def __fetch_data__(sample_ids, temp_table):
+ query, data_start_pos = __build_query__(sample_ids, temp_table)
+ with conn.cursor() as cursor:
+ cursor.execute(
+ query,
+ {"db_name": db_name,
+ **{f"T{item}_sample_id": item for item in sample_ids}})
+ return (cursor.fetchall(), data_start_pos)
+
+ sample_ids = tuple(
+ # look into graduating this to an argument and removing the `samples`
+ # and `species` argument: function currying and compositions might help
+ # with this
+ f"{sample_id}" for sample_id in
+ fetch_sample_ids(conn, samples, species))
+
+ temp_table = None
+ if gene_id and db_type == "probeset":
+ if method.lower() == "sgo literature correlation":
+ temp_table = build_temporary_literature_table(
+ conn, species, gene_id, return_number)
+ if method.lower() in (
+ "tissue correlation, pearson's r",
+ "tissue correlation, spearman's rho"):
+ temp_table = build_temporary_tissue_correlations_table(
+ conn, trait_symbol, probeset_freeze_id, method, return_number)
+
+ trait_database = tuple(
+ item for sublist in
+ (__fetch_data__(ssample_ids, temp_table)
+ for ssample_ids in partition_all(25, sample_ids))
+ for item in sublist)
+
+ if temp_table:
+ with conn.cursor() as cursor:
+ cursor.execute(f"DROP TEMPORARY TABLE {temp_table}")
+
+ return (trait_database[0], trait_database[1])
diff --git a/gn3/db/datasets.py b/gn3/db/datasets.py
index 6c328f5..b19db53 100644
--- a/gn3/db/datasets.py
+++ b/gn3/db/datasets.py
@@ -1,7 +1,11 @@
"""
This module contains functions relating to specific trait dataset manipulation
"""
-from typing import Any
+import re
+from string import Template
+from typing import Any, Dict, List, Optional
+from SPARQLWrapper import JSON, SPARQLWrapper
+from gn3.settings import SPARQL_ENDPOINT
def retrieve_probeset_trait_dataset_name(
threshold: int, name: str, connection: Any):
@@ -22,10 +26,13 @@ def retrieve_probeset_trait_dataset_name(
"threshold": threshold,
"name": name
})
- return dict(zip(
- ["dataset_id", "dataset_name", "dataset_fullname",
- "dataset_shortname", "dataset_datascale"],
- cursor.fetchone()))
+ res = cursor.fetchone()
+ if res:
+ return dict(zip(
+ ["dataset_id", "dataset_name", "dataset_fullname",
+ "dataset_shortname", "dataset_datascale"],
+ res))
+ return {"dataset_id": None, "dataset_name": name, "dataset_fullname": name}
def retrieve_publish_trait_dataset_name(
threshold: int, name: str, connection: Any):
@@ -75,33 +82,8 @@ def retrieve_geno_trait_dataset_name(
"dataset_shortname"],
cursor.fetchone()))
-def retrieve_temp_trait_dataset_name(
- threshold: int, name: str, connection: Any):
- """
- Get the ID, DataScale and various name formats for a `Temp` trait.
- """
- query = (
- "SELECT Id, Name, FullName, ShortName "
- "FROM TempFreeze "
- "WHERE "
- "public > %(threshold)s "
- "AND "
- "(Name = %(name)s OR FullName = %(name)s OR ShortName = %(name)s)")
- with connection.cursor() as cursor:
- cursor.execute(
- query,
- {
- "threshold": threshold,
- "name": name
- })
- return dict(zip(
- ["dataset_id", "dataset_name", "dataset_fullname",
- "dataset_shortname"],
- cursor.fetchone()))
-
def retrieve_dataset_name(
- trait_type: str, threshold: int, trait_name: str, dataset_name: str,
- conn: Any):
+ trait_type: str, threshold: int, dataset_name: str, conn: Any):
"""
Retrieve the name of a trait given the trait's name
@@ -113,9 +95,7 @@ def retrieve_dataset_name(
"ProbeSet": retrieve_probeset_trait_dataset_name,
"Publish": retrieve_publish_trait_dataset_name,
"Geno": retrieve_geno_trait_dataset_name,
- "Temp": retrieve_temp_trait_dataset_name}
- if trait_type == "Temp":
- return retrieve_temp_trait_dataset_name(threshold, trait_name, conn)
+ "Temp": lambda threshold, dataset_name, conn: {}}
return fn_map[trait_type](threshold, dataset_name, conn)
@@ -203,7 +183,6 @@ def retrieve_temp_trait_dataset():
"""
Retrieve the dataset that relates to `Temp` traits
"""
- # pylint: disable=[C0330]
return {
"searchfield": ["name", "description"],
"disfield": ["name", "description"],
@@ -217,7 +196,6 @@ def retrieve_geno_trait_dataset():
"""
Retrieve the dataset that relates to `Geno` traits
"""
- # pylint: disable=[C0330]
return {
"searchfield": ["name", "chr"],
"disfield": ["name", "chr", "mb", "source2", "sequence"],
@@ -228,7 +206,6 @@ def retrieve_publish_trait_dataset():
"""
Retrieve the dataset that relates to `Publish` traits
"""
- # pylint: disable=[C0330]
return {
"searchfield": [
"name", "post_publication_description", "abstract", "title",
@@ -247,7 +224,6 @@ def retrieve_probeset_trait_dataset():
"""
Retrieve the dataset that relates to `ProbeSet` traits
"""
- # pylint: disable=[C0330]
return {
"searchfield": [
"name", "description", "probe_target_description", "symbol",
@@ -278,8 +254,7 @@ def retrieve_trait_dataset(trait_type, trait, threshold, conn):
"dataset_id": None,
"dataset_name": trait["db"]["dataset_name"],
**retrieve_dataset_name(
- trait_type, threshold, trait["trait_name"],
- trait["db"]["dataset_name"], conn)
+ trait_type, threshold, trait["db"]["dataset_name"], conn)
}
group = retrieve_group_fields(
trait_type, trait["trait_name"], dataset_name_info, conn)
@@ -289,3 +264,100 @@ def retrieve_trait_dataset(trait_type, trait, threshold, conn):
**dataset_fns[trait_type](),
**group
}
+
+def sparql_query(query: str) -> List[Dict[str, Any]]:
+ """Run a SPARQL query and return the bound variables."""
+ sparql = SPARQLWrapper(SPARQL_ENDPOINT)
+ sparql.setQuery(query)
+ sparql.setReturnFormat(JSON)
+ return sparql.queryAndConvert()['results']['bindings']
+
+def dataset_metadata(accession_id: str) -> Optional[Dict[str, Any]]:
+ """Return info about dataset with ACCESSION_ID."""
+ # Check accession_id to protect against query injection.
+ # TODO: This function doesn't yet return the names of the actual dataset files.
+ pattern = re.compile(r'GN\d+', re.ASCII)
+ if not pattern.fullmatch(accession_id):
+ return None
+ # KLUDGE: We split the SPARQL query because virtuoso is very slow on a
+ # single large query.
+ queries = ["""
+PREFIX gn: <http://genenetwork.org/>
+SELECT ?name ?dataset_group ?status ?title ?geo_series
+WHERE {
+ ?dataset gn:accessionId "$accession_id" ;
+ rdf:type gn:dataset ;
+ gn:name ?name .
+ OPTIONAL { ?dataset gn:datasetGroup ?dataset_group } .
+ # FIXME: gn:datasetStatus should not be optional. But, some records don't
+ # have it.
+ OPTIONAL { ?dataset gn:datasetStatus ?status } .
+ OPTIONAL { ?dataset gn:title ?title } .
+ OPTIONAL { ?dataset gn:geoSeries ?geo_series } .
+}
+""",
+ """
+PREFIX gn: <http://genenetwork.org/>
+SELECT ?platform_name ?normalization_name ?species_name ?inbred_set_name ?tissue_name
+WHERE {
+ ?dataset gn:accessionId "$accession_id" ;
+ rdf:type gn:dataset ;
+ gn:normalization / gn:name ?normalization_name ;
+ gn:datasetOfSpecies / gn:menuName ?species_name ;
+ gn:datasetOfInbredSet / gn:name ?inbred_set_name .
+ OPTIONAL { ?dataset gn:datasetOfTissue / gn:name ?tissue_name } .
+ OPTIONAL { ?dataset gn:datasetOfPlatform / gn:name ?platform_name } .
+}
+""",
+ """
+PREFIX gn: <http://genenetwork.org/>
+SELECT ?specifics ?summary ?about_cases ?about_tissue ?about_platform
+ ?about_data_processing ?notes ?experiment_design ?contributors
+ ?citation ?acknowledgment
+WHERE {
+ ?dataset gn:accessionId "$accession_id" ;
+ rdf:type gn:dataset .
+ OPTIONAL { ?dataset gn:specifics ?specifics . }
+ OPTIONAL { ?dataset gn:summary ?summary . }
+ OPTIONAL { ?dataset gn:aboutCases ?about_cases . }
+ OPTIONAL { ?dataset gn:aboutTissue ?about_tissue . }
+ OPTIONAL { ?dataset gn:aboutPlatform ?about_platform . }
+ OPTIONAL { ?dataset gn:aboutDataProcessing ?about_data_processing . }
+ OPTIONAL { ?dataset gn:notes ?notes . }
+ OPTIONAL { ?dataset gn:experimentDesign ?experiment_design . }
+ OPTIONAL { ?dataset gn:contributors ?contributors . }
+ OPTIONAL { ?dataset gn:citation ?citation . }
+ OPTIONAL { ?dataset gn:acknowledgment ?acknowledgment . }
+}
+"""]
+ result: Dict[str, Any] = {'accession_id': accession_id,
+ 'investigator': {}}
+ query_result = {}
+ for query in queries:
+ if sparql_result := sparql_query(Template(query).substitute(accession_id=accession_id)):
+ query_result.update(sparql_result[0])
+ else:
+ return None
+ for key, value in query_result.items():
+ result[key] = value['value']
+ investigator_query_result = sparql_query(Template("""
+PREFIX gn: <http://genenetwork.org/>
+SELECT ?name ?address ?city ?state ?zip ?phone ?email ?country ?homepage
+WHERE {
+ ?dataset gn:accessionId "$accession_id" ;
+ rdf:type gn:dataset ;
+ gn:datasetOfInvestigator ?investigator .
+ OPTIONAL { ?investigator foaf:name ?name . }
+ OPTIONAL { ?investigator gn:address ?address . }
+ OPTIONAL { ?investigator gn:city ?city . }
+ OPTIONAL { ?investigator gn:state ?state . }
+ OPTIONAL { ?investigator gn:zipCode ?zip . }
+ OPTIONAL { ?investigator foaf:phone ?phone . }
+ OPTIONAL { ?investigator foaf:mbox ?email . }
+ OPTIONAL { ?investigator gn:country ?country . }
+ OPTIONAL { ?investigator foaf:homepage ?homepage . }
+}
+""").substitute(accession_id=accession_id))[0]
+ for key, value in investigator_query_result.items():
+ result['investigator'][key] = value['value']
+ return result
diff --git a/gn3/db/genotypes.py b/gn3/db/genotypes.py
index 8f18cac..6f867c7 100644
--- a/gn3/db/genotypes.py
+++ b/gn3/db/genotypes.py
@@ -2,7 +2,6 @@
import os
import gzip
-from typing import Union, TextIO
from gn3.settings import GENOTYPE_FILES
@@ -10,7 +9,7 @@ def build_genotype_file(
geno_name: str, base_dir: str = GENOTYPE_FILES,
extension: str = "geno"):
"""Build the absolute path for the genotype file."""
- return "{}/{}.{}".format(os.path.abspath(base_dir), geno_name, extension)
+ return f"{os.path.abspath(base_dir)}/{geno_name}.{extension}"
def load_genotype_samples(genotype_filename: str, file_type: str = "geno"):
"""
@@ -44,22 +43,23 @@ def __load_genotype_samples_from_geno(genotype_filename: str):
Loads samples from '.geno' files.
"""
- gzipped_filename = "{}.gz".format(genotype_filename)
+ def __remove_comments_and_empty_lines__(rows):
+ return(
+ line for line in rows
+ if line and not line.startswith(("#", "@")))
+
+ gzipped_filename = f"{genotype_filename}.gz"
if os.path.isfile(gzipped_filename):
- genofile: Union[TextIO, gzip.GzipFile] = gzip.open(gzipped_filename)
+ with gzip.open(gzipped_filename) as gz_genofile:
+ rows = __remove_comments_and_empty_lines__(gz_genofile.readlines())
else:
- genofile = open(genotype_filename)
-
- for row in genofile:
- line = row.strip()
- if (not line) or (line.startswith(("#", "@"))): # type: ignore[arg-type]
- continue
- break
+ with open(genotype_filename, encoding="utf8") as genofile:
+ rows = __remove_comments_and_empty_lines__(genofile.readlines())
- headers = line.split("\t") # type: ignore[arg-type]
+ headers = next(rows).split() # type: ignore[arg-type]
if headers[3] == "Mb":
- return headers[4:]
- return headers[3:]
+ return tuple(headers[4:])
+ return tuple(headers[3:])
def __load_genotype_samples_from_plink(genotype_filename: str):
"""
@@ -67,8 +67,8 @@ def __load_genotype_samples_from_plink(genotype_filename: str):
Loads samples from '.plink' files.
"""
- genofile = open(genotype_filename)
- return [line.split(" ")[1] for line in genofile]
+ with open(genotype_filename, encoding="utf8") as genofile:
+ return tuple(line.split()[1] for line in genofile)
def parse_genotype_labels(lines: list):
"""
@@ -129,7 +129,7 @@ def parse_genotype_marker(line: str, geno_obj: dict, parlist: tuple):
alleles = marker_row[start_pos:]
genotype = tuple(
- (geno_table[allele] if allele in geno_table.keys() else "U")
+ (geno_table[allele] if allele in geno_table else "U")
for allele in alleles)
if len(parlist) > 0:
genotype = (-1, 1) + genotype
@@ -164,7 +164,7 @@ def parse_genotype_file(filename: str, parlist: tuple = tuple()):
"""
Parse the provided genotype file into a usable pytho3 data structure.
"""
- with open(filename, "r") as infile:
+ with open(filename, "r", encoding="utf8") as infile:
contents = infile.readlines()
lines = tuple(line for line in contents if
@@ -175,10 +175,10 @@ def parse_genotype_file(filename: str, parlist: tuple = tuple()):
data_lines = tuple(line for line in lines if not line.startswith("@"))
header = parse_genotype_header(data_lines[0], parlist)
geno_obj = dict(labels + header)
- markers = tuple(
- [parse_genotype_marker(line, geno_obj, parlist)
- for line in data_lines[1:]])
+ markers = (
+ parse_genotype_marker(line, geno_obj, parlist)
+ for line in data_lines[1:])
chromosomes = tuple(
dict(chromosome) for chromosome in
- build_genotype_chromosomes(geno_obj, markers))
+ build_genotype_chromosomes(geno_obj, tuple(markers)))
return {**geno_obj, "chromosomes": chromosomes}
diff --git a/gn3/db/partial_correlations.py b/gn3/db/partial_correlations.py
new file mode 100644
index 0000000..72dbf1a
--- /dev/null
+++ b/gn3/db/partial_correlations.py
@@ -0,0 +1,791 @@
+"""
+This module contains the code and queries for fetching data from the database,
+that relates to partial correlations.
+
+It is intended to replace the functions in `gn3.db.traits` and `gn3.db.datasets`
+modules with functions that fetch the data enmasse, rather than one at a time.
+
+This module is part of the optimisation effort for the partial correlations.
+"""
+
+from functools import reduce, partial
+from typing import Any, Dict, Tuple, Union, Sequence
+
+from MySQLdb.cursors import DictCursor
+
+from gn3.function_helpers import compose
+from gn3.db.traits import (
+ build_trait_name,
+ with_samplelist_data_setup,
+ without_samplelist_data_setup)
+
+def organise_trait_data_by_trait(
+ traits_data_rows: Tuple[Dict[str, Any], ...]) -> Dict[
+ str, Dict[str, Any]]:
+ """
+ Organise the trait data items by their trait names.
+ """
+ def __organise__(acc, row):
+ trait_name = row["trait_name"]
+ return {
+ **acc,
+ trait_name: acc.get(trait_name, tuple()) + ({
+ key: val for key, val in row.items() if key != "trait_name"},)
+ }
+ if traits_data_rows:
+ return reduce(__organise__, traits_data_rows, {})
+ return {}
+
+def temp_traits_data(conn, traits):
+ """
+ Retrieve trait data for `Temp` traits.
+ """
+ query = (
+ "SELECT "
+ "Temp.Name AS trait_name, Strain.Name AS sample_name, TempData.value, "
+ "TempData.SE AS se_error, TempData.NStrain AS nstrain, "
+ "TempData.Id AS id "
+ "FROM TempData, Temp, Strain "
+ "WHERE TempData.StrainId = Strain.Id "
+ "AND TempData.Id = Temp.DataId "
+ f"AND Temp.name IN ({', '.join(['%s'] * len(traits))}) "
+ "ORDER BY Strain.Name")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ tuple(trait["trait_name"] for trait in traits))
+ return organise_trait_data_by_trait(cursor.fetchall())
+ return {}
+
+def publish_traits_data(conn, traits):
+ """
+ Retrieve trait data for `Publish` traits.
+ """
+ dataset_ids = tuple(set(
+ trait["db"]["dataset_id"] for trait in traits
+ if trait["db"].get("dataset_id") is not None))
+ query = (
+ "SELECT "
+ "PublishXRef.Id AS trait_name, Strain.Name AS sample_name, "
+ "PublishData.value, PublishSE.error AS se_error, "
+ "NStrain.count AS nstrain, PublishData.Id AS id "
+ "FROM (PublishData, Strain, PublishXRef, PublishFreeze) "
+ "LEFT JOIN PublishSE "
+ "ON (PublishSE.DataId = PublishData.Id "
+ "AND PublishSE.StrainId = PublishData.StrainId) "
+ "LEFT JOIN NStrain "
+ "ON (NStrain.DataId = PublishData.Id "
+ "AND NStrain.StrainId = PublishData.StrainId) "
+ "WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
+ "AND PublishData.Id = PublishXRef.DataId "
+ f"AND PublishXRef.Id IN ({', '.join(['%s'] * len(traits))}) "
+ "AND PublishFreeze.Id IN "
+ f"({', '.join(['%s'] * len(dataset_ids))}) "
+ "AND PublishData.StrainId = Strain.Id "
+ "ORDER BY Strain.Name")
+ if len(dataset_ids) > 0:
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ tuple(trait["trait_name"] for trait in traits) +
+ tuple(dataset_ids))
+ return organise_trait_data_by_trait(cursor.fetchall())
+ return {}
+
+def cellid_traits_data(conn, traits):
+ """
+ Retrieve trait data for `Probe Data` types.
+ """
+ cellids = tuple(trait["cellid"] for trait in traits)
+ dataset_names = set(trait["db"]["dataset_name"] for trait in traits)
+ query = (
+ "SELECT "
+ "ProbeSet.Name AS trait_name, Strain.Name AS sample_name, "
+ "ProbeData.value, ProbeSE.error AS se_error, ProbeData.Id AS id "
+ "FROM (ProbeData, ProbeFreeze, ProbeSetFreeze, ProbeXRef, Strain, "
+ "Probe, ProbeSet) "
+ "LEFT JOIN ProbeSE "
+ "ON (ProbeSE.DataId = ProbeData.Id "
+ "AND ProbeSE.StrainId = ProbeData.StrainId) "
+ f"WHERE Probe.Name IN ({', '.join(['%s'] * len(cellids))}) "
+ f"AND ProbeSet.Name IN ({', '.join(['%s'] * len(traits))}) "
+ "AND Probe.ProbeSetId = ProbeSet.Id "
+ "AND ProbeXRef.ProbeId = Probe.Id "
+ "AND ProbeXRef.ProbeFreezeId = ProbeFreeze.Id "
+ "AND ProbeSetFreeze.ProbeFreezeId = ProbeFreeze.Id "
+ f"AND ProbeSetFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))}) "
+ "AND ProbeXRef.DataId = ProbeData.Id "
+ "AND ProbeData.StrainId = Strain.Id "
+ "ORDER BY Strain.Name")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ cellids + tuple(trait["trait_name"] for trait in traits) +
+ tuple(dataset_names))
+ return organise_trait_data_by_trait(cursor.fetchall())
+ return {}
+
+def probeset_traits_data(conn, traits):
+ """
+ Retrieve trait data for `ProbeSet` traits.
+ """
+ dataset_names = set(trait["db"]["dataset_name"] for trait in traits)
+ query = (
+ "SELECT ProbeSet.Name AS trait_name, Strain.Name AS sample_name, "
+ "ProbeSetData.value, ProbeSetSE.error AS se_error, "
+ "ProbeSetData.Id AS id "
+ "FROM (ProbeSetData, ProbeSetFreeze, Strain, ProbeSet, ProbeSetXRef) "
+ "LEFT JOIN ProbeSetSE ON "
+ "(ProbeSetSE.DataId = ProbeSetData.Id "
+ "AND ProbeSetSE.StrainId = ProbeSetData.StrainId) "
+ f"WHERE ProbeSet.Name IN ({', '.join(['%s'] * len(traits))})"
+ "AND ProbeSetXRef.ProbeSetId = ProbeSet.Id "
+ "AND ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id "
+ f"AND ProbeSetFreeze.Name IN ({', '.join(['%s']*len(dataset_names))}) "
+ "AND ProbeSetXRef.DataId = ProbeSetData.Id "
+ "AND ProbeSetData.StrainId = Strain.Id "
+ "ORDER BY Strain.Name")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ tuple(trait["trait_name"] for trait in traits) +
+ tuple(dataset_names))
+ return organise_trait_data_by_trait(cursor.fetchall())
+ return {}
+
+def species_ids(conn, traits):
+ """
+ Retrieve the IDS of the related species from the given list of traits.
+ """
+ groups = tuple(set(
+ trait["db"]["group"] for trait in traits
+ if trait["db"].get("group") is not None))
+ query = (
+ "SELECT Name AS `group`, SpeciesId AS species_id "
+ "FROM InbredSet "
+ f"WHERE Name IN ({', '.join(['%s'] * len(groups))})")
+ if len(groups) > 0:
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(query, groups)
+ return tuple(row for row in cursor.fetchall())
+ return tuple()
+
+def geno_traits_data(conn, traits):
+ """
+ Retrieve trait data for `Geno` traits.
+ """
+ sp_ids = tuple(item["species_id"] for item in species_ids(conn, traits))
+ dataset_names = set(trait["db"]["dataset_name"] for trait in traits)
+ query = (
+ "SELECT Geno.Name AS trait_name, Strain.Name AS sample_name, "
+ "GenoData.value, GenoSE.error AS se_error, GenoData.Id AS id "
+ "FROM (GenoData, GenoFreeze, Strain, Geno, GenoXRef) "
+ "LEFT JOIN GenoSE ON "
+ "(GenoSE.DataId = GenoData.Id AND GenoSE.StrainId = GenoData.StrainId) "
+ f"WHERE Geno.SpeciesId IN ({', '.join(['%s'] * len(sp_ids))}) "
+ f"AND Geno.Name IN ({', '.join(['%s'] * len(traits))}) "
+ "AND GenoXRef.GenoId = Geno.Id "
+ "AND GenoXRef.GenoFreezeId = GenoFreeze.Id "
+ f"AND GenoFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))}) "
+ "AND GenoXRef.DataId = GenoData.Id "
+ "AND GenoData.StrainId = Strain.Id "
+ "ORDER BY Strain.Name")
+ if len(sp_ids) > 0 and len(dataset_names) > 0:
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ sp_ids +
+ tuple(trait["trait_name"] for trait in traits) +
+ tuple(dataset_names))
+ return organise_trait_data_by_trait(cursor.fetchall())
+ return {}
+
+def traits_data(
+ conn: Any, traits: Tuple[Dict[str, Any], ...],
+ samplelist: Tuple[str, ...] = tuple()) -> Dict[str, Dict[str, Any]]:
+ """
+ Retrieve trait data for multiple `traits`
+
+ This is a rework of the `gn3.db.traits.retrieve_trait_data` function.
+ """
+ def __organise__(acc, trait):
+ dataset_type = trait["db"]["dataset_type"]
+ if dataset_type == "Temp":
+ return {**acc, "Temp": acc.get("Temp", tuple()) + (trait,)}
+ if dataset_type == "Publish":
+ return {**acc, "Publish": acc.get("Publish", tuple()) + (trait,)}
+ if trait.get("cellid"):
+ return {**acc, "cellid": acc.get("cellid", tuple()) + (trait,)}
+ if dataset_type == "ProbeSet":
+ return {**acc, "ProbeSet": acc.get("ProbeSet", tuple()) + (trait,)}
+ return {**acc, "Geno": acc.get("Geno", tuple()) + (trait,)}
+
+ def __setup_samplelist__(data):
+ if samplelist:
+ return tuple(
+ item for item in
+ map(with_samplelist_data_setup(samplelist), data)
+ if item is not None)
+ return tuple(
+ item for item in
+ map(without_samplelist_data_setup(), data)
+ if item is not None)
+
+ def __process_results__(results):
+ flattened = reduce(lambda acc, res: {**acc, **res}, results)
+ return {
+ trait_name: {"data": dict(map(
+ lambda item: (
+ item["sample_name"],
+ {
+ key: val for key, val in item.items()
+ if item != "sample_name"
+ }),
+ __setup_samplelist__(data)))}
+ for trait_name, data in flattened.items()}
+
+ traits_data_fns = {
+ "Temp": temp_traits_data,
+ "Publish": publish_traits_data,
+ "cellid": cellid_traits_data,
+ "ProbeSet": probeset_traits_data,
+ "Geno": geno_traits_data
+ }
+ return __process_results__(tuple(# type: ignore[var-annotated]
+ traits_data_fns[key](conn, vals)
+ for key, vals in reduce(__organise__, traits, {}).items()))
+
+def merge_traits_and_info(traits, info_results):
+ """
+ Utility to merge trait info retrieved from the database with the given traits.
+ """
+ if info_results:
+ results = {
+ str(trait["trait_name"]): trait for trait in info_results
+ }
+ return tuple(
+ {
+ **trait,
+ **results.get(trait["trait_name"], {}),
+ "haveinfo": bool(results.get(trait["trait_name"]))
+ } for trait in traits)
+ return tuple({**trait, "haveinfo": False} for trait in traits)
+
+def publish_traits_info(
+ conn: Any, traits: Tuple[Dict[str, Any], ...]) -> Tuple[
+ Dict[str, Any], ...]:
+ """
+ Retrieve trait information for type `Publish` traits.
+
+ This is a rework of `gn3.db.traits.retrieve_publish_trait_info` function:
+ this one fetches multiple items in a single query, unlike the original that
+ fetches one item per query.
+ """
+ trait_dataset_ids = set(
+ trait["db"]["dataset_id"] for trait in traits
+ if trait["db"].get("dataset_id") is not None)
+ columns = (
+ "PublishXRef.Id, Publication.PubMed_ID, "
+ "Phenotype.Pre_publication_description, "
+ "Phenotype.Post_publication_description, "
+ "Phenotype.Original_description, "
+ "Phenotype.Pre_publication_abbreviation, "
+ "Phenotype.Post_publication_abbreviation, "
+ "Phenotype.Lab_code, Phenotype.Submitter, Phenotype.Owner, "
+ "Phenotype.Authorized_Users, "
+ "CAST(Publication.Authors AS BINARY) AS Authors, Publication.Title, "
+ "Publication.Abstract, Publication.Journal, Publication.Volume, "
+ "Publication.Pages, Publication.Month, Publication.Year, "
+ "PublishXRef.Sequence, Phenotype.Units, PublishXRef.comments")
+ query = (
+ "SELECT "
+ f"PublishXRef.Id AS trait_name, {columns} "
+ "FROM "
+ "PublishXRef, Publication, Phenotype, PublishFreeze "
+ "WHERE "
+ f"PublishXRef.Id IN ({', '.join(['%s'] * len(traits))}) "
+ "AND Phenotype.Id = PublishXRef.PhenotypeId "
+ "AND Publication.Id = PublishXRef.PublicationId "
+ "AND PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
+ "AND PublishFreeze.Id IN "
+ f"({', '.join(['%s'] * len(trait_dataset_ids))})")
+ if trait_dataset_ids:
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ (
+ tuple(trait["trait_name"] for trait in traits) +
+ tuple(trait_dataset_ids)))
+ return merge_traits_and_info(traits, cursor.fetchall())
+ return tuple({**trait, "haveinfo": False} for trait in traits)
+
+def probeset_traits_info(
+ conn: Any, traits: Tuple[Dict[str, Any], ...]):
+ """
+ Retrieve information for the probeset traits
+ """
+ dataset_names = set(trait["db"]["dataset_name"] for trait in traits)
+ columns = ", ".join(
+ [f"ProbeSet.{x}" for x in
+ ("name", "symbol", "description", "probe_target_description", "chr",
+ "mb", "alias", "geneid", "genbankid", "unigeneid", "omim",
+ "refseq_transcriptid", "blatseq", "targetseq", "chipid", "comments",
+ "strand_probe", "strand_gene", "probe_set_target_region", "proteinid",
+ "probe_set_specificity", "probe_set_blat_score",
+ "probe_set_blat_mb_start", "probe_set_blat_mb_end",
+ "probe_set_strand", "probe_set_note_by_rw", "flag")])
+ query = (
+ f"SELECT ProbeSet.Name AS trait_name, {columns} "
+ "FROM ProbeSet INNER JOIN ProbeSetXRef "
+ "ON ProbeSetXRef.ProbeSetId = ProbeSet.Id "
+ "INNER JOIN ProbeSetFreeze "
+ "ON ProbeSetFreeze.Id = ProbeSetXRef.ProbeSetFreezeId "
+ "WHERE ProbeSetFreeze.Name IN "
+ f"({', '.join(['%s'] * len(dataset_names))}) "
+ f"AND ProbeSet.Name IN ({', '.join(['%s'] * len(traits))})")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ tuple(dataset_names) + tuple(
+ trait["trait_name"] for trait in traits))
+ return merge_traits_and_info(traits, cursor.fetchall())
+ return tuple({**trait, "haveinfo": False} for trait in traits)
+
+def geno_traits_info(
+ conn: Any, traits: Tuple[Dict[str, Any], ...]):
+ """
+ Retrieve trait information for type `Geno` traits.
+
+ This is a rework of the `gn3.db.traits.retrieve_geno_trait_info` function.
+ """
+ dataset_names = set(trait["db"]["dataset_name"] for trait in traits)
+ columns = ", ".join([
+ f"Geno.{x}" for x in ("name", "chr", "mb", "source2", "sequence")])
+ query = (
+ "SELECT "
+ f"Geno.Name AS trait_name, {columns} "
+ "FROM "
+ "Geno INNER JOIN GenoXRef ON GenoXRef.GenoId = Geno.Id "
+ "INNER JOIN GenoFreeze ON GenoFreeze.Id = GenoXRef.GenoFreezeId "
+ f"WHERE GenoFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))}) "
+ f"AND Geno.Name IN ({', '.join(['%s'] * len(traits))})")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ tuple(dataset_names) + tuple(
+ trait["trait_name"] for trait in traits))
+ return merge_traits_and_info(traits, cursor.fetchall())
+ return tuple({**trait, "haveinfo": False} for trait in traits)
+
+def temp_traits_info(
+ conn: Any, traits: Tuple[Dict[str, Any], ...]):
+ """
+ Retrieve trait information for type `Temp` traits.
+
+ A rework of the `gn3.db.traits.retrieve_temp_trait_info` function.
+ """
+ query = (
+ "SELECT Name as trait_name, name, description FROM Temp "
+ f"WHERE Name IN ({', '.join(['%s'] * len(traits))})")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ tuple(trait["trait_name"] for trait in traits))
+ return merge_traits_and_info(traits, cursor.fetchall())
+ return tuple({**trait, "haveinfo": False} for trait in traits)
+
+def publish_datasets_names(
+ conn: Any, threshold: int, dataset_names: Tuple[str, ...]):
+ """
+ Get the ID, DataScale and various name formats for a `Publish` trait.
+
+ Rework of the `gn3.db.datasets.retrieve_publish_trait_dataset_name`
+ """
+ query = (
+ "SELECT DISTINCT "
+ "Id AS dataset_id, Name AS dataset_name, FullName AS dataset_fullname, "
+ "ShortName AS dataset_shortname "
+ "FROM PublishFreeze "
+ "WHERE "
+ "public > %s "
+ "AND "
+ "(Name IN ({names}) OR FullName IN ({names}) OR ShortName IN ({names}))")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query.format(names=", ".join(["%s"] * len(dataset_names))),
+ (threshold,) +(dataset_names * 3))
+ return {ds["dataset_name"]: ds for ds in cursor.fetchall()}
+ return {}
+
+def set_bxd(group_info):
+ """Set the group value to BXD if it is 'BXD300'."""
+ return {
+ **group_info,
+ "group": (
+ "BXD" if group_info.get("Name") == "BXD300"
+ else group_info.get("Name", "")),
+ "groupid": group_info["Id"]
+ }
+
+def organise_groups_by_dataset(
+ group_rows: Union[Sequence[Dict[str, Any]], None]) -> Dict[str, Any]:
+ """Utility: Organise given groups by their datasets."""
+ if group_rows:
+ return {
+ row["dataset_name"]: set_bxd({
+ key: val for key, val in row.items()
+ if key != "dataset_name"
+ }) for row in group_rows
+ }
+ return {}
+
+def publish_datasets_groups(conn: Any, dataset_names: Tuple[str]):
+ """
+ Retrieve the Group, and GroupID values for various Publish trait types.
+
+ Rework of `gn3.db.datasets.retrieve_publish_group_fields` function.
+ """
+ query = (
+ "SELECT PublishFreeze.Name AS dataset_name, InbredSet.Name, "
+ "InbredSet.Id "
+ "FROM InbredSet, PublishFreeze "
+ "WHERE PublishFreeze.InbredSetId = InbredSet.Id "
+ f"AND PublishFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))})")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(query, tuple(dataset_names))
+ return organise_groups_by_dataset(cursor.fetchall())
+ return {}
+
+def publish_traits_datasets(conn: Any, threshold, traits: Tuple[Dict]):
+ """Retrieve datasets for 'Publish' traits."""
+ dataset_names = tuple(set(trait["db"]["dataset_name"] for trait in traits))
+ dataset_names_info = publish_datasets_names(conn, threshold, dataset_names)
+ dataset_groups = publish_datasets_groups(conn, dataset_names) # type: ignore[arg-type]
+ return tuple({
+ **trait,
+ "db": {
+ **trait["db"],
+ **dataset_names_info.get(trait["db"]["dataset_name"], {}),
+ **dataset_groups.get(trait["db"]["dataset_name"], {})
+ }
+ } for trait in traits)
+
+def probeset_datasets_names(conn: Any, threshold: int, dataset_names: Tuple[str, ...]):
+ """
+ Get the ID, DataScale and various name formats for a `ProbeSet` trait.
+ """
+ query = (
+ "SELECT Id AS dataset_id, Name AS dataset_name, "
+ "FullName AS dataset_fullname, ShortName AS dataset_shortname, "
+ "DataScale AS dataset_datascale "
+ "FROM ProbeSetFreeze "
+ "WHERE "
+ "public > %s "
+ "AND "
+ "(Name IN ({names}) OR FullName IN ({names}) OR ShortName IN ({names}))")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query.format(names=", ".join(["%s"] * len(dataset_names))),
+ (threshold,) +(dataset_names * 3))
+ return {ds["dataset_name"]: ds for ds in cursor.fetchall()}
+ return {}
+
+def probeset_datasets_groups(conn, dataset_names):
+ """
+ Retrieve the Group, and GroupID values for various ProbeSet trait types.
+ """
+ query = (
+ "SELECT ProbeSetFreeze.Name AS dataset_name, InbredSet.Name, "
+ "InbredSet.Id "
+ "FROM InbredSet, ProbeSetFreeze, ProbeFreeze "
+ "WHERE ProbeFreeze.InbredSetId = InbredSet.Id "
+ "AND ProbeFreeze.Id = ProbeSetFreeze.ProbeFreezeId "
+ f"AND ProbeSetFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))})")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(query, tuple(dataset_names))
+ return organise_groups_by_dataset(cursor.fetchall())
+ return {}
+
+def probeset_traits_datasets(conn: Any, threshold, traits: Tuple[Dict]):
+ """Retrive datasets for 'ProbeSet' traits."""
+ dataset_names = tuple(set(trait["db"]["dataset_name"] for trait in traits))
+ dataset_names_info = probeset_datasets_names(conn, threshold, dataset_names)
+ dataset_groups = probeset_datasets_groups(conn, dataset_names)
+ return tuple({
+ **trait,
+ "db": {
+ **trait["db"],
+ **dataset_names_info.get(trait["db"]["dataset_name"], {}),
+ **dataset_groups.get(trait["db"]["dataset_name"], {})
+ }
+ } for trait in traits)
+
+def geno_datasets_names(conn, threshold, dataset_names):
+ """
+ Get the ID, DataScale and various name formats for a `Geno` trait.
+ """
+ query = (
+ "SELECT Id AS dataset_id, Name AS dataset_name, "
+ "FullName AS dataset_fullname, ShortName AS dataset_short_name "
+ "FROM GenoFreeze "
+ "WHERE "
+ "public > %s "
+ "AND "
+ "(Name IN ({names}) OR FullName IN ({names}) OR ShortName IN ({names}))")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query.format(names=", ".join(["%s"] * len(dataset_names))),
+ (threshold,) + (tuple(dataset_names) * 3))
+ return {ds["dataset_name"]: ds for ds in cursor.fetchall()}
+ return {}
+
+def geno_datasets_groups(conn, dataset_names):
+ """
+ Retrieve the Group, and GroupID values for various Geno trait types.
+ """
+ query = (
+ "SELECT GenoFreeze.Name AS dataset_name, InbredSet.Name, InbredSet.Id "
+ "FROM InbredSet, GenoFreeze "
+ "WHERE GenoFreeze.InbredSetId = InbredSet.Id "
+ f"AND GenoFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))})")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(query, tuple(dataset_names))
+ return organise_groups_by_dataset(cursor.fetchall())
+ return {}
+
+def geno_traits_datasets(conn: Any, threshold: int, traits: Tuple[Dict]):
+ """Retrieve datasets for 'Geno' traits."""
+ dataset_names = tuple(set(trait["db"]["dataset_name"] for trait in traits))
+ dataset_names_info = geno_datasets_names(conn, threshold, dataset_names)
+ dataset_groups = geno_datasets_groups(conn, dataset_names)
+ return tuple({
+ **trait,
+ "db": {
+ **trait["db"],
+ **dataset_names_info.get(trait["db"]["dataset_name"], {}),
+ **dataset_groups.get(trait["db"]["dataset_name"], {})
+ }
+ } for trait in traits)
+
+def temp_datasets_groups(conn, dataset_names):
+ """
+ Retrieve the Group, and GroupID values for `Temp` trait types.
+ """
+ query = (
+ "SELECT Temp.Name AS dataset_name, InbredSet.Name, InbredSet.Id "
+ "FROM InbredSet, Temp "
+ "WHERE Temp.InbredSetId = InbredSet.Id "
+ f"AND Temp.Name IN ({', '.join(['%s'] * len(dataset_names))})")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(query, tuple(dataset_names))
+ return organise_groups_by_dataset(cursor.fetchall())
+ return {}
+
+def temp_traits_datasets(conn: Any, threshold: int, traits: Tuple[Dict]): #pylint: disable=[W0613]
+ """
+ Retrieve datasets for 'Temp' traits.
+ """
+ dataset_names = tuple(set(trait["db"]["dataset_name"] for trait in traits))
+ dataset_groups = temp_datasets_groups(conn, dataset_names)
+ return tuple({
+ **trait,
+ "db": {
+ **trait["db"],
+ **dataset_groups.get(trait["db"]["dataset_name"], {})
+ }
+ } for trait in traits)
+
+def set_confidential(traits):
+ """
+ Set the confidential field for traits of type `Publish`.
+ """
+ return tuple({
+ **trait,
+ "confidential": (
+ True if (# pylint: disable=[R1719]
+ trait.get("pre_publication_description")
+ and not trait.get("pubmed_id"))
+ else False)
+ } for trait in traits)
+
+def query_qtl_info(conn, query, traits, dataset_ids):
+ """
+ Utility: Run the `query` to get the QTL information for the given `traits`.
+ """
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ tuple(trait["trait_name"] for trait in traits) + dataset_ids)
+ results = {
+ row["trait_name"]: {
+ key: val for key, val in row if key != "trait_name"
+ } for row in cursor.fetchall()
+ }
+ return tuple(
+ {**trait, **results.get(trait["trait_name"], {})}
+ for trait in traits)
+
+def set_publish_qtl_info(conn, qtl, traits):
+ """
+ Load extra QTL information for `Publish` traits
+ """
+ if qtl:
+ dataset_ids = set(trait["db"]["dataset_id"] for trait in traits)
+ query = (
+ "SELECT PublishXRef.Id AS trait_name, PublishXRef.Locus, "
+ "PublishXRef.LRS, PublishXRef.additive "
+ "FROM PublishXRef, PublishFreeze "
+ f"WHERE PublishXRef.Id IN ({', '.join(['%s'] * len(traits))}) "
+ "AND PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
+ f"AND PublishFreeze.Id IN ({', '.join(['%s'] * len(dataset_ids))})")
+ return query_qtl_info(conn, query, traits, tuple(dataset_ids))
+ return traits
+
+def set_probeset_qtl_info(conn, qtl, traits):
+ """
+ Load extra QTL information for `ProbeSet` traits
+ """
+ if qtl:
+ dataset_ids = tuple(set(trait["db"]["dataset_id"] for trait in traits))
+ query = (
+ "SELECT ProbeSet.Name AS trait_name, ProbeSetXRef.Locus, "
+ "ProbeSetXRef.LRS, ProbeSetXRef.pValue, "
+ "ProbeSetXRef.mean, ProbeSetXRef.additive "
+ "FROM ProbeSetXRef, ProbeSet "
+ "WHERE ProbeSetXRef.ProbeSetId = ProbeSet.Id "
+ f"AND ProbeSet.Name IN ({', '.join(['%s'] * len(traits))}) "
+ "AND ProbeSetXRef.ProbeSetFreezeId IN "
+ f"({', '.join(['%s'] * len(dataset_ids))})")
+ return query_qtl_info(conn, query, traits, tuple(dataset_ids))
+ return traits
+
+def set_sequence(conn, traits):
+ """
+ Retrieve 'ProbeSet' traits sequence information
+ """
+ dataset_names = set(trait["db"]["dataset_name"] for trait in traits)
+ query = (
+ "SELECT ProbeSet.Name as trait_name, ProbeSet.BlatSeq "
+ "FROM ProbeSet, ProbeSetFreeze, ProbeSetXRef "
+ "WHERE ProbeSet.Id=ProbeSetXRef.ProbeSetId "
+ "AND ProbeSetFreeze.Id = ProbeSetXRef.ProbeSetFreezeId "
+ f"AND ProbeSet.Name IN ({', '.join(['%s'] * len(traits))}) "
+ f"AND ProbeSetFreeze.Name IN ({', '.join(['%s'] * len(dataset_names))})")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(
+ query,
+ (tuple(trait["trait_name"] for trait in traits) +
+ tuple(dataset_names)))
+ results = {
+ row["trait_name"]: {
+ key: val for key, val in row.items() if key != "trait_name"
+ } for row in cursor.fetchall()
+ }
+ return tuple(
+ {
+ **trait,
+ **results.get(trait["trait_name"], {})
+ } for trait in traits)
+ return traits
+
+def set_homologene_id(conn, traits):
+ """
+ Retrieve and set the 'homologene_id' values for ProbeSet traits.
+ """
+ geneids = set(trait.get("geneid") for trait in traits if trait["haveinfo"])
+ groups = set(
+ trait["db"].get("group") for trait in traits if trait["haveinfo"])
+ if len(geneids) > 1 and len(groups) > 1:
+ query = (
+ "SELECT InbredSet.Name AS `group`, Homologene.GeneId AS geneid, "
+ "HomologeneId "
+ "FROM Homologene, Species, InbredSet "
+ f"WHERE Homologene.GeneId IN ({', '.join(['%s'] * len(geneids))}) "
+ f"AND InbredSet.Name IN ({', '.join(['%s'] * len(groups))}) "
+ "AND InbredSet.SpeciesId = Species.Id "
+ "AND Species.TaxonomyId = Homologene.TaxonomyId")
+ with conn.cursor(cursorclass=DictCursor) as cursor:
+ cursor.execute(query, (tuple(geneids) + tuple(groups)))
+ results = {
+ row["group"]: {
+ row["geneid"]: {
+ key: val for key, val in row.items()
+ if key not in ("group", "geneid")
+ }
+ } for row in cursor.fetchall()
+ }
+ return tuple(
+ {
+ **trait, **results.get(
+ trait["db"]["group"], {}).get(trait["geneid"], {})
+ } for trait in traits)
+ return traits
+
+def traits_datasets(conn, threshold, traits):
+ """
+ Retrieve datasets for various `traits`.
+ """
+ dataset_fns = {
+ "Temp": temp_traits_datasets,
+ "Geno": geno_traits_datasets,
+ "Publish": publish_traits_datasets,
+ "ProbeSet": probeset_traits_datasets
+ }
+ def __organise_by_type__(acc, trait):
+ dataset_type = trait["db"]["dataset_type"]
+ return {
+ **acc,
+ dataset_type: acc.get(dataset_type, tuple()) + (trait,)
+ }
+ with_datasets = {
+ trait["trait_fullname"]: trait for trait in (
+ item for sublist in (
+ dataset_fns[dtype](conn, threshold, ttraits)
+ for dtype, ttraits
+ in reduce(__organise_by_type__, traits, {}).items())
+ for item in sublist)}
+ return tuple(
+ {**trait, **with_datasets.get(trait["trait_fullname"], {})}
+ for trait in traits)
+
+def traits_info(
+ conn: Any, threshold: int, traits_fullnames: Tuple[str, ...],
+ qtl=None) -> Tuple[Dict[str, Any], ...]:
+ """
+ Retrieve basic trait information for multiple `traits`.
+
+ This is a rework of the `gn3.db.traits.retrieve_trait_info` function.
+ """
+ def __organise_by_dataset_type__(acc, trait):
+ dataset_type = trait["db"]["dataset_type"]
+ return {
+ **acc,
+ dataset_type: acc.get(dataset_type, tuple()) + (trait,)
+ }
+ traits = traits_datasets(
+ conn, threshold,
+ tuple(build_trait_name(trait) for trait in traits_fullnames))
+ traits_fns = {
+ "Publish": compose(
+ set_confidential, partial(set_publish_qtl_info, conn, qtl),
+ partial(publish_traits_info, conn),
+ partial(publish_traits_datasets, conn, threshold)),
+ "ProbeSet": compose(
+ partial(set_sequence, conn),
+ partial(set_probeset_qtl_info, conn, qtl),
+ partial(set_homologene_id, conn),
+ partial(probeset_traits_info, conn),
+ partial(probeset_traits_datasets, conn, threshold)),
+ "Geno": compose(
+ partial(geno_traits_info, conn),
+ partial(geno_traits_datasets, conn, threshold)),
+ "Temp": compose(
+ partial(temp_traits_info, conn),
+ partial(temp_traits_datasets, conn, threshold))
+ }
+ return tuple(
+ trait for sublist in (# type: ignore[var-annotated]
+ traits_fns[dataset_type](traits)
+ for dataset_type, traits
+ in reduce(__organise_by_dataset_type__, traits, {}).items())
+ for trait in sublist)
diff --git a/gn3/db/sample_data.py b/gn3/db/sample_data.py
new file mode 100644
index 0000000..f73954f
--- /dev/null
+++ b/gn3/db/sample_data.py
@@ -0,0 +1,365 @@
+"""Module containing functions that work with sample data"""
+from typing import Any, Tuple, Dict, Callable
+
+import MySQLdb
+
+from gn3.csvcmp import extract_strain_name
+
+
+_MAP = {
+ "PublishData": ("StrainId", "Id", "value"),
+ "PublishSE": ("StrainId", "DataId", "error"),
+ "NStrain": ("StrainId", "DataId", "count"),
+}
+
+
+def __extract_actions(original_data: str,
+ updated_data: str,
+ csv_header: str) -> Dict:
+ """Return a dictionary containing elements that need to be deleted, inserted,
+or updated.
+
+ """
+ result: Dict[str, Any] = {
+ "delete": {"data": [], "csv_header": []},
+ "insert": {"data": [], "csv_header": []},
+ "update": {"data": [], "csv_header": []},
+ }
+ strain_name = ""
+ for _o, _u, _h in zip(original_data.strip().split(","),
+ updated_data.strip().split(","),
+ csv_header.strip().split(",")):
+ if _h == "Strain Name":
+ strain_name = _o
+ if _o == _u: # No change
+ continue
+ if _o and _u == "x": # Deletion
+ result["delete"]["data"].append(_o)
+ result["delete"]["csv_header"].append(_h)
+ elif _o == "x" and _u: # Insert
+ result["insert"]["data"].append(_u)
+ result["insert"]["csv_header"].append(_h)
+ elif _o and _u: # Update
+ result["update"]["data"].append(_u)
+ result["update"]["csv_header"].append(_h)
+ for key, val in result.items():
+ if not val["data"]:
+ result[key] = None
+ else:
+ result[key]["data"] = (f"{strain_name}," +
+ ",".join(result[key]["data"]))
+ result[key]["csv_header"] = ("Strain Name," +
+ ",".join(result[key]["csv_header"]))
+ return result
+
+
+def get_trait_csv_sample_data(conn: Any,
+ trait_name: int, phenotype_id: int) -> str:
+ """Fetch a trait and return it as a csv string"""
+ __query = ("SELECT concat(st.Name, ',', ifnull(pd.value, 'x'), ',', "
+ "ifnull(ps.error, 'x'), ',', ifnull(ns.count, 'x')) as 'Data' "
+ ",ifnull(ca.Name, 'x') as 'CaseAttr', "
+ "ifnull(cxref.value, 'x') as 'Value' "
+ "FROM PublishFreeze pf "
+ "JOIN PublishXRef px ON px.InbredSetId = pf.InbredSetId "
+ "JOIN PublishData pd ON pd.Id = px.DataId "
+ "JOIN Strain st ON pd.StrainId = st.Id "
+ "LEFT JOIN PublishSE ps ON ps.DataId = pd.Id "
+ "AND ps.StrainId = pd.StrainId "
+ "LEFT JOIN NStrain ns ON ns.DataId = pd.Id "
+ "AND ns.StrainId = pd.StrainId "
+ "LEFT JOIN CaseAttributeXRefNew cxref ON "
+ "(cxref.InbredSetId = px.InbredSetId AND "
+ "cxref.StrainId = st.Id) "
+ "LEFT JOIN CaseAttribute ca ON ca.Id = cxref.CaseAttributeId "
+ "WHERE px.Id = %s AND px.PhenotypeId = %s ORDER BY st.Name")
+ case_attr_columns = set()
+ csv_data: Dict = {}
+ with conn.cursor() as cursor:
+ cursor.execute(__query, (trait_name, phenotype_id))
+ for data in cursor.fetchall():
+ if data[1] == "x":
+ csv_data[data[0]] = None
+ else:
+ sample, case_attr, value = data[0], data[1], data[2]
+ if not csv_data.get(sample):
+ csv_data[sample] = {}
+ csv_data[sample][case_attr] = None if value == "x" else value
+ case_attr_columns.add(case_attr)
+ if not case_attr_columns:
+ return ("Strain Name,Value,SE,Count\n" +
+ "\n".join(csv_data.keys()))
+ columns = sorted(case_attr_columns)
+ csv = ("Strain Name,Value,SE,Count," +
+ ",".join(columns) + "\n")
+ for key, value in csv_data.items():
+ if not value:
+ csv += (key + (len(case_attr_columns) * ",x") + "\n")
+ else:
+ vals = [str(value.get(column, "x")) for column in columns]
+ csv += (key + "," + ",".join(vals) + "\n")
+ return csv
+ return "No Sample Data Found"
+
+
+def get_sample_data_ids(conn: Any, publishxref_id: int,
+ phenotype_id: int,
+ strain_name: str) -> Tuple:
+ """Get the strain_id, publishdata_id and inbredset_id for a given strain"""
+ strain_id, publishdata_id, inbredset_id = None, None, None
+ with conn.cursor() as cursor:
+ cursor.execute("SELECT st.id, pd.Id, pf.InbredSetId "
+ "FROM PublishData pd "
+ "JOIN Strain st ON pd.StrainId = st.Id "
+ "JOIN PublishXRef px ON px.DataId = pd.Id "
+ "JOIN PublishFreeze pf ON pf.InbredSetId "
+ "= px.InbredSetId WHERE px.Id = %s "
+ "AND px.PhenotypeId = %s AND st.Name = %s",
+ (publishxref_id, phenotype_id, strain_name))
+ if _result := cursor.fetchone():
+ strain_id, publishdata_id, inbredset_id = _result
+ if not all([strain_id, publishdata_id, inbredset_id]):
+ # Applies for data to be inserted:
+ cursor.execute("SELECT DataId, InbredSetId FROM PublishXRef "
+ "WHERE Id = %s AND PhenotypeId = %s",
+ (publishxref_id, phenotype_id))
+ publishdata_id, inbredset_id = cursor.fetchone()
+ cursor.execute("SELECT Id FROM Strain WHERE Name = %s",
+ (strain_name,))
+ strain_id = cursor.fetchone()[0]
+ return (strain_id, publishdata_id, inbredset_id)
+
+
+# pylint: disable=[R0913, R0914]
+def update_sample_data(conn: Any,
+ trait_name: str,
+ original_data: str,
+ updated_data: str,
+ csv_header: str,
+ phenotype_id: int) -> int:
+ """Given the right parameters, update sample-data from the relevant
+ table."""
+ def __update_data(conn, table, value):
+ if value and value != "x":
+ with conn.cursor() as cursor:
+ sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) +
+ " = %s")
+ _val = _MAP.get(table)[-1]
+ cursor.execute((f"UPDATE {table} SET {_val} = %s "
+ f"WHERE {sub_query}"),
+ (value, strain_id, data_id))
+ return cursor.rowcount
+ return 0
+
+ def __update_case_attribute(conn, value, strain_id,
+ case_attr, inbredset_id):
+ if value != "x":
+ with conn.cursor() as cursor:
+ cursor.execute(
+ "UPDATE CaseAttributeXRefNew "
+ "SET Value = %s "
+ "WHERE StrainId = %s AND CaseAttributeId = "
+ "(SELECT CaseAttributeId FROM "
+ "CaseAttribute WHERE Name = %s) "
+ "AND InbredSetId = %s",
+ (value, strain_id, case_attr, inbredset_id))
+ return cursor.rowcount
+ return 0
+
+ strain_id, data_id, inbredset_id = get_sample_data_ids(
+ conn=conn, publishxref_id=int(trait_name),
+ phenotype_id=phenotype_id,
+ strain_name=extract_strain_name(csv_header, original_data))
+
+ none_case_attrs: Dict[str, Callable] = {
+ "Strain Name": lambda x: 0,
+ "Value": lambda x: __update_data(conn, "PublishData", x),
+ "SE": lambda x: __update_data(conn, "PublishSE", x),
+ "Count": lambda x: __update_data(conn, "NStrain", x),
+ }
+ count = 0
+ try:
+ __actions = __extract_actions(original_data=original_data,
+ updated_data=updated_data,
+ csv_header=csv_header)
+ if __actions.get("update"):
+ _csv_header = __actions["update"]["csv_header"]
+ _data = __actions["update"]["data"]
+ # pylint: disable=[E1101]
+ for header, value in zip(_csv_header.split(","),
+ _data.split(",")):
+ header = header.strip()
+ value = value.strip()
+ if header in none_case_attrs:
+ count += none_case_attrs[header](value)
+ else:
+ count += __update_case_attribute(
+ conn=conn,
+ value=none_case_attrs[header](value),
+ strain_id=strain_id,
+ case_attr=header,
+ inbredset_id=inbredset_id)
+ if __actions.get("delete"):
+ _rowcount = delete_sample_data(
+ conn=conn,
+ trait_name=trait_name,
+ data=__actions["delete"]["data"],
+ csv_header=__actions["delete"]["csv_header"],
+ phenotype_id=phenotype_id)
+ if _rowcount:
+ count += 1
+ if __actions.get("insert"):
+ _rowcount = insert_sample_data(
+ conn=conn,
+ trait_name=trait_name,
+ data=__actions["insert"]["data"],
+ csv_header=__actions["insert"]["csv_header"],
+ phenotype_id=phenotype_id)
+ if _rowcount:
+ count += 1
+ except Exception as _e:
+ conn.rollback()
+ raise MySQLdb.Error(_e) from _e
+ conn.commit()
+ return count
+
+
+def delete_sample_data(conn: Any,
+ trait_name: str,
+ data: str,
+ csv_header: str,
+ phenotype_id: int) -> int:
+ """Given the right parameters, delete sample-data from the relevant
+ tables."""
+ def __delete_data(conn, table):
+ sub_query = (" = %s AND ".join(_MAP.get(table)[:2]) + " = %s")
+ with conn.cursor() as cursor:
+ cursor.execute((f"DELETE FROM {table} "
+ f"WHERE {sub_query}"),
+ (strain_id, data_id))
+ return cursor.rowcount
+
+ def __delete_case_attribute(conn, strain_id,
+ case_attr, inbredset_id):
+ with conn.cursor() as cursor:
+ cursor.execute(
+ "DELETE FROM CaseAttributeXRefNew "
+ "WHERE StrainId = %s AND CaseAttributeId = "
+ "(SELECT CaseAttributeId FROM "
+ "CaseAttribute WHERE Name = %s) "
+ "AND InbredSetId = %s",
+ (strain_id, case_attr, inbredset_id))
+ return cursor.rowcount
+
+ strain_id, data_id, inbredset_id = get_sample_data_ids(
+ conn=conn, publishxref_id=int(trait_name),
+ phenotype_id=phenotype_id,
+ strain_name=extract_strain_name(csv_header, data))
+
+ none_case_attrs: Dict[str, Any] = {
+ "Strain Name": lambda: 0,
+ "Value": lambda: __delete_data(conn, "PublishData"),
+ "SE": lambda: __delete_data(conn, "PublishSE"),
+ "Count": lambda: __delete_data(conn, "NStrain"),
+ }
+ count = 0
+
+ try:
+ for header in csv_header.split(","):
+ header = header.strip()
+ if header in none_case_attrs:
+ count += none_case_attrs[header]()
+ else:
+ count += __delete_case_attribute(
+ conn=conn,
+ strain_id=strain_id,
+ case_attr=header,
+ inbredset_id=inbredset_id)
+ except Exception as _e:
+ conn.rollback()
+ raise MySQLdb.Error(_e) from _e
+ conn.commit()
+ return count
+
+
+# pylint: disable=[R0913, R0914]
+def insert_sample_data(conn: Any,
+ trait_name: str,
+ data: str,
+ csv_header: str,
+ phenotype_id: int) -> int:
+ """Given the right parameters, insert sample-data to the relevant table.
+
+ """
+ def __insert_data(conn, table, value):
+ if value and value != "x":
+ with conn.cursor() as cursor:
+ columns = ", ".join(_MAP.get(table))
+ cursor.execute((f"INSERT INTO {table} "
+ f"({columns}) "
+ f"VALUES (%s, %s, %s)"),
+ (strain_id, data_id, value))
+ return cursor.rowcount
+ return 0
+
+ def __insert_case_attribute(conn, case_attr, value):
+ if value != "x":
+ with conn.cursor() as cursor:
+ cursor.execute("SELECT Id FROM "
+ "CaseAttribute WHERE Name = %s",
+ (case_attr,))
+ if case_attr_id := cursor.fetchone():
+ case_attr_id = case_attr_id[0]
+ cursor.execute("SELECT StrainId FROM "
+ "CaseAttributeXRefNew WHERE StrainId = %s "
+ "AND CaseAttributeId = %s "
+ "AND InbredSetId = %s",
+ (strain_id, case_attr_id, inbredset_id))
+ if (not cursor.fetchone()) and case_attr_id:
+ cursor.execute(
+ "INSERT INTO CaseAttributeXRefNew "
+ "(StrainId, CaseAttributeId, Value, InbredSetId) "
+ "VALUES (%s, %s, %s, %s)",
+ (strain_id, case_attr_id, value, inbredset_id))
+ row_count = cursor.rowcount
+ return row_count
+ return 0
+
+ strain_id, data_id, inbredset_id = get_sample_data_ids(
+ conn=conn, publishxref_id=int(trait_name),
+ phenotype_id=phenotype_id,
+ strain_name=extract_strain_name(csv_header, data))
+
+ none_case_attrs: Dict[str, Any] = {
+ "Strain Name": lambda _: 0,
+ "Value": lambda x: __insert_data(conn, "PublishData", x),
+ "SE": lambda x: __insert_data(conn, "PublishSE", x),
+ "Count": lambda x: __insert_data(conn, "NStrain", x),
+ }
+
+ try:
+ count = 0
+
+ # Check if the data already exists:
+ with conn.cursor() as cursor:
+ cursor.execute(
+ "SELECT Id FROM PublishData where Id = %s "
+ "AND StrainId = %s",
+ (data_id, strain_id))
+ if cursor.fetchone(): # Data already exists
+ return count
+
+ for header, value in zip(csv_header.split(","), data.split(",")):
+ header = header.strip()
+ value = value.strip()
+ if header in none_case_attrs:
+ count += none_case_attrs[header](value)
+ else:
+ count += __insert_case_attribute(
+ conn=conn,
+ case_attr=header,
+ value=value)
+ return count
+ except Exception as _e:
+ conn.rollback()
+ raise MySQLdb.Error(_e) from _e
diff --git a/gn3/db/species.py b/gn3/db/species.py
index 702a9a8..5b8e096 100644
--- a/gn3/db/species.py
+++ b/gn3/db/species.py
@@ -57,3 +57,20 @@ def translate_to_mouse_gene_id(species: str, geneid: int, conn: Any) -> int:
return translated_gene_id[0]
return 0 # default if all else fails
+
+def species_name(conn: Any, group: str) -> str:
+ """
+ Retrieve the name of the species, given the group (RISet).
+
+ This is a migration of the
+ `web.webqtl.dbFunction.webqtlDatabaseFunction.retrieveSpecies` function in
+ GeneNetwork1.
+ """
+ with conn.cursor() as cursor:
+ cursor.execute(
+ ("SELECT Species.Name FROM Species, InbredSet "
+ "WHERE InbredSet.Name = %(group_name)s "
+ "AND InbredSet.SpeciesId = Species.Id"),
+ {"group_name": group})
+ return cursor.fetchone()[0]
+ return None
diff --git a/gn3/db/traits.py b/gn3/db/traits.py
index 1c6aaa7..f722e24 100644
--- a/gn3/db/traits.py
+++ b/gn3/db/traits.py
@@ -1,7 +1,7 @@
"""This class contains functions relating to trait data manipulation"""
import os
from functools import reduce
-from typing import Any, Dict, Union, Sequence
+from typing import Any, Dict, Sequence
from gn3.settings import TMPDIR
from gn3.random import random_string
@@ -67,7 +67,7 @@ def export_trait_data(
return accumulator + (trait_data["data"][sample]["ndata"], )
if dtype == "all":
return accumulator + __export_all_types(trait_data["data"], sample)
- raise KeyError("Type `%s` is incorrect" % dtype)
+ raise KeyError(f"Type `{dtype}` is incorrect")
if var_exists and n_exists:
return accumulator + (None, None, None)
if var_exists or n_exists:
@@ -76,80 +76,6 @@ def export_trait_data(
return reduce(__exporter, samplelist, tuple())
-def get_trait_csv_sample_data(conn: Any,
- trait_name: int, phenotype_id: int):
- """Fetch a trait and return it as a csv string"""
- sql = ("SELECT DISTINCT Strain.Id, PublishData.Id, Strain.Name, "
- "PublishData.value, "
- "PublishSE.error, NStrain.count FROM "
- "(PublishData, Strain, PublishXRef, PublishFreeze) "
- "LEFT JOIN PublishSE ON "
- "(PublishSE.DataId = PublishData.Id AND "
- "PublishSE.StrainId = PublishData.StrainId) "
- "LEFT JOIN NStrain ON (NStrain.DataId = PublishData.Id AND "
- "NStrain.StrainId = PublishData.StrainId) WHERE "
- "PublishXRef.InbredSetId = PublishFreeze.InbredSetId AND "
- "PublishData.Id = PublishXRef.DataId AND "
- "PublishXRef.Id = %s AND PublishXRef.PhenotypeId = %s "
- "AND PublishData.StrainId = Strain.Id Order BY Strain.Name")
- csv_data = ["Strain Id,Strain Name,Value,SE,Count"]
- publishdata_id = ""
- with conn.cursor() as cursor:
- cursor.execute(sql, (trait_name, phenotype_id,))
- for record in cursor.fetchall():
- (strain_id, publishdata_id,
- strain_name, value, error, count) = record
- csv_data.append(
- ",".join([str(val) if val else "x"
- for val in (strain_id, strain_name,
- value, error, count)]))
- return f"# Publish Data Id: {publishdata_id}\n\n" + "\n".join(csv_data)
-
-
-def update_sample_data(conn: Any,
- strain_name: str,
- strain_id: int,
- publish_data_id: int,
- value: Union[int, float, str],
- error: Union[int, float, str],
- count: Union[int, str]):
- """Given the right parameters, update sample-data from the relevant
- table."""
- # pylint: disable=[R0913, R0914, C0103]
- STRAIN_ID_SQL: str = "UPDATE Strain SET Name = %s WHERE Id = %s"
- PUBLISH_DATA_SQL: str = ("UPDATE PublishData SET value = %s "
- "WHERE StrainId = %s AND Id = %s")
- PUBLISH_SE_SQL: str = ("UPDATE PublishSE SET error = %s "
- "WHERE StrainId = %s AND DataId = %s")
- N_STRAIN_SQL: str = ("UPDATE NStrain SET count = %s "
- "WHERE StrainId = %s AND DataId = %s")
-
- updated_strains: int = 0
- updated_published_data: int = 0
- updated_se_data: int = 0
- updated_n_strains: int = 0
-
- with conn.cursor() as cursor:
- # Update the Strains table
- cursor.execute(STRAIN_ID_SQL, (strain_name, strain_id))
- updated_strains = cursor.rowcount
- # Update the PublishData table
- cursor.execute(PUBLISH_DATA_SQL,
- (None if value == "x" else value,
- strain_id, publish_data_id))
- updated_published_data = cursor.rowcount
- # Update the PublishSE table
- cursor.execute(PUBLISH_SE_SQL,
- (None if error == "x" else error,
- strain_id, publish_data_id))
- updated_se_data = cursor.rowcount
- # Update the NStrain table
- cursor.execute(N_STRAIN_SQL,
- (None if count == "x" else count,
- strain_id, publish_data_id))
- updated_n_strains = cursor.rowcount
- return (updated_strains, updated_published_data,
- updated_se_data, updated_n_strains)
def retrieve_publish_trait_info(trait_data_source: Dict[str, Any], conn: Any):
"""Retrieve trait information for type `Publish` traits.
@@ -177,24 +103,24 @@ def retrieve_publish_trait_info(trait_data_source: Dict[str, Any], conn: Any):
"PublishXRef.comments")
query = (
"SELECT "
- "{columns} "
+ f"{columns} "
"FROM "
- "PublishXRef, Publication, Phenotype, PublishFreeze "
+ "PublishXRef, Publication, Phenotype "
"WHERE "
"PublishXRef.Id = %(trait_name)s AND "
"Phenotype.Id = PublishXRef.PhenotypeId AND "
"Publication.Id = PublishXRef.PublicationId AND "
- "PublishXRef.InbredSetId = PublishFreeze.InbredSetId AND "
- "PublishFreeze.Id =%(trait_dataset_id)s").format(columns=columns)
+ "PublishXRef.InbredSetId = %(trait_dataset_id)s")
with conn.cursor() as cursor:
cursor.execute(
query,
{
- k:v for k, v in trait_data_source.items()
+ k: v for k, v in trait_data_source.items()
if k in ["trait_name", "trait_dataset_id"]
})
return dict(zip([k.lower() for k in keys], cursor.fetchone()))
+
def set_confidential_field(trait_type, trait_info):
"""Post processing function for 'Publish' trait types.
@@ -207,6 +133,7 @@ def set_confidential_field(trait_type, trait_info):
and not trait_info.get("pubmed_id", None)) else 0}
return trait_info
+
def retrieve_probeset_trait_info(trait_data_source: Dict[str, Any], conn: Any):
"""Retrieve trait information for type `ProbeSet` traits.
@@ -219,67 +146,68 @@ def retrieve_probeset_trait_info(trait_data_source: Dict[str, Any], conn: Any):
"probe_set_specificity", "probe_set_blat_score",
"probe_set_blat_mb_start", "probe_set_blat_mb_end", "probe_set_strand",
"probe_set_note_by_rw", "flag")
+ columns = (f"ProbeSet.{x}" for x in keys)
query = (
- "SELECT "
- "{columns} "
+ f"SELECT {', '.join(columns)} "
"FROM "
"ProbeSet, ProbeSetFreeze, ProbeSetXRef "
"WHERE "
"ProbeSetXRef.ProbeSetFreezeId = ProbeSetFreeze.Id AND "
"ProbeSetXRef.ProbeSetId = ProbeSet.Id AND "
"ProbeSetFreeze.Name = %(trait_dataset_name)s AND "
- "ProbeSet.Name = %(trait_name)s").format(
- columns=", ".join(["ProbeSet.{}".format(x) for x in keys]))
+ "ProbeSet.Name = %(trait_name)s")
with conn.cursor() as cursor:
cursor.execute(
query,
{
- k:v for k, v in trait_data_source.items()
+ k: v for k, v in trait_data_source.items()
if k in ["trait_name", "trait_dataset_name"]
})
return dict(zip(keys, cursor.fetchone()))
+
def retrieve_geno_trait_info(trait_data_source: Dict[str, Any], conn: Any):
"""Retrieve trait information for type `Geno` traits.
https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L438-L449"""
keys = ("name", "chr", "mb", "source2", "sequence")
+ columns = ", ".join(f"Geno.{x}" for x in keys)
query = (
- "SELECT "
- "{columns} "
+ f"SELECT {columns} "
"FROM "
- "Geno, GenoFreeze, GenoXRef "
+ "Geno INNER JOIN GenoXRef ON GenoXRef.GenoId = Geno.Id "
+ "INNER JOIN GenoFreeze ON GenoFreeze.Id = GenoXRef.GenoFreezeId "
"WHERE "
- "GenoXRef.GenoFreezeId = GenoFreeze.Id AND GenoXRef.GenoId = Geno.Id AND "
"GenoFreeze.Name = %(trait_dataset_name)s AND "
- "Geno.Name = %(trait_name)s").format(
- columns=", ".join(["Geno.{}".format(x) for x in keys]))
+ "Geno.Name = %(trait_name)s")
with conn.cursor() as cursor:
cursor.execute(
query,
{
- k:v for k, v in trait_data_source.items()
+ k: v for k, v in trait_data_source.items()
if k in ["trait_name", "trait_dataset_name"]
})
return dict(zip(keys, cursor.fetchone()))
+
def retrieve_temp_trait_info(trait_data_source: Dict[str, Any], conn: Any):
"""Retrieve trait information for type `Temp` traits.
https://github.com/genenetwork/genenetwork1/blob/master/web/webqtl/base/webqtlTrait.py#L450-452"""
keys = ("name", "description")
query = (
- "SELECT {columns} FROM Temp "
- "WHERE Name = %(trait_name)s").format(columns=", ".join(keys))
+ f"SELECT {', '.join(keys)} FROM Temp "
+ "WHERE Name = %(trait_name)s")
with conn.cursor() as cursor:
cursor.execute(
query,
{
- k:v for k, v in trait_data_source.items()
+ k: v for k, v in trait_data_source.items()
if k in ["trait_name"]
})
return dict(zip(keys, cursor.fetchone()))
+
def set_haveinfo_field(trait_info):
"""
Common postprocessing function for all trait types.
@@ -287,6 +215,7 @@ def set_haveinfo_field(trait_info):
Sets the value for the 'haveinfo' field."""
return {**trait_info, "haveinfo": 1 if trait_info else 0}
+
def set_homologene_id_field_probeset(trait_info, conn):
"""
Postprocessing function for 'ProbeSet' traits.
@@ -302,7 +231,7 @@ def set_homologene_id_field_probeset(trait_info, conn):
cursor.execute(
query,
{
- k:v for k, v in trait_info.items()
+ k: v for k, v in trait_info.items()
if k in ["geneid", "group"]
})
res = cursor.fetchone()
@@ -310,12 +239,13 @@ def set_homologene_id_field_probeset(trait_info, conn):
return {**trait_info, "homologeneid": res[0]}
return {**trait_info, "homologeneid": None}
+
def set_homologene_id_field(trait_type, trait_info, conn):
"""
Common postprocessing function for all trait types.
Sets the value for the 'homologene' key."""
- set_to_null = lambda ti: {**ti, "homologeneid": None}
+ def set_to_null(ti): return {**ti, "homologeneid": None} # pylint: disable=[C0103, C0321]
functions_table = {
"Temp": set_to_null,
"Geno": set_to_null,
@@ -324,6 +254,7 @@ def set_homologene_id_field(trait_type, trait_info, conn):
}
return functions_table[trait_type](trait_info)
+
def load_publish_qtl_info(trait_info, conn):
"""
Load extra QTL information for `Publish` traits
@@ -344,6 +275,7 @@ def load_publish_qtl_info(trait_info, conn):
return dict(zip(["locus", "lrs", "additive"], cursor.fetchone()))
return {"locus": "", "lrs": "", "additive": ""}
+
def load_probeset_qtl_info(trait_info, conn):
"""
Load extra QTL information for `ProbeSet` traits
@@ -366,6 +298,7 @@ def load_probeset_qtl_info(trait_info, conn):
["locus", "lrs", "pvalue", "mean", "additive"], cursor.fetchone()))
return {"locus": "", "lrs": "", "pvalue": "", "mean": "", "additive": ""}
+
def load_qtl_info(qtl, trait_type, trait_info, conn):
"""
Load extra QTL information for traits
@@ -389,11 +322,12 @@ def load_qtl_info(qtl, trait_type, trait_info, conn):
"Publish": load_publish_qtl_info,
"ProbeSet": load_probeset_qtl_info
}
- if trait_info["name"] not in qtl_info_functions.keys():
+ if trait_info["name"] not in qtl_info_functions:
return trait_info
return qtl_info_functions[trait_type](trait_info, conn)
+
def build_trait_name(trait_fullname):
"""
Initialises the trait's name, and other values from the search data provided
@@ -408,7 +342,7 @@ def build_trait_name(trait_fullname):
return "ProbeSet"
name_parts = trait_fullname.split("::")
- assert len(name_parts) >= 2, "Name format error"
+ assert len(name_parts) >= 2, f"Name format error: '{trait_fullname}'"
dataset_name = name_parts[0]
dataset_type = dataset_type(dataset_name)
return {
@@ -420,6 +354,7 @@ def build_trait_name(trait_fullname):
"cellid": name_parts[2] if len(name_parts) == 3 else ""
}
+
def retrieve_probeset_sequence(trait, conn):
"""
Retrieve a 'ProbeSet' trait's sequence information
@@ -441,6 +376,7 @@ def retrieve_probeset_sequence(trait, conn):
seq = cursor.fetchone()
return {**trait, "sequence": seq[0] if seq else ""}
+
def retrieve_trait_info(
threshold: int, trait_full_name: str, conn: Any,
qtl=None):
@@ -496,6 +432,7 @@ def retrieve_trait_info(
}
return trait_info
+
def retrieve_temp_trait_data(trait_info: dict, conn: Any):
"""
Retrieve trait data for `Temp` traits.
@@ -514,10 +451,12 @@ def retrieve_temp_trait_data(trait_info: dict, conn: Any):
query,
{"trait_name": trait_info["trait_name"]})
return [dict(zip(
- ["sample_name", "value", "se_error", "nstrain", "id"], row))
+ ["sample_name", "value", "se_error", "nstrain", "id"],
+ row))
for row in cursor.fetchall()]
return []
+
def retrieve_species_id(group, conn: Any):
"""
Retrieve a species id given the Group value
@@ -529,6 +468,7 @@ def retrieve_species_id(group, conn: Any):
return cursor.fetchone()[0]
return None
+
def retrieve_geno_trait_data(trait_info: Dict, conn: Any):
"""
Retrieve trait data for `Geno` traits.
@@ -552,11 +492,14 @@ def retrieve_geno_trait_data(trait_info: Dict, conn: Any):
"dataset_name": trait_info["db"]["dataset_name"],
"species_id": retrieve_species_id(
trait_info["db"]["group"], conn)})
- return [dict(zip(
- ["sample_name", "value", "se_error", "id"], row))
- for row in cursor.fetchall()]
+ return [
+ dict(zip(
+ ["sample_name", "value", "se_error", "id"],
+ row))
+ for row in cursor.fetchall()]
return []
+
def retrieve_publish_trait_data(trait_info: Dict, conn: Any):
"""
Retrieve trait data for `Publish` traits.
@@ -565,17 +508,16 @@ def retrieve_publish_trait_data(trait_info: Dict, conn: Any):
"SELECT "
"Strain.Name, PublishData.value, PublishSE.error, NStrain.count, "
"PublishData.Id "
- "FROM (PublishData, Strain, PublishXRef, PublishFreeze) "
+ "FROM (PublishData, Strain, PublishXRef) "
"LEFT JOIN PublishSE ON "
"(PublishSE.DataId = PublishData.Id "
"AND PublishSE.StrainId = PublishData.StrainId) "
"LEFT JOIN NStrain ON "
"(NStrain.DataId = PublishData.Id "
"AND NStrain.StrainId = PublishData.StrainId) "
- "WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId "
- "AND PublishData.Id = PublishXRef.DataId "
+ "WHERE PublishData.Id = PublishXRef.DataId "
"AND PublishXRef.Id = %(trait_name)s "
- "AND PublishFreeze.Id = %(dataset_id)s "
+ "AND PublishXRef.InbredSetId = %(dataset_id)s "
"AND PublishData.StrainId = Strain.Id "
"ORDER BY Strain.Name")
with conn.cursor() as cursor:
@@ -583,11 +525,13 @@ def retrieve_publish_trait_data(trait_info: Dict, conn: Any):
query,
{"trait_name": trait_info["trait_name"],
"dataset_id": trait_info["db"]["dataset_id"]})
- return [dict(zip(
- ["sample_name", "value", "se_error", "nstrain", "id"], row))
- for row in cursor.fetchall()]
+ return [
+ dict(zip(
+ ["sample_name", "value", "se_error", "nstrain", "id"], row))
+ for row in cursor.fetchall()]
return []
+
def retrieve_cellid_trait_data(trait_info: Dict, conn: Any):
"""
Retrieve trait data for `Probe Data` types.
@@ -616,11 +560,13 @@ def retrieve_cellid_trait_data(trait_info: Dict, conn: Any):
{"cellid": trait_info["cellid"],
"trait_name": trait_info["trait_name"],
"dataset_id": trait_info["db"]["dataset_id"]})
- return [dict(zip(
- ["sample_name", "value", "se_error", "id"], row))
- for row in cursor.fetchall()]
+ return [
+ dict(zip(
+ ["sample_name", "value", "se_error", "id"], row))
+ for row in cursor.fetchall()]
return []
+
def retrieve_probeset_trait_data(trait_info: Dict, conn: Any):
"""
Retrieve trait data for `ProbeSet` traits.
@@ -645,11 +591,13 @@ def retrieve_probeset_trait_data(trait_info: Dict, conn: Any):
query,
{"trait_name": trait_info["trait_name"],
"dataset_name": trait_info["db"]["dataset_name"]})
- return [dict(zip(
- ["sample_name", "value", "se_error", "id"], row))
- for row in cursor.fetchall()]
+ return [
+ dict(zip(
+ ["sample_name", "value", "se_error", "id"], row))
+ for row in cursor.fetchall()]
return []
+
def with_samplelist_data_setup(samplelist: Sequence[str]):
"""
Build function that computes the trait data from provided list of samples.
@@ -676,6 +624,7 @@ def with_samplelist_data_setup(samplelist: Sequence[str]):
return None
return setup_fn
+
def without_samplelist_data_setup():
"""
Build function that computes the trait data.
@@ -696,6 +645,7 @@ def without_samplelist_data_setup():
return None
return setup_fn
+
def retrieve_trait_data(trait: dict, conn: Any, samplelist: Sequence[str] = tuple()):
"""
Retrieve trait data
@@ -735,14 +685,16 @@ def retrieve_trait_data(trait: dict, conn: Any, samplelist: Sequence[str] = tupl
"data": dict(map(
lambda x: (
x["sample_name"],
- {k:v for k, v in x.items() if x != "sample_name"}),
+ {k: v for k, v in x.items() if x != "sample_name"}),
data))}
return {}
+
def generate_traits_filename(base_path: str = TMPDIR):
"""Generate a unique filename for use with generated traits files."""
- return "{}/traits_test_file_{}.txt".format(
- os.path.abspath(base_path), random_string(10))
+ return (
+ f"{os.path.abspath(base_path)}/traits_test_file_{random_string(10)}.txt")
+
def export_informative(trait_data: dict, inc_var: bool = False) -> tuple:
"""
@@ -765,5 +717,6 @@ def export_informative(trait_data: dict, inc_var: bool = False) -> tuple:
return acc
return reduce(
__exporter__,
- filter(lambda td: td["value"] is not None, trait_data["data"].values()),
+ filter(lambda td: td["value"] is not None,
+ trait_data["data"].values()),
(tuple(), tuple(), tuple()))