diff options
-rw-r--r-- | .guix_deploy | 8 | ||||
-rw-r--r-- | .pylintrc | 6 | ||||
-rw-r--r-- | README.md | 3 | ||||
-rw-r--r-- | gn3/api/correlation.py | 31 | ||||
-rw-r--r-- | gn3/authentication.py | 1 | ||||
-rw-r--r-- | gn3/computations/correlations.py | 5 | ||||
-rw-r--r-- | gn3/computations/partial_correlations.py | 137 | ||||
-rw-r--r-- | gn3/db/correlations.py | 22 | ||||
-rw-r--r-- | gn3/db/datasets.py | 12 | ||||
-rw-r--r-- | gn3/db/traits.py | 48 | ||||
-rw-r--r-- | mypy.ini | 12 | ||||
-rw-r--r-- | scripts/wgcna_analysis.R | 7 | ||||
-rw-r--r-- | tests/unit/computations/test_correlation.py | 1 | ||||
-rw-r--r-- | tests/unit/db/test_datasets.py | 14 | ||||
-rw-r--r-- | tests/unit/db/test_traits.py | 33 |
15 files changed, 258 insertions, 82 deletions
diff --git a/.guix_deploy b/.guix_deploy new file mode 100644 index 0000000..c7bbb5b --- /dev/null +++ b/.guix_deploy @@ -0,0 +1,8 @@ +# Deploy script on tux01 +# +# echo Run tests: +# echo python -m unittest discover -v +# echo Run service (single process): +# echo flask run --port=8080 + +/home/wrk/opt/guix-pull/bin/guix shell -L /home/wrk/guix-bioinformatics/ --expose=$HOME/production/genotype_files/ -C -N -Df guix.scm @@ -1,3 +1,7 @@ [SIMILARITIES] -ignore-imports=yes
\ No newline at end of file +ignore-imports=yes + +[MESSAGES CONTROL] + +disable=fixme
\ No newline at end of file @@ -38,7 +38,6 @@ python3 guix shell -C --network --expose=$HOME/genotype_files/ -Df guix.scm ``` - #### Using a Guix profile (or rolling back) Create a new profile with @@ -128,6 +127,8 @@ And for the scalable production version run gunicorn --bind 0.0.0.0:8080 --workers 8 --keep-alive 6000 --max-requests 10 --max-requests-jitter 5 --timeout 1200 wsgi:app ``` +(see also the [.guix_deploy](./.guix_deploy) script) + ## Using python-pip IMPORTANT NOTE: we do not recommend using pip tools, use Guix instead diff --git a/gn3/api/correlation.py b/gn3/api/correlation.py index 46121f8..1caf31f 100644 --- a/gn3/api/correlation.py +++ b/gn3/api/correlation.py @@ -1,13 +1,16 @@ """Endpoints for running correlations""" +import json from flask import jsonify from flask import Blueprint from flask import request +from flask import make_response from gn3.computations.correlations import compute_all_sample_correlation from gn3.computations.correlations import compute_all_lit_correlation from gn3.computations.correlations import compute_tissue_correlation from gn3.computations.correlations import map_shared_keys_to_values from gn3.db_utils import database_connector +from gn3.computations.partial_correlations import partial_correlations_entry correlation = Blueprint("correlation", __name__) @@ -83,3 +86,31 @@ def compute_tissue_corr(corr_method="pearson"): corr_method=corr_method) return jsonify(results) + +@correlation.route("/partial", methods=["POST"]) +def partial_correlation(): + """API endpoint for partial correlations.""" + def trait_fullname(trait): + return f"{trait['dataset']}::{trait['name']}" + + class OutputEncoder(json.JSONEncoder): + """ + Class to encode output into JSON, for objects which the default + json.JSONEncoder class does not have default encoding for. + """ + def default(self, obj): + if isinstance(obj, bytes): + return str(obj, encoding="utf-8") + return json.JSONEncoder.default(self, obj) + + args = request.get_json() + conn, _cursor_object = database_connector() + corr_results = partial_correlations_entry( + conn, trait_fullname(args["primary_trait"]), + tuple(trait_fullname(trait) for trait in args["control_traits"]), + args["method"], int(args["criteria"]), args["target_db"]) + response = make_response( + json.dumps(corr_results, cls=OutputEncoder).replace(": NaN", ": null"), + 400 if "error" in corr_results.keys() else 200) + response.headers["Content-Type"] = "application/json" + return response diff --git a/gn3/authentication.py b/gn3/authentication.py index a6372c1..d0b35bc 100644 --- a/gn3/authentication.py +++ b/gn3/authentication.py @@ -163,3 +163,4 @@ def create_group(conn: Redis, group_name: Optional[str], } conn.hset("groups", group_id, json.dumps(group)) return group + return None diff --git a/gn3/computations/correlations.py b/gn3/computations/correlations.py index 37c70e9..09288c5 100644 --- a/gn3/computations/correlations.py +++ b/gn3/computations/correlations.py @@ -8,6 +8,7 @@ from typing import List from typing import Tuple from typing import Optional from typing import Callable +from typing import Generator import scipy.stats import pingouin as pg @@ -80,7 +81,7 @@ def compute_sample_r_correlation(trait_name, corr_method, trait_vals, zip(*list(normalize_values(trait_vals, target_samples_vals)))) num_overlap = len(normalized_traits_vals) except ValueError: - return + return None if num_overlap > 5: @@ -107,7 +108,7 @@ package :not packaged in guix def filter_shared_sample_keys(this_samplelist, - target_samplelist) -> Tuple[List, List]: + target_samplelist) -> Generator: """Given primary and target sample-list for two base and target trait select filter the values using the shared keys diff --git a/gn3/computations/partial_correlations.py b/gn3/computations/partial_correlations.py index 719c605..984c15a 100644 --- a/gn3/computations/partial_correlations.py +++ b/gn3/computations/partial_correlations.py @@ -18,6 +18,7 @@ from gn3.random import random_string from gn3.function_helpers import compose from gn3.data_helpers import parse_csv_line from gn3.db.traits import export_informative +from gn3.db.datasets import retrieve_trait_dataset from gn3.db.traits import retrieve_trait_info, retrieve_trait_data from gn3.db.species import species_name, translate_to_mouse_gene_id from gn3.db.correlations import ( @@ -216,7 +217,7 @@ def good_dataset_samples_indexes( def partial_correlations_fast(# pylint: disable=[R0913, R0914] samples, primary_vals, control_vals, database_filename, fetched_correlations, method: str, correlation_type: str) -> Tuple[ - float, Tuple[float, ...]]: + int, Tuple[float, ...]]: """ Computes partial correlation coefficients using data from a CSV file. @@ -257,8 +258,9 @@ def partial_correlations_fast(# pylint: disable=[R0913, R0914] ## `correlation_type` parameter return len(all_correlations), tuple( corr + ( - (fetched_correlations[corr[0]],) if correlation_type == "literature" - else fetched_correlations[corr[0]][0:2]) + (fetched_correlations[corr[0]],) # type: ignore[index] + if correlation_type == "literature" + else fetched_correlations[corr[0]][0:2]) # type: ignore[index] for idx, corr in enumerate(all_correlations)) def build_data_frame( @@ -305,11 +307,19 @@ def compute_partial( prim for targ, prim in zip(targ_vals, primary_vals) if targ is not None] + if len(primary) < 3: + return None + + def __remove_controls_for_target_nones(cont_targ): + return tuple(cont for cont, targ in cont_targ if targ is not None) + + conts_targs = tuple(tuple( + zip(control, targ_vals)) for control in control_vals) datafrm = build_data_frame( primary, - tuple(targ for targ in targ_vals if targ is not None), - tuple(cont for i, cont in enumerate(control_vals) - if target[0][i] is not None)) + [targ for targ in targ_vals if targ is not None], + [__remove_controls_for_target_nones(cont_targ) + for cont_targ in conts_targs]) covariates = "z" if datafrm.shape[1] == 3 else [ col for col in datafrm.columns if col not in ("x", "y")] ppc = pingouin.partial_corr( @@ -332,13 +342,17 @@ def compute_partial( zero_order_corr["r"][0], zero_order_corr["p-val"][0]) return tuple( - __compute_trait_info__(target) - for target in zip(target_vals, target_names)) + result for result in ( + __compute_trait_info__(target) + for target in zip(target_vals, target_names)) + if result is not None) def partial_correlations_normal(# pylint: disable=R0913 primary_vals, control_vals, input_trait_gene_id, trait_database, data_start_pos: int, db_type: str, method: str) -> Tuple[ - float, Tuple[float, ...]]: + int, Tuple[Union[ + Tuple[str, int, float, float, float, float], None], + ...]]:#Tuple[float, ...] """ Computes the correlation coefficients. @@ -360,7 +374,7 @@ def partial_correlations_normal(# pylint: disable=R0913 return tuple(item) + (trait_database[1], trait_database[2]) return item - target_trait_names, target_trait_vals = reduce( + target_trait_names, target_trait_vals = reduce(# type: ignore[var-annotated] lambda acc, item: (acc[0]+(item[0],), acc[1]+(item[data_start_pos:],)), trait_database, (tuple(), tuple())) @@ -413,7 +427,7 @@ def partial_corrs(# pylint: disable=[R0913] data_start_pos, dataset, method) def literature_correlation_by_list( - conn: Any, species: str, trait_list: Tuple[dict]) -> Tuple[dict]: + conn: Any, species: str, trait_list: Tuple[dict]) -> Tuple[dict, ...]: """ This is a migration of the `web.webqtl.correlation.CorrelationPage.getLiteratureCorrelationByList` @@ -473,7 +487,7 @@ def literature_correlation_by_list( def tissue_correlation_by_list( conn: Any, primary_trait_symbol: str, tissue_probeset_freeze_id: int, - method: str, trait_list: Tuple[dict]) -> Tuple[dict]: + method: str, trait_list: Tuple[dict]) -> Tuple[dict, ...]: """ This is a migration of the `web.webqtl.correlation.CorrelationPage.getTissueCorrelationByList` @@ -496,7 +510,7 @@ def tissue_correlation_by_list( primary_trait_value = prim_trait_symbol_value_dict[ primary_trait_symbol.lower()] gene_symbol_list = tuple( - trait for trait in trait_list if "symbol" in trait.keys()) + trait["symbol"] for trait in trait_list if "symbol" in trait.keys()) symbol_value_dict = fetch_gene_symbol_tissue_value_dict_for_trait( gene_symbol_list, tissue_probeset_freeze_id, conn) return tuple( @@ -514,6 +528,54 @@ def tissue_correlation_by_list( } for trait in trait_list) return trait_list +def trait_for_output(trait): + """ + Process a trait for output. + + Removes a lot of extraneous data from the trait, that is not needed for + the display of partial correlation results. + This function also removes all key-value pairs, for which the value is + `None`, because it is a waste of network resources to transmit the key-value + pair just to indicate it does not exist. + """ + trait = { + "trait_type": trait["trait_type"], + "dataset_name": trait["db"]["dataset_name"], + "dataset_type": trait["db"]["dataset_type"], + "group": trait["db"]["group"], + "trait_fullname": trait["trait_fullname"], + "trait_name": trait["trait_name"], + "symbol": trait.get("symbol"), + "description": trait.get("description"), + "pre_publication_description": trait.get( + "pre_publication_description"), + "post_publication_description": trait.get( + "post_publication_description"), + "original_description": trait.get( + "original_description"), + "authors": trait.get("authors"), + "year": trait.get("year"), + "probe_target_description": trait.get( + "probe_target_description"), + "chr": trait.get("chr"), + "mb": trait.get("mb"), + "geneid": trait.get("geneid"), + "homologeneid": trait.get("homologeneid"), + "noverlap": trait.get("noverlap"), + "partial_corr": trait.get("partial_corr"), + "partial_corr_p_value": trait.get("partial_corr_p_value"), + "corr": trait.get("corr"), + "corr_p_value": trait.get("corr_p_value"), + "rank_order": trait.get("rank_order"), + "delta": ( + None if trait.get("partial_corr") is None + else (trait.get("partial_corr") - trait.get("corr"))), + "l_corr": trait.get("l_corr"), + "tissue_corr": trait.get("tissue_corr"), + "tissue_p_value": trait.get("tissue_p_value") + } + return {key: val for key, val in trait.items() if val is not None} + def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911] conn: Any, primary_trait_name: str, control_trait_names: Tuple[str, ...], method: str, @@ -640,28 +702,47 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911] "any associated Tissue Correlation Information."), "error_type": "Tissue Correlation"} + target_dataset = retrieve_trait_dataset( + ("Temp" if "Temp" in target_db_name else + ("Publish" if "Publish" in target_db_name else + "Geno" if "Geno" in target_db_name else "ProbeSet")), + {"db": {"dataset_name": target_db_name}, "trait_name": "_"}, + threshold, + conn) + database_filename = get_filename(conn, target_db_name, TEXTDIR) _total_traits, all_correlations = partial_corrs( conn, common_primary_control_samples, fixed_primary_vals, fixed_control_vals, len(fixed_primary_vals), species, input_trait_geneid, input_trait_symbol, tissue_probeset_freeze_id, - method, primary_trait["db"], database_filename) + method, {**target_dataset, "dataset_type": target_dataset["type"]}, database_filename) def __make_sorter__(method): - def __sort_6__(row): - return row[6] - - def __sort_3__(row): + def __compare_lit_or_tiss_correlation_values_(row): + # Index Content + # 0 trait name + # 1 N + # 2 partial correlation coefficient + # 3 p value of partial correlation + # 6 literature/tissue correlation value + return (row[6], row[3]) + + def __compare_partial_correlation_p_values__(row): + # Index Content + # 0 trait name + # 1 partial correlation coefficient + # 2 N + # 3 p value of partial correlation return row[3] if "literature" in method.lower(): - return __sort_6__ + return __compare_lit_or_tiss_correlation_values_ if "tissue" in method.lower(): - return __sort_6__ + return __compare_lit_or_tiss_correlation_values_ - return __sort_3__ + return __compare_partial_correlation_p_values__ sorted_correlations = sorted( all_correlations, key=__make_sorter__(method)) @@ -676,7 +757,7 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911] { **retrieve_trait_info( threshold, - f"{primary_trait['db']['dataset_name']}::{item[0]}", + f"{target_dataset['dataset_name']}::{item[0]}", conn), "noverlap": item[1], "partial_corr": item[2], @@ -694,4 +775,14 @@ def partial_correlations_entry(# pylint: disable=[R0913, R0914, R0911] for item in sorted_correlations[:min(criteria, len(all_correlations))])) - return trait_list + return { + "status": "success", + "results": { + "primary_trait": trait_for_output(primary_trait), + "control_traits": tuple( + trait_for_output(trait) for trait in cntrl_traits), + "correlations": tuple( + trait_for_output(trait) for trait in trait_list), + "dataset_type": target_dataset["type"], + "method": "spearman" if "spearman" in method.lower() else "pearson" + }} diff --git a/gn3/db/correlations.py b/gn3/db/correlations.py index 254af10..5361a1e 100644 --- a/gn3/db/correlations.py +++ b/gn3/db/correlations.py @@ -157,11 +157,12 @@ def fetch_symbol_value_pair_dict( symbol: data_id_dict.get(symbol) for symbol in symbol_list if data_id_dict.get(symbol) is not None } - query = "SELECT Id, value FROM TissueProbeSetData WHERE Id IN %(data_ids)s" + query = "SELECT Id, value FROM TissueProbeSetData WHERE Id IN ({})".format( + ",".join(f"%(id{i})s" for i in range(len(data_ids.values())))) with conn.cursor() as cursor: cursor.execute( query, - data_ids=tuple(data_ids.values())) + **{f"id{i}": did for i, did in enumerate(data_ids.values())}) value_results = cursor.fetchall() return { key: tuple(row[1] for row in value_results if row[0] == key) @@ -406,21 +407,22 @@ def fetch_sample_ids( """ query = ( "SELECT Strain.Id FROM Strain, Species " - "WHERE Strain.Name IN %(samples_names)s " + "WHERE Strain.Name IN ({}) " "AND Strain.SpeciesId=Species.Id " - "AND Species.name=%(species_name)s") + "AND Species.name=%(species_name)s").format( + ",".join(f"%(s{i})s" for i in range(len(sample_names)))) with conn.cursor() as cursor: cursor.execute( query, { - "samples_names": tuple(sample_names), + **{f"s{i}": sname for i, sname in enumerate(sample_names)}, "species_name": species_name }) return tuple(row[0] for row in cursor.fetchall()) def build_query_sgo_lit_corr( db_type: str, temp_table: str, sample_id_columns: str, - joins: Tuple[str, ...]) -> str: + joins: Tuple[str, ...]) -> Tuple[str, int]: """ Build query for `SGO Literature Correlation` data, when querying the given `temp_table` temporary table. @@ -483,14 +485,14 @@ def fetch_all_database_data(# pylint: disable=[R0913, R0914] sample_id_columns = ", ".join(f"T{smpl}.value" for smpl in sample_ids) if db_type == "Publish": joins = tuple( - ("LEFT JOIN PublishData AS T{item} " - "ON T{item}.Id = PublishXRef.DataId " - "AND T{item}.StrainId = %(T{item}_sample_id)s") + (f"LEFT JOIN PublishData AS T{item} " + f"ON T{item}.Id = PublishXRef.DataId " + f"AND T{item}.StrainId = %(T{item}_sample_id)s") for item in sample_ids) return ( ("SELECT PublishXRef.Id, " + sample_id_columns + - "FROM (PublishXRef, PublishFreeze) " + + " FROM (PublishXRef, PublishFreeze) " + " ".join(joins) + " WHERE PublishXRef.InbredSetId = PublishFreeze.InbredSetId " "AND PublishFreeze.Name = %(db_name)s"), diff --git a/gn3/db/datasets.py b/gn3/db/datasets.py index c50e148..a41e228 100644 --- a/gn3/db/datasets.py +++ b/gn3/db/datasets.py @@ -3,7 +3,7 @@ This module contains functions relating to specific trait dataset manipulation """ import re from string import Template -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from SPARQLWrapper import JSON, SPARQLWrapper from gn3.settings import SPARQL_ENDPOINT @@ -297,7 +297,7 @@ def retrieve_trait_dataset(trait_type, trait, threshold, conn): **group } -def sparql_query(query: str) -> Dict[str, Any]: +def sparql_query(query: str) -> List[Dict[str, Any]]: """Run a SPARQL query and return the bound variables.""" sparql = SPARQLWrapper(SPARQL_ENDPOINT) sparql.setQuery(query) @@ -328,7 +328,7 @@ WHERE { OPTIONAL { ?dataset gn:geoSeries ?geo_series } . } """, - """ + """ PREFIX gn: <http://genenetwork.org/> SELECT ?platform_name ?normalization_name ?species_name ?inbred_set_name ?tissue_name WHERE { @@ -341,7 +341,7 @@ WHERE { OPTIONAL { ?dataset gn:datasetOfPlatform / gn:name ?platform_name } . } """, - """ + """ PREFIX gn: <http://genenetwork.org/> SELECT ?specifics ?summary ?about_cases ?about_tissue ?about_platform ?about_data_processing ?notes ?experiment_design ?contributors @@ -362,8 +362,8 @@ WHERE { OPTIONAL { ?dataset gn:acknowledgment ?acknowledgment . } } """] - result = {'accession_id': accession_id, - 'investigator': {}} + result: Dict[str, Any] = {'accession_id': accession_id, + 'investigator': {}} query_result = {} for query in queries: if sparql_result := sparql_query(Template(query).substitute(accession_id=accession_id)): diff --git a/gn3/db/traits.py b/gn3/db/traits.py index 7994aef..338b320 100644 --- a/gn3/db/traits.py +++ b/gn3/db/traits.py @@ -1,6 +1,5 @@ """This class contains functions relating to trait data manipulation""" import os -import MySQLdb from functools import reduce from typing import Any, Dict, Union, Sequence @@ -111,7 +110,6 @@ def get_trait_csv_sample_data(conn: Any, def update_sample_data(conn: Any, #pylint: disable=[R0913] - trait_name: str, strain_name: str, phenotype_id: int, @@ -204,25 +202,30 @@ def delete_sample_data(conn: Any, "AND Strain.Name = \"%s\"") % (trait_name, phenotype_id, str(strain_name))) - strain_id, data_id = cursor.fetchone() - cursor.execute(("DELETE FROM PublishData " + # Check if it exists if the data was already deleted: + if _result := cursor.fetchone(): + strain_id, data_id = _result + + # Only run if the strain_id and data_id exist + if strain_id and data_id: + cursor.execute(("DELETE FROM PublishData " "WHERE StrainId = %s AND Id = %s") - % (strain_id, data_id)) - deleted_published_data = cursor.rowcount - - # Delete the PublishSE table - cursor.execute(("DELETE FROM PublishSE " - "WHERE StrainId = %s AND DataId = %s") % - (strain_id, data_id)) - deleted_se_data = cursor.rowcount - - # Delete the NStrain table - cursor.execute(("DELETE FROM NStrain " - "WHERE StrainId = %s AND DataId = %s" % - (strain_id, data_id))) - deleted_n_strains = cursor.rowcount - except Exception as e: #pylint: disable=[C0103, W0612] + % (strain_id, data_id)) + deleted_published_data = cursor.rowcount + + # Delete the PublishSE table + cursor.execute(("DELETE FROM PublishSE " + "WHERE StrainId = %s AND DataId = %s") % + (strain_id, data_id)) + deleted_se_data = cursor.rowcount + + # Delete the NStrain table + cursor.execute(("DELETE FROM NStrain " + "WHERE StrainId = %s AND DataId = %s" % + (strain_id, data_id))) + deleted_n_strains = cursor.rowcount + except Exception as e: #pylint: disable=[C0103, W0612] conn.rollback() raise MySQLdb.Error conn.commit() @@ -255,6 +258,13 @@ def insert_sample_data(conn: Any, #pylint: disable=[R0913] (strain_name,)) strain_id = cursor.fetchone() + # Return early if an insert already exists! + cursor.execute("SELECT Id FROM PublishData where Id = %s " + "AND StrainId = %s", + (data_id, strain_id)) + if cursor.fetchone(): # This strain already exists + return (0, 0, 0) + # Insert the PublishData table cursor.execute(("INSERT INTO PublishData (Id, StrainId, value)" "VALUES (%s, %s, %s)"), @@ -19,4 +19,16 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-requests.*] +ignore_missing_imports = True + +[mypy-flask.*] +ignore_missing_imports = True + +[mypy-werkzeug.*] +ignore_missing_imports = True + +[mypy-SPARQLWrapper.*] +ignore_missing_imports = True + +[mypy-pandas.*] ignore_missing_imports = True
\ No newline at end of file diff --git a/scripts/wgcna_analysis.R b/scripts/wgcna_analysis.R index b0d25a9..d368013 100644 --- a/scripts/wgcna_analysis.R +++ b/scripts/wgcna_analysis.R @@ -25,6 +25,7 @@ if (length(args)==0) { inputData <- fromJSON(file = json_file_path) imgDir = inputData$TMPDIR +inputData trait_sample_data <- do.call(rbind, inputData$trait_sample_data) @@ -51,10 +52,8 @@ names(dataExpr) = inputData$trait_names # Allow multi-threading within WGCNA enableWGCNAThreads() -# choose softthreshhold (Calculate soft threshold) -# xtodo allow users to pass args - -powers <- c(c(1:10), seq(from = 12, to=20, by=2)) +# powers <- c(c(1:10), seq(from = 12, to=20, by=2)) +powers <- unlist(c(inputData$SoftThresholds)) sft <- pickSoftThreshold(dataExpr, powerVector = powers, verbose = 5) # check the power estimate diff --git a/tests/unit/computations/test_correlation.py b/tests/unit/computations/test_correlation.py index 0de347d..7523d99 100644 --- a/tests/unit/computations/test_correlation.py +++ b/tests/unit/computations/test_correlation.py @@ -1,7 +1,6 @@ """Module contains the tests for correlation""" from unittest import TestCase from unittest import mock -import unittest from collections import namedtuple import math diff --git a/tests/unit/db/test_datasets.py b/tests/unit/db/test_datasets.py index 39f4af9..0b8c2fe 100644 --- a/tests/unit/db/test_datasets.py +++ b/tests/unit/db/test_datasets.py @@ -13,15 +13,17 @@ class TestDatasetsDBFunctions(TestCase): def test_retrieve_dataset_name(self): """Test that the function is called correctly.""" - for trait_type, thresh, trait_name, dataset_name, columns, table in [ + for trait_type, thresh, trait_name, dataset_name, columns, table, expected in [ ["ProbeSet", 9, "probesetTraitName", "probesetDatasetName", - "Id, Name, FullName, ShortName, DataScale", "ProbeSetFreeze"], + "Id, Name, FullName, ShortName, DataScale", "ProbeSetFreeze", + {"dataset_id": None, "dataset_name": "probesetDatasetName", + "dataset_fullname": "probesetDatasetName"}], ["Geno", 3, "genoTraitName", "genoDatasetName", - "Id, Name, FullName, ShortName", "GenoFreeze"], + "Id, Name, FullName, ShortName", "GenoFreeze", {}], ["Publish", 6, "publishTraitName", "publishDatasetName", - "Id, Name, FullName, ShortName", "PublishFreeze"], + "Id, Name, FullName, ShortName", "PublishFreeze", {}], ["Temp", 4, "tempTraitName", "tempTraitName", - "Id, Name, FullName, ShortName", "TempFreeze"]]: + "Id, Name, FullName, ShortName", "TempFreeze", {}]]: db_mock = mock.MagicMock() with self.subTest(trait_type=trait_type): with db_mock.cursor() as cursor: @@ -29,7 +31,7 @@ class TestDatasetsDBFunctions(TestCase): self.assertEqual( retrieve_dataset_name( trait_type, thresh, trait_name, dataset_name, db_mock), - {}) + expected) cursor.execute.assert_called_once_with( "SELECT {cols} " "FROM {table} " diff --git a/tests/unit/db/test_traits.py b/tests/unit/db/test_traits.py index 4aa9389..75f3d4c 100644 --- a/tests/unit/db/test_traits.py +++ b/tests/unit/db/test_traits.py @@ -202,8 +202,6 @@ class TestTraitsDBFunctions(TestCase): """ # pylint: disable=C0103 db_mock = mock.MagicMock() - - STRAIN_ID_SQL: str = "UPDATE Strain SET Name = %s WHERE Id = %s" PUBLISH_DATA_SQL: str = ( "UPDATE PublishData SET value = %s " "WHERE StrainId = %s AND Id = %s") @@ -216,16 +214,33 @@ class TestTraitsDBFunctions(TestCase): with db_mock.cursor() as cursor: type(cursor).rowcount = 1 + mock_fetchone = mock.MagicMock() + mock_fetchone.return_value = (1, 1) + type(cursor).fetchone = mock_fetchone self.assertEqual(update_sample_data( conn=db_mock, strain_name="BXD11", - strain_id=10, publish_data_id=8967049, - value=18.7, error=2.3, count=2), - (1, 1, 1, 1)) + trait_name="1", + phenotype_id=10, value=18.7, + error=2.3, count=2), + (1, 1, 1)) cursor.execute.assert_has_calls( - [mock.call(STRAIN_ID_SQL, ('BXD11', 10)), - mock.call(PUBLISH_DATA_SQL, (18.7, 10, 8967049)), - mock.call(PUBLISH_SE_SQL, (2.3, 10, 8967049)), - mock.call(N_STRAIN_SQL, (2, 10, 8967049))] + [mock.call('SELECT Strain.Id, PublishData.Id FROM' + ' (PublishData, Strain, PublishXRef, ' + 'PublishFreeze) LEFT JOIN PublishSE ON ' + '(PublishSE.DataId = PublishData.Id ' + 'AND PublishSE.StrainId = ' + 'PublishData.StrainId) LEFT JOIN NStrain ON ' + '(NStrain.DataId = PublishData.Id AND ' + 'NStrain.StrainId = PublishData.StrainId) WHERE ' + 'PublishXRef.InbredSetId = ' + 'PublishFreeze.InbredSetId AND PublishData.Id = ' + 'PublishXRef.DataId AND PublishXRef.Id = 1 AND ' + 'PublishXRef.PhenotypeId = 10 AND ' + 'PublishData.StrainId = Strain.Id AND ' + 'Strain.Name = "BXD11"'), + mock.call(PUBLISH_DATA_SQL, (18.7, 1, 1)), + mock.call(PUBLISH_SE_SQL, (2.3, 1, 1)), + mock.call(N_STRAIN_SQL, (2, 1, 1))] ) def test_set_haveinfo_field(self): |