about summary refs log tree commit diff
path: root/wqflask
diff options
context:
space:
mode:
Diffstat (limited to 'wqflask')
-rw-r--r--wqflask/base/data_set/dataset.py2
-rw-r--r--wqflask/base/species.py71
-rw-r--r--wqflask/tests/unit/base/test_species.py20
-rw-r--r--wqflask/wqflask/heatmap/heatmap.py7
-rw-r--r--wqflask/wqflask/show_trait/show_trait.py6
-rw-r--r--wqflask/wqflask/snp_browser/snp_browser.py10
-rw-r--r--wqflask/wqflask/views.py104
7 files changed, 110 insertions, 110 deletions
diff --git a/wqflask/base/data_set/dataset.py b/wqflask/base/data_set/dataset.py
index f035e028..69c842ad 100644
--- a/wqflask/base/data_set/dataset.py
+++ b/wqflask/base/data_set/dataset.py
@@ -45,7 +45,7 @@ class DataSet:
             self.accession_id = self.get_accession_id()
         if get_samplelist == True:
             self.group.get_samplelist(redis_conn)
-        self.species = species.TheSpecies(self)
+        self.species = species.TheSpecies(dataset=self)
 
     def as_dict(self):
         return {
diff --git a/wqflask/base/species.py b/wqflask/base/species.py
index 0ee04630..68b00c70 100644
--- a/wqflask/base/species.py
+++ b/wqflask/base/species.py
@@ -1,23 +1,19 @@
-from collections import OrderedDict
 from dataclasses import dataclass
-from dataclasses import InitVar
-from typing import Optional, Dict, Any, Union
+from typing import Optional, Union
+from collections import OrderedDict
+
 from wqflask.database import database_connection
 
 
-@dataclass
 class TheSpecies:
     """Data related to species."""
-    dataset: Optional[Dict] = None
-    species_name: Optional[str] = None
 
-    def __post_init__(self) -> None:
-        # Just an alias of species_name.  It's safe for this to be None.
-        self.name = self.species_name
-        with database_connection() as conn:
-            self.chromosomes = Chromosomes(conn=conn,
-                                           species=self.species_name,
-                                           dataset=self.dataset)
+    def __init__(self, dataset=None, species_name=None) -> None:
+        "Initialise the Species object"
+        self.dataset = dataset
+        self.name = self.species_name = species_name
+        self.chromosomes = Chromosomes(species=species_name,
+                                       dataset=dataset)
 
 
 @dataclass
@@ -35,34 +31,31 @@ class IndChromosome:
 @dataclass
 class Chromosomes:
     """Data related to a chromosome"""
-    conn: Any
-    dataset: InitVar[Dict] = None
-    species: Optional[str] = None
 
-    def __post_init__(self, dataset) -> None:
-        if self.species is None:
+    def __init__(self, dataset, species: Optional[str]) -> None:
+        "initialise the Chromosome object"
+        self.species = species
+        if species is None:
             self.dataset = dataset
 
-    @property
-    def chromosomes(self) -> OrderedDict:
+    def chromosomes(self, db_cursor) -> OrderedDict:
         """Lazily fetch the chromosomes"""
         chromosomes = OrderedDict()
-        with database_connection() as conn, conn.cursor() as cursor:
-            if self.species is not None:
-                cursor.execute(
-                    "SELECT Chr_Length.Name, Chr_Length.OrderId, Length "
-                    "FROM Chr_Length, Species WHERE "
-                    "Chr_Length.SpeciesId = Species.SpeciesId AND "
-                    "Species.Name = %s "
-                    "ORDER BY OrderId", (self.species.capitalize(),))
-            else:
-                cursor.execute(
-                    "SELECT Chr_Length.Name, Chr_Length.OrderId, "
-                    "Length FROM Chr_Length, InbredSet WHERE "
-                    "Chr_Length.SpeciesId = InbredSet.SpeciesId AND "
-                    "InbredSet.Name = "
-                    "%s ORDER BY OrderId", (self.dataset.group.name,))
-            for name, _, length in cursor.fetchall():
-                chromosomes[name] = IndChromosome(
-                    name=name, length=length)
-            return chromosomes
+        if self.species is not None:
+            db_cursor.execute(
+                "SELECT Chr_Length.Name, Chr_Length.OrderId, Length "
+                "FROM Chr_Length, Species WHERE "
+                "Chr_Length.SpeciesId = Species.SpeciesId AND "
+                "Species.Name = %s "
+                "ORDER BY OrderId", (self.species.capitalize(),))
+        else:
+            db_cursor.execute(
+                "SELECT Chr_Length.Name, Chr_Length.OrderId, "
+                "Length FROM Chr_Length, InbredSet WHERE "
+                "Chr_Length.SpeciesId = InbredSet.SpeciesId AND "
+                "InbredSet.Name = "
+                "%s ORDER BY OrderId", (self.dataset.group.name,))
+        for name, _, length in db_cursor.fetchall():
+            chromosomes[name] = IndChromosome(
+                name=name, length=length)
+        return chromosomes
diff --git a/wqflask/tests/unit/base/test_species.py b/wqflask/tests/unit/base/test_species.py
index d7ba30a3..f12bde6d 100644
--- a/wqflask/tests/unit/base/test_species.py
+++ b/wqflask/tests/unit/base/test_species.py
@@ -30,17 +30,14 @@ class MockDataset:
      (None, "Random Dataset", None, 1)))
 def test_species(mocker, species_name, dataset,
                  expected_name, chromosome_param):
