about summary refs log tree commit diff
path: root/wqflask
diff options
context:
space:
mode:
Diffstat (limited to 'wqflask')
-rw-r--r--wqflask/base/webqtlConfig.py1
-rw-r--r--wqflask/mock/__init__.py (renamed from wqflask/tests/__init__.py)0
-rw-r--r--wqflask/mock/es_double.py15
-rw-r--r--wqflask/run_gunicorn.py3
-rw-r--r--wqflask/tests/es_double.py30
-rw-r--r--wqflask/tests/test_registration.py113
-rw-r--r--wqflask/utility/elasticsearch_tools.py16
-rw-r--r--wqflask/utility/tools.py42
-rw-r--r--wqflask/wqflask/correlation/show_corr_results.py100
-rw-r--r--wqflask/wqflask/marker_regression/gemma_mapping.py16
-rw-r--r--wqflask/wqflask/user_manager.py7
11 files changed, 133 insertions, 210 deletions
diff --git a/wqflask/base/webqtlConfig.py b/wqflask/base/webqtlConfig.py
index 1ef2bc26..1e66e957 100644
--- a/wqflask/base/webqtlConfig.py
+++ b/wqflask/base/webqtlConfig.py
@@ -82,6 +82,7 @@ assert_writable_dir(GENERATED_TEXT_DIR)
 # Flat file directories
 GENODIR              = flat_files('genotype')+'/'
 assert_dir(GENODIR)
+assert_dir(GENODIR+'bimbam') # for gemma
 
 # JSON genotypes are OBSOLETE
 JSON_GENODIR         = flat_files('genotype/json')+'/'
diff --git a/wqflask/tests/__init__.py b/wqflask/mock/__init__.py
index e69de29b..e69de29b 100644
--- a/wqflask/tests/__init__.py
+++ b/wqflask/mock/__init__.py
diff --git a/wqflask/mock/es_double.py b/wqflask/mock/es_double.py
new file mode 100644
index 00000000..6ef8a1b9
--- /dev/null
+++ b/wqflask/mock/es_double.py
@@ -0,0 +1,15 @@
+class ESDouble(object):
+    def __init__(self):
+        self.items = {}
+
+    def ping(self):
+        return true
+
+    def create(self, index, doc_type, body, id):
+        self.items["index"] = {doc_type: {"id": id, "_source": data}}
+
+    def search(self, index, doc_type, body):
+        return {
+            "hits": {
+                "hits": self.items[index][doc_type][body]
+            }}
diff --git a/wqflask/run_gunicorn.py b/wqflask/run_gunicorn.py
index 14a2d689..ebe3add5 100644
--- a/wqflask/run_gunicorn.py
+++ b/wqflask/run_gunicorn.py
@@ -11,6 +11,9 @@ print "Starting up Gunicorn process"
 
 from wqflask import app
 
+app.config['SESSION_TYPE'] = 'filesystem'
+app.config['SECRET_KEY'] = 'super secret key'
+
 @app.route("/gunicorn")
 def hello():
     return "<h1 style='color:blue'>Hello There!</h1>"
