about summary refs log tree commit diff
diff options
context:
space:
mode:
authorFrederick Muriuki Muriithi2022-09-21 12:33:20 +0300
committerFrederick Muriuki Muriithi2022-09-21 12:37:36 +0300
commit8a51d58ed66593a420740b347ef841454e254237 (patch)
tree7b184d39444d3cac33c8ff54ab013ba9c3fb9a6c
parent347f874a82275ae9d6069f00d39b8e1737d5b3e7 (diff)
downloadgenenetwork2-8a51d58ed66593a420740b347ef841454e254237.tar.gz
Refactor: pass redis and db connections as argument
Pass the redis and database connection/cursors as arguments to the
class methods, and do not retain a copy of the connections/cursors.

This allows us to do the connection management in the context
managers elsewhere - ideally, at the top-level. For now the context
manager is within the `create_dataset` function, but this should be
moved out to a higher level once the lower levels are verified to be
working as expected.
-rw-r--r--wqflask/base/data_set/__init__.py4
-rw-r--r--wqflask/base/data_set/datasettype.py41
2 files changed, 21 insertions, 24 deletions
diff --git a/wqflask/base/data_set/__init__.py b/wqflask/base/data_set/__init__.py
index 6d475df2..4667d5ce 100644
--- a/wqflask/base/data_set/__init__.py
+++ b/wqflask/base/data_set/__init__.py
@@ -37,7 +37,9 @@ def create_dataset(dataset_name, dataset_type=None,
         dataset_type = "Temp"
 
     if not dataset_type:
-        dataset_type = DatasetType(redis_conn)(dataset_name)
+        with database_connection() as db_conn, db_conn.cursor() as cursor:
+            dataset_type = DatasetType(redis_conn)(
+                dataset_name, redis_conn, cursor)
 
     dataset_ob = DS_NAME_MAP[dataset_type]
     dataset_class = globals()[dataset_ob]
diff --git a/wqflask/base/data_set/datasettype.py b/wqflask/base/data_set/datasettype.py
index ca6515b6..05f0f564 100644
--- a/wqflask/base/data_set/datasettype.py
+++ b/wqflask/base/data_set/datasettype.py
@@ -1,11 +1,8 @@
-# builtins imports
+"DatasetType class ..."
 
 import json
 import requests
-from dataclasses import field
-from dataclasses import InitVar
 from typing import Optional, Dict
-from dataclasses import dataclass
 
 
 from redis import Redis
@@ -14,7 +11,7 @@ from redis import Redis
 from utility.tools import GN2_BASE_URL
 from wqflask.database import database_connection
 
-@dataclass
+
 class DatasetType:
     """Create a dictionary of samples where the value is set to Geno,
     Publish or ProbeSet. E.g.
@@ -30,13 +27,13 @@ class DatasetType:
          'B139_K_1206_R': 'ProbeSet' ...
         }
         """
-    redis_instance: InitVar[Redis]
-    datasets: Optional[Dict] = field(init=False, default_factory=dict)
-    data: Optional[Dict] = field(init=False)
 
-    def __post_init__(self, redis_instance):
-        self.redis_instance = redis_instance
-        data = redis_instance.get("dataset_structure")
+    def __init__(self, redis_conn):
+        "Initialise the object"
+        self.datasets = {}
+        self.data = {}
+        # self.redis_instance = redis_instance
+        data = redis_conn.get("dataset_structure")
         if data:
             self.datasets = json.loads(data)
         else:
@@ -61,11 +58,10 @@ class DatasetType:
             except Exception:  # Do nothing
                 pass
 
-            self.redis_instance.set("dataset_structure",
-                                    json.dumps(self.datasets))
+            redis_conn.set("dataset_structure", json.dumps(self.datasets))
         self.data = data
 
-    def set_dataset_key(self, t, name):
+    def set_dataset_key(self, t, name, redis_conn, db_cursor):
         """If name is not in the object's dataset dictionary, set it, and
         update dataset_structure in Redis
         args:
@@ -102,21 +98,20 @@ class DatasetType:
         if t in ['pheno', 'other_pheno']:
             group_name = name.replace("Publish", "")
 
-        with database_connection() as conn, conn.cursor() as cursor:
-            cursor.execute(sql_query_mapping[t], (group_name,))
-            if cursor.fetchone():
-                self.datasets[name] = dataset_name_mapping[t]
-                self.redis_instance.set(
-                    "dataset_structure", json.dumps(self.datasets))
-                return True
+        db_cursor.execute(sql_query_mapping[t], (group_name,))
+        if db_cursor.fetchone():
+            self.datasets[name] = dataset_name_mapping[t]
+            redis_conn.set(
+                "dataset_structure", json.dumps(self.datasets))
+            return True
 
 
-    def __call__(self, name):
+    def __call__(self, name, redis_conn, db_cursor):
         if name not in self.datasets:
             for t in ["mrna_expr", "pheno", "other_pheno", "geno"]:
                 # This has side-effects, with the end result being a
                 # truth-y value
-                if(self.set_dataset_key(t, name)):
+                if(self.set_dataset_key(t, name, redis_conn, db_cursor)):
                     break
         # Return None if name has not been set
         return self.datasets.get(name, None)