-    mock_conn = mocker.patch("base.species.database_connection")
-    mock_conn.return_value.__enter__.return_value = mocker.MagicMock()
     _c = mocker.patch("base.species.Chromosomes",
                       return_value=chromosome_param)
-    with mock_conn() as conn:
-        test_species = TheSpecies(dataset=dataset,
-                                  species_name=species_name)
-        _c.assert_called_with(conn=conn, species=species_name,
-                              dataset=dataset)
-        assert test_species.name == expected_name
-        assert test_species.chromosomes == chromosome_param
+    test_species = TheSpecies(dataset=dataset,
+                              species_name=species_name)
+    _c.assert_called_with(species=species_name,
+                          dataset=dataset)
+    assert test_species.name == expected_name
+    assert test_species.chromosomes == chromosome_param
 
 
 @pytest.mark.parametrize(
@@ -74,9 +71,8 @@ def test_create_chromosomes(mocker, species, dataset, expected_call):
         cursor.fetchall.return_value = (("1", 2, 10,),
                                         ("2", 3, 11,),
                                         ("4", 5, 15,),)
-        _c = Chromosomes(conn=mock_conn,
-                         dataset=dataset, species=species)
-        assert _c.chromosomes == OrderedDict([
+        _c = Chromosomes(dataset=dataset, species=species)
+        assert _c.chromosomes(cursor) == OrderedDict([
             ("1", IndChromosome("1", 10)),
             ("2", IndChromosome("2", 11)),
             ("4", IndChromosome("4", 15)),
diff --git a/wqflask/wqflask/heatmap/heatmap.py b/wqflask/wqflask/heatmap/heatmap.py
index 1c8a4ff6..8d9c9e7f 100644
--- a/wqflask/wqflask/heatmap/heatmap.py
+++ b/wqflask/wqflask/heatmap/heatmap.py
@@ -9,12 +9,14 @@ from utility.tools import flat_files, REAPER_COMMAND, TEMPDIR
 from redis import Redis
 from flask import Flask, g
 
+from wqflask.database import database_connection
+
 Redis = Redis()
 
 
 class Heatmap:
 
-    def __init__(self, start_vars, temp_uuid):
+    def __init__(self, db_cursor, start_vars, temp_uuid):
         trait_db_list = [trait.strip()
                          for trait in start_vars['trait_list'].split(',')]
         helper_functions.get_trait_db_obs(self, trait_db_list)
@@ -30,7 +32,8 @@ class Heatmap:
 
         chrnames = []
         self.species = species.TheSpecies(dataset=self.trait_list[0][1])
-        for key in list(self.species.chromosomes.chromosomes.keys()):
+
+        for key in list(self.species.chromosomes(db_cursor).chromosomes.keys()):
             chrnames.append([self.species.chromosomes.chromosomes[key].name,
                              self.species.chromosomes.chromosomes[key].mb_length])
 
diff --git a/wqflask/wqflask/show_trait/show_trait.py b/wqflask/wqflask/show_trait/show_trait.py
index ae6cf0cf..8cea271d 100644
--- a/wqflask/wqflask/show_trait/show_trait.py
+++ b/wqflask/wqflask/show_trait/show_trait.py
@@ -36,7 +36,7 @@ ONE_YEAR = 60 * 60 * 24 * 365
 
 
 class ShowTrait:
-    def __init__(self, user_id, kw):
+    def __init__(self, db_cursor, user_id, kw):
         self.admin_status = None
         if 'trait_id' in kw and kw['dataset'] != "Temp":
             self.temp_trait = False
@@ -197,9 +197,9 @@ class ShowTrait:
 
         # ZS: Get list of chromosomes to select for mapping
         self.chr_list = [["All", -1]]
-        for i, this_chr in enumerate(self.dataset.species.chromosomes.chromosomes):
+        for i, this_chr in enumerate(self.dataset.species.chromosomes.chromosomes(db_cursor)):
             self.chr_list.append(
-                [self.dataset.species.chromosomes.chromosomes[this_chr].name, i])
+                [self.dataset.species.chromosomes.chromosomes(db_cursor)[this_chr].name, i])
 
         self.genofiles = self.dataset.group.get_genofiles()
         study_samplelist_json = self.dataset.group.get_study_samplelists()
diff --git a/wqflask/wqflask/snp_browser/snp_browser.py b/wqflask/wqflask/snp_browser/snp_browser.py
index 0dfa3e64..d5c4e946 100644
--- a/wqflask/wqflask/snp_browser/snp_browser.py
+++ b/wqflask/wqflask/snp_browser/snp_browser.py
@@ -9,7 +9,7 @@ from wqflask.database import database_connection
 
 class SnpBrowser:
 
-    def __init__(self, start_vars):
+    def __init__(self, db_cursor, start_vars):
         self.strain_lists = get_browser_sample_lists()
         self.initialize_parameters(start_vars)
 
@@ -30,7 +30,7 @@ class SnpBrowser:
                 self.header_fields, self.empty_field_count, self.header_data_names = get_header_list(
                     variant_type=self.variant_type, strains=self.strain_lists, species=self.species_name, empty_columns=self.empty_columns)
 
-    def initialize_parameters(self, start_vars):
+    def initialize_parameters(self, db_cursor, start_vars):
         if 'first_run' in start_vars:
             self.first_run = "false"
         else:
@@ -51,13 +51,13 @@ class SnpBrowser:
         self.mouse_chr_list = []
         self.rat_chr_list = []
         mouse_species_ob = species.TheSpecies(species_name="Mouse")
-        for key in mouse_species_ob.chromosomes.chromosomes:
+        for key in mouse_species_ob.chromosomes.chromosomes(db_cursor):
             self.mouse_chr_list.append(
                 mouse_species_ob.chromosomes.chromosomes[key].name)
         rat_species_ob = species.TheSpecies(species_name="Rat")
-        for key in rat_species_ob.chromosomes.chromosomes:
+        for key in rat_species_ob.chromosomes.chromosomes(db_cursor):
             self.rat_chr_list.append(
-                rat_species_ob.chromosomes.chromosomes[key].name)
+                rat_species_ob.chromosomes.chromosomes(db_cursor)[key].name)
 
         if self.species_id == 1:
             self.this_chr_list = self.mouse_chr_list
diff --git a/wqflask/wqflask/views.py b/wqflask/wqflask/views.py
index 11482469..4ea2d529 100644
--- a/wqflask/wqflask/views.py
+++ b/wqflask/wqflask/views.py
@@ -80,6 +80,8 @@ from utility.redis_tools import get_redis_conn
 
 from base.webqtlConfig import GENERATED_IMAGE_DIR
 
+from wqflask.database import database_connection
+
 
 Redis = get_redis_conn()
 
@@ -452,26 +454,30 @@ def export_perm_data():
 
 @app.route("/show_temp_trait", methods=('POST',))
 def show_temp_trait_page():
-    user_id = ((g.user_session.record.get(b"user_id") or b"").decode("utf-8")
-               or g.user_session.record.get("user_id") or "")
-    template_vars = show_trait.ShowTrait(user_id=user_id,
-                                         kw=request.form)
-    template_vars.js_data = json.dumps(template_vars.js_data,
-                                       default=json_default_handler,
-                                       indent="   ")
-    return render_template("show_trait.html", **template_vars.__dict__)
+    with database_connection() as conn, conn.cursor() as cursor:
+        user_id = ((g.user_session.record.get(b"user_id") or b"").decode("utf-8")
+                   or g.user_session.record.get("user_id") or "")
+        template_vars = show_trait.ShowTrait(cursor,
+                                             user_id=user_id,
+                                             kw=request.form)
+        template_vars.js_data = json.dumps(template_vars.js_data,
+                                           default=json_default_handler,
+                                           indent="   ")
+        return render_template("show_trait.html", **template_vars.__dict__)
 
 
 @app.route("/show_trait")
 def show_trait_page():
-    user_id = ((g.user_session.record.get(b"user_id") or b"").decode("utf-8")
-               or g.user_session.record.get("user_id") or "")
-    template_vars = show_trait.ShowTrait(user_id=user_id,
-                                         kw=request.args)
-    template_vars.js_data = json.dumps(template_vars.js_data,
-                                       default=json_default_handler,
-                                       indent="   ")
-    return render_template("show_trait.html", **template_vars.__dict__)
+    with database_connection() as conn, conn.cursor() as cursor:
+        user_id = ((g.user_session.record.get(b"user_id") or b"").decode("utf-8")
+                   or g.user_session.record.get("user_id") or "")
+        template_vars = show_trait.ShowTrait(cursor,
+                                             user_id=user_id,
+                                             kw=request.args)
+        template_vars.js_data = json.dumps(template_vars.js_data,
+                                           default=json_default_handler,
+                                           indent="   ")
+        return render_template("show_trait.html", **template_vars.__dict__)
 
 
 @app.route("/heatmap", methods=('POST',))
@@ -480,31 +486,32 @@ def heatmap_page():
     temp_uuid = uuid.uuid4()
 
     traits = [trait.strip() for trait in start_vars['trait_list'].split(',')]
-    if traits[0] != "":
-        version = "v5"
-        key = "heatmap:{}:".format(
-            version) + json.dumps(start_vars, sort_keys=True)
-        result = Redis.get(key)
+    with database_connection() as conn, conn.cursor() as cursor:
+        if traits[0] != "":
+            version = "v5"
+            key = "heatmap:{}:".format(
+                version) + json.dumps(start_vars, sort_keys=True)
+            result = Redis.get(key)
 
-        if result:
-            result = pickle.loads(result)
+            if result:
+                result = pickle.loads(result)
 
-        else:
-            template_vars = heatmap.Heatmap(request.form, temp_uuid)
-            template_vars.js_data = json.dumps(template_vars.js_data,
-                                               default=json_default_handler,
-                                               indent="   ")
+            else:
+                template_vars = heatmap.Heatmap(cursor, request.form, temp_uuid)
+                template_vars.js_data = json.dumps(template_vars.js_data,
+                                                   default=json_default_handler,
+                                                   indent="   ")
 
-            result = template_vars.__dict__
+                result = template_vars.__dict__
 
-            pickled_result = pickle.dumps(result, pickle.HIGHEST_PROTOCOL)
-            Redis.set(key, pickled_result)
-            Redis.expire(key, 60 * 60)
-        rendered_template = render_template("heatmap.html", **result)
+                pickled_result = pickle.dumps(result, pickle.HIGHEST_PROTOCOL)
+                Redis.set(key, pickled_result)
+                Redis.expire(key, 60 * 60)
+            rendered_template = render_template("heatmap.html", **result)
 
-    else:
-        rendered_template = render_template(
-            "empty_collection.html", **{'tool': 'Heatmap'})
+        else:
+            rendered_template = render_template(
+                "empty_collection.html", **{'tool': 'Heatmap'})
 
     return rendered_template
 
@@ -856,9 +863,9 @@ def corr_scatter_plot_page():
 
 @app.route("/snp_browser", methods=('GET',))
 def snp_browser_page():
-    template_vars = snp_browser.SnpBrowser(request.args)
-
-    return render_template("snp_browser.html", **template_vars.__dict__)
+    with database_connection() as conn, conn.cursor() as cursor:
+        template_vars = snp_browser.SnpBrowser(cursor, request.args)
+        return render_template("snp_browser.html", **template_vars.__dict__)
 
 
 @app.route("/db_info", methods=('GET',))
@@ -870,15 +877,16 @@ def db_info_page():
 
 @app.route("/snp_browser_table", methods=('GET',))
 def snp_browser_table():
-    snp_table_data = snp_browser.SnpBrowser(request.args)
-    current_page = server_side.ServerSideTable(
-        snp_table_data.rows_count,
-        snp_table_data.table_rows,
-        snp_table_data.header_data_names,
-        request.args,
-    ).get_page()
-
-    return flask.jsonify(current_page)
+    with database_connection() as conn, conn.cursor() as cursor:
+        snp_table_data = snp_browser.SnpBrowser(cursor, request.args)
+        current_page = server_side.ServerSideTable(
+            snp_table_data.rows_count,
+            snp_table_data.table_rows,
+            snp_table_data.header_data_names,
+            request.args,
+        ).get_page()
+
+        return flask.jsonify(current_page)
 
 
 @app.route("/tutorial/WebQTLTour", methods=('GET',))