diff --git a/wqflask/tests/es_double.py b/wqflask/tests/es_double.py
deleted file mode 100644
index 00739016..00000000
--- a/wqflask/tests/es_double.py
+++ /dev/null
@@ -1,30 +0,0 @@
-class ESDouble(object):
-    def __init__(self):
-        self.items = {
-            "users": {
-                "local": []
-            }}
-
-    def ping(self):
-        return true
-
-    def create(self, index, doc_type, body, id):
-        item = {"id": id, "_source": body}
-        if not self.items.get("index", None):
-            self.items[index] = {doc_type: [item]}
-        else:
-            self.items[index][doc_type].append(item)
-
-    def search(self, index, doc_type, body):
-        d = body["query"]["match"]
-        column = [(key, d[key]) for key in d]
-
-        items = []
-        for thing in self.items[index][doc_type]:
-            if thing["_source"][column[0][0]] == column[0][1]:
-                items.append(thing)
-                break
-        return {
-            "hits": {
-                "hits": items
-            }}
diff --git a/wqflask/tests/test_registration.py b/wqflask/tests/test_registration.py
deleted file mode 100644
index 50a2a84c..00000000
--- a/wqflask/tests/test_registration.py
+++ /dev/null
@@ -1,113 +0,0 @@
-import unittest
-import es_double
-import wqflask.user_manager
-from wqflask.user_manager import RegisterUser
-
-class TestRegisterUser(unittest.TestCase):
-    def setUp(self):
-        # Mock elasticsearch
-        self.es = es_double.ESDouble()
-
-        # Patch method
-        wqflask.user_manager.basic_info = lambda : {"basic_info": "some info"}
-
-    def tearDown(self):
-        self.es = None
-
-    def testRegisterUserWithNoData(self):
-        data = {}
-        result = RegisterUser(data)
-        self.assertNotEqual(len(result.errors), 0, "Data was not provided. Error was expected")
-
-    def testRegisterUserWithNoEmail(self):
-        data = {
-            "email_address": ""
-            , "full_name": "A.N. Other"
-            , "organization": "Some Organisation"
-            , "password": "testing"
-            , "password_confirm": "testing"
-            , "es_connection": self.es
-        }
-
-        result = RegisterUser(data)
-        self.assertNotEqual(len(result.errors), 0, "Email not provided. Error was expected")
-
-    def testRegisterUserWithNoName(self):
-        data = {
-            "email_address": "user@example.com"
-            , "full_name": ""
-            , "organization": "Some Organisation"
-            , "password": "testing"
-            , "password_confirm": "testing"
-            , "es_connection": self.es
-        }
-
-        result = RegisterUser(data)
-        self.assertNotEqual(len(result.errors), 0, "Name not provided. Error was expected")
-
-    def testRegisterUserWithNoOrganisation(self):
-        data = {
-            "email_address": "user@example.com"
-            , "full_name": "A.N. Other"
-            , "organization": ""
-            , "password": "testing"
-            , "password_confirm": "testing"
-            , "es_connection": self.es
-        }
-        
-        result = RegisterUser(data)
-        self.assertEqual(len(result.errors), 0, "Organisation not provided. Error not expected")
-
-    def testRegisterUserWithShortOrganisation(self):
-        data = {
-            "email_address": "user@example.com"
-            , "full_name": "A.N. Other"
-            , "organization": "SO"
-            , "password": "testing"
-            , "password_confirm": "testing"
-            , "es_connection": self.es
-        }
-        
-        result = RegisterUser(data)
-        self.assertNotEqual(len(result.errors), 0, "Organisation name too short. Error expected")
-
-    def testRegisterUserWithNoPassword(self):
-        data = {
-            "email_address": "user@example.com"
-            , "full_name": "A.N. Other"
-            , "organization": "Some Organisation"
-            , "password": None
-            , "password_confirm": None
-            , "es_connection": self.es
-        }
-
-        result = RegisterUser(data)
-        self.assertNotEqual(len(result.errors), 0, "Password not provided. Error was expected")
-
-    def testRegisterUserWithNonMatchingPasswords(self):
-        data = {
-            "email_address": "user@example.com"
-            , "full_name": "A.N. Other"
-            , "organization": "Some Organisation"
-            , "password": "testing"
-            , "password_confirm": "stilltesting"
-            , "es_connection": self.es
-        }
-
-        result = RegisterUser(data)
-        self.assertNotEqual(len(result.errors), 0, "Password mismatch. Error was expected")
-
-    def testRegisterUserWithCorrectData(self):
-        data = {
-            "email_address": "user@example.com"
-            , "full_name": "A.N. Other"
-            , "organization": "Some Organisation"
-            , "password": "testing"
-            , "password_confirm": "testing"
-            , "es_connection": self.es
-        }
-        result = RegisterUser(data)
-        self.assertEqual(len(result.errors), 0, "All data items provided. Errors were not expected")
-
-if __name__ == "__main__":
-    unittest.main()
diff --git a/wqflask/utility/elasticsearch_tools.py b/wqflask/utility/elasticsearch_tools.py
index a964b025..2d3d5add 100644
--- a/wqflask/utility/elasticsearch_tools.py
+++ b/wqflask/utility/elasticsearch_tools.py
@@ -1,10 +1,18 @@
 from elasticsearch import Elasticsearch, TransportError
 import logging
 
