diff options
Diffstat (limited to 'wqflask')
-rw-r--r-- | wqflask/base/data_set/dataset.py | 2 | ||||
-rw-r--r-- | wqflask/base/species.py | 71 | ||||
-rw-r--r-- | wqflask/tests/unit/base/test_species.py | 20 | ||||
-rw-r--r-- | wqflask/wqflask/heatmap/heatmap.py | 7 | ||||
-rw-r--r-- | wqflask/wqflask/show_trait/show_trait.py | 6 | ||||
-rw-r--r-- | wqflask/wqflask/snp_browser/snp_browser.py | 10 | ||||
-rw-r--r-- | wqflask/wqflask/views.py | 104 |
7 files changed, 110 insertions, 110 deletions
diff --git a/wqflask/base/data_set/dataset.py b/wqflask/base/data_set/dataset.py index f035e028..69c842ad 100644 --- a/wqflask/base/data_set/dataset.py +++ b/wqflask/base/data_set/dataset.py @@ -45,7 +45,7 @@ class DataSet: self.accession_id = self.get_accession_id() if get_samplelist == True: self.group.get_samplelist(redis_conn) - self.species = species.TheSpecies(self) + self.species = species.TheSpecies(dataset=self) def as_dict(self): return { diff --git a/wqflask/base/species.py b/wqflask/base/species.py index 0ee04630..68b00c70 100644 --- a/wqflask/base/species.py +++ b/wqflask/base/species.py @@ -1,23 +1,19 @@ -from collections import OrderedDict from dataclasses import dataclass -from dataclasses import InitVar -from typing import Optional, Dict, Any, Union +from typing import Optional, Union +from collections import OrderedDict + from wqflask.database import database_connection -@dataclass class TheSpecies: """Data related to species.""" - dataset: Optional[Dict] = None - species_name: Optional[str] = None - def __post_init__(self) -> None: - # Just an alias of species_name. It's safe for this to be None. - self.name = self.species_name - with database_connection() as conn: - self.chromosomes = Chromosomes(conn=conn, - species=self.species_name, - dataset=self.dataset) + def __init__(self, dataset=None, species_name=None) -> None: + "Initialise the Species object" + self.dataset = dataset + self.name = self.species_name = species_name + self.chromosomes = Chromosomes(species=species_name, + dataset=dataset) @dataclass @@ -35,34 +31,31 @@ class IndChromosome: @dataclass class Chromosomes: """Data related to a chromosome""" - conn: Any - dataset: InitVar[Dict] = None - species: Optional[str] = None - def __post_init__(self, dataset) -> None: - if self.species is None: + def __init__(self, dataset, species: Optional[str]) -> None: + "initialise the Chromosome object" + self.species = species + if species is None: self.dataset = dataset - @property - def chromosomes(self) -> OrderedDict: + def chromosomes(self, db_cursor) -> OrderedDict: """Lazily fetch the chromosomes""" chromosomes = OrderedDict() - with database_connection() as conn, conn.cursor() as cursor: - if self.species is not None: - cursor.execute( - "SELECT Chr_Length.Name, Chr_Length.OrderId, Length " - "FROM Chr_Length, Species WHERE " - "Chr_Length.SpeciesId = Species.SpeciesId AND " - "Species.Name = %s " - "ORDER BY OrderId", (self.species.capitalize(),)) - else: - cursor.execute( - "SELECT Chr_Length.Name, Chr_Length.OrderId, " - "Length FROM Chr_Length, InbredSet WHERE " - "Chr_Length.SpeciesId = InbredSet.SpeciesId AND " - "InbredSet.Name = " - "%s ORDER BY OrderId", (self.dataset.group.name,)) - for name, _, length in cursor.fetchall(): - chromosomes[name] = IndChromosome( - name=name, length=length) - return chromosomes + if self.species is not None: + db_cursor.execute( + "SELECT Chr_Length.Name, Chr_Length.OrderId, Length " + "FROM Chr_Length, Species WHERE " + "Chr_Length.SpeciesId = Species.SpeciesId AND " + "Species.Name = %s " + "ORDER BY OrderId", (self.species.capitalize(),)) + else: + db_cursor.execute( + "SELECT Chr_Length.Name, Chr_Length.OrderId, " + "Length FROM Chr_Length, InbredSet WHERE " + "Chr_Length.SpeciesId = InbredSet.SpeciesId AND " + "InbredSet.Name = " + "%s ORDER BY OrderId", (self.dataset.group.name,)) + for name, _, length in db_cursor.fetchall(): + chromosomes[name] = IndChromosome( + name=name, length=length) + return chromosomes diff --git a/wqflask/tests/unit/base/test_species.py b/wqflask/tests/unit/base/test_species.py index d7ba30a3..f12bde6d 100644 --- a/wqflask/tests/unit/base/test_species.py +++ b/wqflask/tests/unit/base/test_species.py @@ -30,17 +30,14 @@ class MockDataset: (None, "Random Dataset", None, 1))) def test_species(mocker, species_name, dataset, expected_name, chromosome_param): - mock_conn = mocker.patch("base.species.database_connection") - mock_conn.return_value.__enter__.return_value = mocker.MagicMock() _c = mocker.patch("base.species.Chromosomes", return_value=chromosome_param) - with mock_conn() as conn: - test_species = TheSpecies(dataset=dataset, - species_name=species_name) - _c.assert_called_with(conn=conn, species=species_name, - dataset=dataset) - assert test_species.name == expected_name - assert test_species.chromosomes == chromosome_param + test_species = TheSpecies(dataset=dataset, + species_name=species_name) + _c.assert_called_with(species=species_name, + dataset=dataset) + assert test_species.name == expected_name + assert test_species.chromosomes == chromosome_param @pytest.mark.parametrize( @@ -74,9 +71,8 @@ def test_create_chromosomes(mocker, species, dataset, expected_call): cursor.fetchall.return_value = (("1", 2, 10,), ("2", 3, 11,), ("4", 5, 15,),) - _c = Chromosomes(conn=mock_conn, - dataset=dataset, species=species) - assert _c.chromosomes == OrderedDict([ + _c = Chromosomes(dataset=dataset, species=species) + assert _c.chromosomes(cursor) == OrderedDict([ ("1", IndChromosome("1", 10)), ("2", IndChromosome("2", 11)), ("4", IndChromosome("4", 15)), diff --git a/wqflask/wqflask/heatmap/heatmap.py b/wqflask/wqflask/heatmap/heatmap.py index 1c8a4ff6..8d9c9e7f 100644 --- a/wqflask/wqflask/heatmap/heatmap.py +++ b/wqflask/wqflask/heatmap/heatmap.py @@ -9,12 +9,14 @@ from utility.tools import flat_files, REAPER_COMMAND, TEMPDIR from redis import Redis from flask import Flask, g +from wqflask.database import database_connection + Redis = Redis() class Heatmap: - def __init__(self, start_vars, temp_uuid): + def __init__(self, db_cursor, start_vars, temp_uuid): trait_db_list = [trait.strip() for trait in start_vars['trait_list'].split(',')] helper_functions.get_trait_db_obs(self, trait_db_list) @@ -30,7 +32,8 @@ class Heatmap: chrnames = [] self.species = species.TheSpecies(dataset=self.trait_list[0][1]) - for key in list(self.species.chromosomes.chromosomes.keys()): + + for key in list(self.species.chromosomes(db_cursor).chromosomes.keys()): chrnames.append([self.species.chromosomes.chromosomes[key].name, self.species.chromosomes.chromosomes[key].mb_length]) diff --git a/wqflask/wqflask/show_trait/show_trait.py b/wqflask/wqflask/show_trait/show_trait.py index ae6cf0cf..8cea271d 100644 --- a/wqflask/wqflask/show_trait/show_trait.py +++ b/wqflask/wqflask/show_trait/show_trait.py @@ -36,7 +36,7 @@ ONE_YEAR = 60 * 60 * 24 * 365 class ShowTrait: - def __init__(self, user_id, kw): + def __init__(self, db_cursor, user_id, kw): self.admin_status = None if 'trait_id' in kw and kw['dataset'] != "Temp": self.temp_trait = False @@ -197,9 +197,9 @@ class ShowTrait: # ZS: Get list of chromosomes to select for mapping self.chr_list = [["All", -1]] - for i, this_chr in enumerate(self.dataset.species.chromosomes.chromosomes): + for i, this_chr in enumerate(self.dataset.species.chromosomes.chromosomes(db_cursor)): self.chr_list.append( - [self.dataset.species.chromosomes.chromosomes[this_chr].name, i]) + [self.dataset.species.chromosomes.chromosomes(db_cursor)[this_chr].name, i]) self.genofiles = self.dataset.group.get_genofiles() study_samplelist_json = self.dataset.group.get_study_samplelists() diff --git a/wqflask/wqflask/snp_browser/snp_browser.py b/wqflask/wqflask/snp_browser/snp_browser.py index 0dfa3e64..d5c4e946 100644 --- a/wqflask/wqflask/snp_browser/snp_browser.py +++ b/wqflask/wqflask/snp_browser/snp_browser.py @@ -9,7 +9,7 @@ from wqflask.database import database_connection class SnpBrowser: - def __init__(self, start_vars): + def __init__(self, db_cursor, start_vars): self.strain_lists = get_browser_sample_lists() self.initialize_parameters(start_vars) @@ -30,7 +30,7 @@ class SnpBrowser: self.header_fields, self.empty_field_count, self.header_data_names = get_header_list( variant_type=self.variant_type, strains=self.strain_lists, species=self.species_name, empty_columns=self.empty_columns) - def initialize_parameters(self, start_vars): + def initialize_parameters(self, db_cursor, start_vars): if 'first_run' in start_vars: self.first_run = "false" else: @@ -51,13 +51,13 @@ class SnpBrowser: self.mouse_chr_list = [] self.rat_chr_list = [] mouse_species_ob = species.TheSpecies(species_name="Mouse") - for key in mouse_species_ob.chromosomes.chromosomes: + for key in mouse_species_ob.chromosomes.chromosomes(db_cursor): self.mouse_chr_list.append( mouse_species_ob.chromosomes.chromosomes[key].name) rat_species_ob = species.TheSpecies(species_name="Rat") - for key in rat_species_ob.chromosomes.chromosomes: + for key in rat_species_ob.chromosomes.chromosomes(db_cursor): self.rat_chr_list.append( - rat_species_ob.chromosomes.chromosomes[key].name) + rat_species_ob.chromosomes.chromosomes(db_cursor)[key].name) if self.species_id == 1: self.this_chr_list = self.mouse_chr_list diff --git a/wqflask/wqflask/views.py b/wqflask/wqflask/views.py index 11482469..4ea2d529 100644 --- a/wqflask/wqflask/views.py +++ b/wqflask/wqflask/views.py @@ -80,6 +80,8 @@ from utility.redis_tools import get_redis_conn from base.webqtlConfig import GENERATED_IMAGE_DIR +from wqflask.database import database_connection + Redis = get_redis_conn() @@ -452,26 +454,30 @@ def export_perm_data(): @app.route("/show_temp_trait", methods=('POST',)) def show_temp_trait_page(): - user_id = ((g.user_session.record.get(b"user_id") or b"").decode("utf-8") - or g.user_session.record.get("user_id") or "") - template_vars = show_trait.ShowTrait(user_id=user_id, - kw=request.form) - template_vars.js_data = json.dumps(template_vars.js_data, - default=json_default_handler, - indent=" ") - return render_template("show_trait.html", **template_vars.__dict__) + with database_connection() as conn, conn.cursor() as cursor: + user_id = ((g.user_session.record.get(b"user_id") or b"").decode("utf-8") + or g.user_session.record.get("user_id") or "") + template_vars = show_trait.ShowTrait(cursor, + user_id=user_id, + kw=request.form) + template_vars.js_data = json.dumps(template_vars.js_data, + default=json_default_handler, + indent=" ") + return render_template("show_trait.html", **template_vars.__dict__) @app.route("/show_trait") def show_trait_page(): - user_id = ((g.user_session.record.get(b"user_id") or b"").decode("utf-8") - or g.user_session.record.get("user_id") or "") - template_vars = show_trait.ShowTrait(user_id=user_id, - kw=request.args) - template_vars.js_data = json.dumps(template_vars.js_data, - default=json_default_handler, - indent=" ") - return render_template("show_trait.html", **template_vars.__dict__) + with database_connection() as conn, conn.cursor() as cursor: + user_id = ((g.user_session.record.get(b"user_id") or b"").decode("utf-8") + or g.user_session.record.get("user_id") or "") + template_vars = show_trait.ShowTrait(cursor, + user_id=user_id, + kw=request.args) + template_vars.js_data = json.dumps(template_vars.js_data, + default=json_default_handler, + indent=" ") + return render_template("show_trait.html", **template_vars.__dict__) @app.route("/heatmap", methods=('POST',)) @@ -480,31 +486,32 @@ def heatmap_page(): temp_uuid = uuid.uuid4() traits = [trait.strip() for trait in start_vars['trait_list'].split(',')] - if traits[0] != "": - version = "v5" - key = "heatmap:{}:".format( - version) + json.dumps(start_vars, sort_keys=True) - result = Redis.get(key) + with database_connection() as conn, conn.cursor() as cursor: + if traits[0] != "": + version = "v5" + key = "heatmap:{}:".format( + version) + json.dumps(start_vars, sort_keys=True) + result = Redis.get(key) - if result: - result = pickle.loads(result) + if result: + result = pickle.loads(result) - else: - template_vars = heatmap.Heatmap(request.form, temp_uuid) - template_vars.js_data = json.dumps(template_vars.js_data, - default=json_default_handler, - indent=" ") + else: + template_vars = heatmap.Heatmap(cursor, request.form, temp_uuid) + template_vars.js_data = json.dumps(template_vars.js_data, + default=json_default_handler, + indent=" ") - result = template_vars.__dict__ + result = template_vars.__dict__ - pickled_result = pickle.dumps(result, pickle.HIGHEST_PROTOCOL) - Redis.set(key, pickled_result) - Redis.expire(key, 60 * 60) - rendered_template = render_template("heatmap.html", **result) + pickled_result = pickle.dumps(result, pickle.HIGHEST_PROTOCOL) + Redis.set(key, pickled_result) + Redis.expire(key, 60 * 60) + rendered_template = render_template("heatmap.html", **result) - else: - rendered_template = render_template( - "empty_collection.html", **{'tool': 'Heatmap'}) + else: + rendered_template = render_template( + "empty_collection.html", **{'tool': 'Heatmap'}) return rendered_template @@ -856,9 +863,9 @@ def corr_scatter_plot_page(): @app.route("/snp_browser", methods=('GET',)) def snp_browser_page(): - template_vars = snp_browser.SnpBrowser(request.args) - - return render_template("snp_browser.html", **template_vars.__dict__) + with database_connection() as conn, conn.cursor() as cursor: + template_vars = snp_browser.SnpBrowser(cursor, request.args) + return render_template("snp_browser.html", **template_vars.__dict__) @app.route("/db_info", methods=('GET',)) @@ -870,15 +877,16 @@ def db_info_page(): @app.route("/snp_browser_table", methods=('GET',)) def snp_browser_table(): - snp_table_data = snp_browser.SnpBrowser(request.args) - current_page = server_side.ServerSideTable( - snp_table_data.rows_count, - snp_table_data.table_rows, - snp_table_data.header_data_names, - request.args, - ).get_page() - - return flask.jsonify(current_page) + with database_connection() as conn, conn.cursor() as cursor: + snp_table_data = snp_browser.SnpBrowser(cursor, request.args) + current_page = server_side.ServerSideTable( + snp_table_data.rows_count, + snp_table_data.table_rows, + snp_table_data.header_data_names, + request.args, + ).get_page() + + return flask.jsonify(current_page) @app.route("/tutorial/WebQTLTour", methods=('GET',)) |