+from utility.logger import getLogger
+logger = getLogger(__name__)
+
+from utility.tools import ELASTICSEARCH_HOST, ELASTICSEARCH_PORT
+
 def get_elasticsearch_connection():
+    logger.info("get_elasticsearch_connection")
     es = None
     try:
-        from utility.tools import ELASTICSEARCH_HOST, ELASTICSEARCH_PORT
+        assert(ELASTICSEARCH_HOST)
+        assert(ELASTICSEARCH_PORT)
+        logger.info("ES HOST",ELASTICSEARCH_HOST)
 
         es = Elasticsearch([{
             "host": ELASTICSEARCH_HOST
@@ -31,12 +39,12 @@ def get_item_by_unique_column(es, column_name, column_value, index, doc_type):
         response = es.search(
             index = index
             , doc_type = doc_type
-            , body = { 
-                "query": { "match": { column_name: column_value } } 
+            , body = {
+                "query": { "match": { column_name: column_value } }
             })
         if len(response["hits"]["hits"]) > 0:
             item_details = response["hits"]["hits"][0]["_source"]
-    except TransportError as te: 
+    except TransportError as te:
         pass
     return item_details
 
diff --git a/wqflask/utility/tools.py b/wqflask/utility/tools.py
index ec673cf5..8c9fed96 100644
--- a/wqflask/utility/tools.py
+++ b/wqflask/utility/tools.py
@@ -16,7 +16,7 @@ OVERRIDES = {}
 def app_set(command_id, value):
     """Set application wide value"""
     app.config.setdefault(command_id,value)
-    value
+    return value
 
 def get_setting(command_id,guess=None):
     """Resolve a setting from the environment or the global settings in
@@ -51,7 +51,7 @@ def get_setting(command_id,guess=None):
             return None
 
     # ---- Check whether environment exists
-    logger.debug("Looking for "+command_id+"\n")
+    # print("Looking for "+command_id+"\n")
     command = value(os.environ.get(command_id))
     if command is None or command == "":
         command = OVERRIDES.get(command_id)
@@ -63,7 +63,7 @@ def get_setting(command_id,guess=None):
                 if command is None or command == "":
                     # print command
                     raise Exception(command_id+' setting unknown or faulty (update default_settings.py?).')
-    logger.debug("Set "+command_id+"="+str(command))
+    # print("Set "+command_id+"="+str(command))
     return command
 
 def get_setting_bool(id):
@@ -251,35 +251,29 @@ assert_dir(JS_GUIX_PATH)
 JS_GN_PATH         = get_setting('JS_GN_PATH')
 # assert_dir(JS_GN_PATH)
 
-def get_setting_safe(setting):
-    try:
-        return get_setting(setting)
-    except:
-        print("Could not find the setting '", setting, "'. Continuing with value unset")
-        return None
-
-GITHUB_CLIENT_ID = get_setting_safe('GITHUB_CLIENT_ID')
-GITHUB_CLIENT_SECRET = get_setting_safe('GITHUB_CLIENT_SECRET')
+GITHUB_CLIENT_ID = get_setting('GITHUB_CLIENT_ID')
+GITHUB_CLIENT_SECRET = get_setting('GITHUB_CLIENT_SECRET')
 GITHUB_AUTH_URL = None
 if GITHUB_CLIENT_ID and GITHUB_CLIENT_SECRET:
     GITHUB_AUTH_URL = "https://github.com/login/oauth/authorize?client_id="+GITHUB_CLIENT_ID+"&client_secret="+GITHUB_CLIENT_SECRET
-GITHUB_API_URL = get_setting_safe('GITHUB_API_URL')
-ORCID_CLIENT_ID = get_setting_safe('ORCID_CLIENT_ID')
-ORCID_CLIENT_SECRET = get_setting_safe('ORCID_CLIENT_SECRET')
+GITHUB_API_URL = get_setting('GITHUB_API_URL')
+ORCID_CLIENT_ID = get_setting('ORCID_CLIENT_ID')
+ORCID_CLIENT_SECRET = get_setting('ORCID_CLIENT_SECRET')
 ORCID_AUTH_URL = None
 if ORCID_CLIENT_ID and ORCID_CLIENT_SECRET:
     ORCID_AUTH_URL = "https://sandbox.orcid.org/oauth/authorize?response_type=code&scope=/authenticate&show_login=true&client_id="+ORCID_CLIENT_ID+"&client_secret="+ORCID_CLIENT_SECRET
-ORCID_TOKEN_URL = get_setting_safe('ORCID_TOKEN_URL')
+ORCID_TOKEN_URL = get_setting('ORCID_TOKEN_URL')
 
-ELASTICSEARCH_HOST = get_setting_safe('ELASTICSEARCH_HOST')
-ELASTICSEARCH_PORT = get_setting_safe('ELASTICSEARCH_PORT')
+ELASTICSEARCH_HOST = get_setting('ELASTICSEARCH_HOST')
+ELASTICSEARCH_PORT = get_setting('ELASTICSEARCH_PORT')
 
-SMTP_CONNECT = get_setting_safe('SMTP_CONNECT')
-SMTP_USERNAME = get_setting_safe('SMTP_USERNAME')
-SMTP_PASSWORD = get_setting_safe('SMTP_PASSWORD')
+SMTP_CONNECT = get_setting('SMTP_CONNECT')
+SMTP_USERNAME = get_setting('SMTP_USERNAME')
+SMTP_PASSWORD = get_setting('SMTP_PASSWORD')
 
 PYLMM_COMMAND      = app_set("PYLMM_COMMAND",pylmm_command())
 GEMMA_COMMAND      = app_set("GEMMA_COMMAND",gemma_command())
+assert(GEMMA_COMMAND is not None)
 PLINK_COMMAND      = app_set("PLINK_COMMAND",plink_command())
 GEMMA_WRAPPER_COMMAND = gemma_wrapper_command()
 TEMPDIR            = tempdir() # defaults to UNIX TMPDIR
@@ -293,7 +287,7 @@ from six import string_types
 
 if os.environ.get('WQFLASK_OVERRIDES'):
     jsonfn = get_setting('WQFLASK_OVERRIDES')
-    logger.error("WQFLASK_OVERRIDES: %s" % jsonfn)
+    logger.info("WQFLASK_OVERRIDES: %s" % jsonfn)
     with open(jsonfn) as data_file:
         overrides = json.load(data_file)
         for k in overrides:
@@ -305,8 +299,4 @@ if os.environ.get('WQFLASK_OVERRIDES'):
             logger.debug(OVERRIDES)
 
 # assert_file(PHEWAS_FILES+"/auwerx/PheWAS_pval_EMMA_norm.RData")
-# assert_dir(get_setting("JS_BIODALLIANCE"))
-# assert_file(get_setting("JS_BIODALLIANCE")+"/build/dalliance-all.js")
-# assert_file(get_setting("JS_BIODALLIANCE")+"/build/worker-all.js")
-# assert_dir(get_setting("JS_TWITTER_POST_FETCHER"))
 assert_file(JS_TWITTER_POST_FETCHER_PATH+"/js/twitterFetcher_min.js")
diff --git a/wqflask/wqflask/correlation/show_corr_results.py b/wqflask/wqflask/correlation/show_corr_results.py
index 24432ad0..3d1c0d17 100644
--- a/wqflask/wqflask/correlation/show_corr_results.py
+++ b/wqflask/wqflask/correlation/show_corr_results.py
@@ -75,6 +75,46 @@ def print_mem(stage=""):
     mem = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
     #print("{}: {}".format(stage, mem/1024))
 
+def is_float(value):
+    try:
+        float(value)
+        return True
+    except:
+        return False
+
+def is_int(value):
+    try:
+        int(value)
+        return True
+    except:
+        return False
+
+def is_str(value):
+    if value is None:
+        return False
+    try:
+        str(value)
+        return True
+    except:
+        return False
+
+def get_float(vars,name,default=None):
+    if name in vars:
+        if is_float(vars[name]):
+            return float(vars[name])
+    return None
+
+def get_int(vars,name,default=None):
+    if name in vars:
+        if is_int(vars[name]):
+            return float(vars[name])
+    return default
+
+def get_string(vars,name,default=None):
+    if name in vars:
+        if not vars[name] is None:
+            return str(vars[name])
+    return default
 
 class AuthException(Exception):
     pass
@@ -96,7 +136,19 @@ class CorrelationResults(object):
         # get trait list from db (database name)
         # calculate correlation with Base vector and targets
 
-        print("TESTING...")
+        # Check parameters
+        assert('corr_type' in start_vars)
+        assert(is_str(start_vars['corr_type']))
+        assert('dataset' in start_vars)
+        # assert('group' in start_vars) permitted to be empty?
+        assert('corr_sample_method' in start_vars)
+        assert('corr_samples_group' in start_vars)
+        assert('corr_dataset' in start_vars)
+        assert('min_expr' in start_vars)
+        assert('corr_return_results' in start_vars)
+        if 'loc_chr' in start_vars:
+            assert('min_loc_mb' in start_vars)
+            assert('max_loc_mb' in start_vars)
 
         with Bench("Doing correlations"):
             if start_vars['dataset'] == "Temp":
@@ -115,27 +167,17 @@ class CorrelationResults(object):
             self.sample_data = {}
             self.corr_type = start_vars['corr_type']
             self.corr_method = start_vars['corr_sample_method']
-            if 'min_expr' in start_vars:
-                if start_vars['min_expr'] != "":
-                    self.min_expr = float(start_vars['min_expr'])
-                else:
-                    self.min_expr = None
-            self.p_range_lower = float(start_vars['p_range_lower'])
-            self.p_range_upper = float(start_vars['p_range_upper'])
+            self.min_expr = get_float(start_vars,'min_expr')
+            self.p_range_lower = get_float(start_vars,'p_range_lower',-1.0)
+            self.p_range_upper = get_float(start_vars,'p_range_upper',1.0)
 
             if ('loc_chr' in start_vars and
                 'min_loc_mb' in start_vars and
                 'max_loc_mb' in start_vars):
 
-                self.location_chr = start_vars['loc_chr']
-                if start_vars['min_loc_mb'].isdigit():
-                    self.min_location_mb = start_vars['min_loc_mb']
-                else:
-                    self.min_location_mb = None
-                if start_vars['max_loc_mb'].isdigit():
-                    self.max_location_mb = start_vars['max_loc_mb']
-                else:
-                    self.max_location_mb = None
+                self.location_chr = get_string(start_vars,'loc_chr')
+                self.min_location_mb = get_int(start_vars,'min_loc_mb')
+                self.max_location_mb = get_int(start_vars,'max_loc_mb')
 
             self.get_formatted_corr_type()
             self.return_number = int(start_vars['corr_return_results'])
@@ -183,7 +225,7 @@ class CorrelationResults(object):
                 else:
                     for trait, values in self.target_dataset.trait_data.iteritems():
                         self.get_sample_r_and_p_values(trait, values)
-                        
+
             elif self.corr_type == "lit":
                 self.trait_geneid_dict = self.dataset.retrieve_genes("GeneId")
                 lit_corr_data = self.do_lit_correlation_for_all_traits()
@@ -564,7 +606,7 @@ class CorrelationResults(object):
                 self.this_trait_vals.append(sample_value)
                 target_vals.append(target_sample_value)
 
-        self.this_trait_vals, target_vals, num_overlap = corr_result_helpers.normalize_values(self.this_trait_vals, target_vals)	
+        self.this_trait_vals, target_vals, num_overlap = corr_result_helpers.normalize_values(self.this_trait_vals, target_vals)
 
         #ZS: 2015 could add biweight correlation, see http://www.ncbi.nlm.nih.gov/pmc/articles/PMC3465711/
         if self.corr_method == 'pearson':
@@ -574,8 +616,8 @@ class CorrelationResults(object):
 
         if num_overlap > 5:
             self.correlation_data[trait] = [sample_r, sample_p, num_overlap]
-		
-		
+
+
         """
         correlations = []
 
@@ -673,8 +715,8 @@ class CorrelationResults(object):
                         method=self.method)
 
         return trait_list
-        """		
-		
+        """
+
 
     def do_tissue_corr_for_all_traits_2(self):
         """Comments Possibly Out of Date!!!!!
@@ -1089,7 +1131,7 @@ class CorrelationResults(object):
             totalTraits = len(traits) #XZ, 09/18/2008: total trait number
 
         return traits
-			
+
     def calculate_corr_for_all_tissues(self, tissue_dataset_id=None):
 
         symbol_corr_dict = {}
@@ -1129,7 +1171,7 @@ class CorrelationResults(object):
                     values_2.append(target_value)
             correlation = calCorrelation(values_1, values_2)
             self.correlation_data[trait] = correlation
-			
+
     def getFileName(self, target_db_name):  ### dcrowell  August 2008
         """Returns the name of the reference database file with which correlations are calculated.
         Takes argument cursor which is a cursor object of any instance of a subclass of templatePage
@@ -1144,7 +1186,7 @@ class CorrelationResults(object):
         return FileName
 
     def do_parallel_correlation(self, db_filename, num_overlap):
-	
+
         #XZ, 01/14/2009: This method is for parallel computing only.
         #XZ: It is supposed to be called when "Genetic Correlation, Pearson's r" (method 1)
         #XZ: or "Genetic Correlation, Spearman's rho" (method 2) is selected
@@ -1313,7 +1355,7 @@ class CorrelationResults(object):
                         z_value = z_value*math.sqrt(nOverlap-3)
                         sample_p = 2.0*(1.0 - reaper.normp(abs(z_value)))
 
-                correlation_data[traitdataName] = [sample_r, sample_p, nOverlap]	
+                correlation_data[traitdataName] = [sample_r, sample_p, nOverlap]
 
                 # traitinfo = [traitdataName, sample_r, nOverlap]
                 # allcorrelations.append(traitinfo)
@@ -1321,7 +1363,7 @@ class CorrelationResults(object):
             return correlation_data
             # return allcorrelations
 
-	
+
         datasetFile = open(webqtlConfig.GENERATED_TEXT_DIR+db_filename,'r')
 
         print("Invoking parallel computing")
@@ -1378,5 +1420,3 @@ class CorrelationResults(object):
         # for one_result in results:
             # for one_traitinfo in one_result:
                 # allcorrelations.append( one_traitinfo )
-
-
diff --git a/wqflask/wqflask/marker_regression/gemma_mapping.py b/wqflask/wqflask/marker_regression/gemma_mapping.py
index a24e43d4..68920130 100644
--- a/wqflask/wqflask/marker_regression/gemma_mapping.py
+++ b/wqflask/wqflask/marker_regression/gemma_mapping.py
@@ -3,7 +3,7 @@ import os, math, string, random, json
 from base import webqtlConfig
 from base.trait import GeneralTrait
 from base.data_set import create_dataset
-from utility.tools import flat_files, GEMMA_COMMAND, GEMMA_WRAPPER_COMMAND, TEMPDIR
+from utility.tools import flat_files, GEMMA_COMMAND, GEMMA_WRAPPER_COMMAND, TEMPDIR, assert_bin, assert_file
 
 import utility.logger
 logger = utility.logger.getLogger(__name__ )
@@ -11,6 +11,7 @@ logger = utility.logger.getLogger(__name__ )
 def run_gemma(this_dataset, samples, vals, covariates, method, use_loco):
     """Generates p-values for each marker using GEMMA"""
 
+    assert_bin(GEMMA_COMMAND);
     if this_dataset.group.genofile != None:
         genofile_name = this_dataset.group.genofile[:-5]
     else:
@@ -27,7 +28,7 @@ def run_gemma(this_dataset, samples, vals, covariates, method, use_loco):
         if i < (len(this_chromosomes) - 1):
             chr_list_string += this_chromosomes[i+1].name + ","
         else:
-            chr_list_string += this_chromosomes[i+1].name  
+            chr_list_string += this_chromosomes[i+1].name
 
     if covariates != "":
         gen_covariates_file(this_dataset, covariates)
@@ -209,8 +210,13 @@ def parse_gemma_output(genofile_name):
 def parse_loco_output(this_dataset, gwa_output_filename):
 
     output_filelist = []
-    with open("{}/gn2/".format(TEMPDIR) + gwa_output_filename + ".json") as data_file:
-       data = json.load(data_file)
+    jsonfn = "{}/gn2/".format(TEMPDIR) + gwa_output_filename + ".json"
+    assert_file(jsonfn)
+    try:
+        with open(jsonfn) as data_file:
+            data = json.load(data_file)
+    except:
+        logger.error("Can not parse "+jsonfn)
 
     files = data['files']
     for file in files:
@@ -247,4 +253,4 @@ def parse_loco_output(this_dataset, gwa_output_filename):
                     included_markers.append(line.split("\t")[1])
                     p_values.append(float(line.split("\t")[10]))
 
-    return marker_obs
\ No newline at end of file
+    return marker_obs
diff --git a/wqflask/wqflask/user_manager.py b/wqflask/wqflask/user_manager.py
index 6b667615..c8471cb1 100644
--- a/wqflask/wqflask/user_manager.py
+++ b/wqflask/wqflask/user_manager.py
@@ -55,9 +55,8 @@ logger = getLogger(__name__)
 from base.data_set import create_datasets_list
 
 import requests
-from utility.elasticsearch_tools import *
+from utility.elasticsearch_tools import get_elasticsearch_connection, get_user_by_unique_column, save_user
 
-es = get_elasticsearch_connection()
 THREE_DAYS = 60 * 60 * 24 * 3
 #THREE_DAYS = 45
 
@@ -479,6 +478,7 @@ def password_reset_step2():
     password = request.form['password']
     set_password(password, user)
 
+    es = get_elasticsearch_connection()
     es.update(
         index = "users"
         , doc_type = "local"
@@ -620,6 +620,7 @@ class LoginUser(object):
         """Login through the normal form"""
         params = request.form if request.form else request.args
         logger.debug("in login params are:", params)
+        es = get_elasticsearch_connection()
         if not params:
             from utility.tools import GITHUB_AUTH_URL, ORCID_AUTH_URL
             external_login = None
@@ -628,6 +629,7 @@ class LoginUser(object):
                     "github": GITHUB_AUTH_URL,
                     "orcid": ORCID_AUTH_URL
                 }
+            assert(es is not None)
             return render_template(
                 "new_security/login_user.html"
                 , external_login=external_login
@@ -822,6 +824,7 @@ def register():
 
     params = request.form if request.form else request.args
     params = params.to_dict(flat=True)
+    es = get_elasticsearch_connection()
     params["es_connection"] = es
 
     if params: