about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--wqflask/utility/elasticsearch_tools.py40
-rw-r--r--wqflask/wqflask/user_manager.py61
2 files changed, 58 insertions, 43 deletions
diff --git a/wqflask/utility/elasticsearch_tools.py b/wqflask/utility/elasticsearch_tools.py
index 4fc0035c..a964b025 100644
--- a/wqflask/utility/elasticsearch_tools.py
+++ b/wqflask/utility/elasticsearch_tools.py
@@ -1,23 +1,31 @@
-es = None
-try:
-    from elasticsearch import Elasticsearch, TransportError
-    from utility.tools import ELASTICSEARCH_HOST, ELASTICSEARCH_PORT
+from elasticsearch import Elasticsearch, TransportError
+import logging
 
-    es = Elasticsearch([{
-        "host": ELASTICSEARCH_HOST
-        , "port": ELASTICSEARCH_PORT
-    }]) if (ELASTICSEARCH_HOST and ELASTICSEARCH_PORT) else None
-
-except:
+def get_elasticsearch_connection():
     es = None
+    try:
+        from utility.tools import ELASTICSEARCH_HOST, ELASTICSEARCH_PORT
+
+        es = Elasticsearch([{
+            "host": ELASTICSEARCH_HOST
+            , "port": ELASTICSEARCH_PORT
+        }]) if (ELASTICSEARCH_HOST and ELASTICSEARCH_PORT) else None
+
+        es_logger = logging.getLogger("elasticsearch")
+        es_logger.setLevel(logging.INFO)
+        es_logger.addHandler(logging.NullHandler())
+    except:
+        es = None
+
+    return es
 
-def get_user_by_unique_column(column_name, column_value):
-    return get_item_by_unique_column(column_name, column_value, index="users", doc_type="local")
+def get_user_by_unique_column(es, column_name, column_value, index="users", doc_type="local"):
+    return get_item_by_unique_column(es, column_name, column_value, index=index, doc_type=doc_type)
 
-def save_user(user, user_id):
-    es_save_data("users", "local", user, user_id)
+def save_user(es, user, user_id):
+    es_save_data(es, "users", "local", user, user_id)
 
-def get_item_by_unique_column(column_name, column_value, index, doc_type):
+def get_item_by_unique_column(es, column_name, column_value, index, doc_type):
     item_details = None
     try:
         response = es.search(
@@ -32,7 +40,7 @@ def get_item_by_unique_column(column_name, column_value, index, doc_type):
         pass
     return item_details
 
-def es_save_data(index, doc_type, data_item, data_id,):
+def es_save_data(es, index, doc_type, data_item, data_id,):
     from time import sleep
     es.create(index, doc_type, body=data_item, id=data_id)
     sleep(1) # Delay 1 second to allow indexing
diff --git a/wqflask/wqflask/user_manager.py b/wqflask/wqflask/user_manager.py
index 630be9aa..6b667615 100644
--- a/wqflask/wqflask/user_manager.py
+++ b/wqflask/wqflask/user_manager.py
@@ -55,8 +55,9 @@ logger = getLogger(__name__)
 from base.data_set import create_datasets_list
 
 import requests
-from utility.elasticsearch_tools import get_user_by_unique_column, save_user, es_save_data
+from utility.elasticsearch_tools import *
 
+es = get_elasticsearch_connection()
 THREE_DAYS = 60 * 60 * 24 * 3
 #THREE_DAYS = 45
 
@@ -271,14 +272,18 @@ class RegisterUser(object):
         self.thank_you_mode = False
         self.errors = []
         self.user = Bunch()
+        es = kw.get('es_connection', None)
+
+        if not es:
+            self.errors.append("Missing connection object")
 
         self.user.email_address = kw.get('email_address', '').encode("utf-8").strip()
         if not (5 <= len(self.user.email_address) <= 50):
             self.errors.append('Email Address needs to be between 5 and 50 characters.')
-
-        email_exists = get_user_by_unique_column("email_address", self.user.email_address)
-        if email_exists:
-            self.errors.append('User already exists with that email')
+        else:
+            email_exists = get_user_by_unique_column(es, "email_address", self.user.email_address)
+            if email_exists:
+                self.errors.append('User already exists with that email')
 
         self.user.full_name = kw.get('full_name', '').encode("utf-8").strip()
         if not (5 <= len(self.user.full_name) <= 50):
@@ -305,7 +310,7 @@ class RegisterUser(object):
         self.user.confirmed = 1
 
         self.user.registration_info = json.dumps(basic_info(), sort_keys=True)
-        save_user(self.user.__dict__, self.user.user_id)
+        save_user(es, self.user.__dict__, self.user.user_id)
 
 def set_password(password, user):
     pwfields = Bunch()
@@ -381,7 +386,7 @@ class ForgotPasswordEmail(VerificationEmail):
             "email_address": toaddr,
             "timestamp": timestamp()
         }
-        es_save_data(self.key_prefix, "local", data, verification_code)
+        es_save_data(es, self.key_prefix, "local", data, verification_code)
 
         subject = self.subject
         body = render_template(
@@ -431,7 +436,6 @@ def verify_email():
 
 @app.route("/n/password_reset", methods=['GET'])
 def password_reset():
-    from utility.elasticsearch_tools import get_item_by_unique_column
     logger.debug("in password_reset request.url is:", request.url)
 
     # We do this mainly just to assert that it's in proper form for displaying next page
@@ -441,14 +445,16 @@ def password_reset():
     hmac = request.args.get('hm')
     if verification_code:
         code_details = get_item_by_unique_column(
-            "verification_code",
-            verification_code,
-            ForgotPasswordEmail.key_prefix,
-            "local")
+            es
+            , "verification_code"
+            , verification_code
+            , ForgotPasswordEmail.key_prefix
+            , "local")
         if code_details:
             user_details = get_user_by_unique_column(
-                "email_address",
-                code_details["email_address"])
+                es
+                , "email_address"
+                , code_details["email_address"])
             if user_details:
                 return render_template(
                     "new_security/password_reset.html", user_encode=user_details["user_id"])
@@ -533,7 +539,7 @@ def github_oauth2():
     result_dict = {arr[0]:arr[1] for arr in [tok.split("=") for tok in [token.encode("utf-8") for token in result.text.split("&")]]}
 
     github_user = get_github_user_details(result_dict["access_token"])
-    user_details = get_user_by_unique_column("github_id", github_user["id"])
+    user_details = get_user_by_unique_column(es, "github_id", github_user["id"])
     if user_details == None:
         user_details = {
             "user_id": str(uuid.uuid4())
@@ -545,7 +551,7 @@ def github_oauth2():
             , "active": 1
             , "confirmed": 1
         }
-        save_user(user_details, user_details["user_id"])
+        save_user(es, user_details, user_details["user_id"])
     url = "/n/login?type=github&uid="+user_details["user_id"]
     return redirect(url)
 
@@ -566,7 +572,7 @@ def orcid_oauth2():
         result = requests.post(ORCID_TOKEN_URL, data=data)
         result_dict = json.loads(result.text.encode("utf-8"))
 
-        user_details = get_user_by_unique_column("orcid", result_dict["orcid"])
+        user_details = get_user_by_unique_column(es, "orcid", result_dict["orcid"])
         if user_details == None:
             user_details = {
                 "user_id": str(uuid4())
@@ -580,7 +586,7 @@ def orcid_oauth2():
                 , "active": 1
                 , "confirmed": 1
             }
-            save_user(user_details, user_details["user_id"])
+            save_user(es, user_details, user_details["user_id"])
         url = "/n/login?type=orcid&uid="+user_details["user_id"]
     else:
         flash("There was an error getting code from ORCID")
@@ -599,7 +605,7 @@ class LoginUser(object):
 
     def oauth2_login(self, login_type, user_id):
         """Login via an OAuth2 provider"""
-        user_details = get_user_by_unique_column("user_id", user_id)
+        user_details = get_user_by_unique_column(es, "user_id", user_id)
         if user_details:
             user = model.User()
             user.id = user_details["user_id"]
@@ -616,7 +622,6 @@ class LoginUser(object):
         logger.debug("in login params are:", params)
         if not params:
             from utility.tools import GITHUB_AUTH_URL, ORCID_AUTH_URL
-            from utility.elasticsearch_tools import es
             external_login = None
             if GITHUB_AUTH_URL or ORCID_AUTH_URL:
                 external_login={
@@ -628,8 +633,9 @@ class LoginUser(object):
                 , external_login=external_login
                 , es_server=es.ping())
         else:
-            user_details = get_user_by_unique_column("email_address", params["email_address"])
+            user_details = get_user_by_unique_column(es, "email_address", params["email_address"])
             user = None
+            valid = None
             if user_details:
                 user = model.User();
                 for key in user_details:
@@ -672,7 +678,7 @@ class LoginUser(object):
         else:
             if user:
                 self.unsuccessful_login(user)
-            flash("Invalid email-address or password. Please try again.", "alert-error")
+            flash("Invalid email-address or password. Please try again.", "alert-danger")
             response = make_response(redirect(url_for('login')))
 
             return response
@@ -739,7 +745,7 @@ def forgot_password():
 def forgot_password_submit():
     params = request.form
     email_address = params['email_address']
-    user_details = get_user_by_unique_column("email_address", email_address)
+    user_details = get_user_by_unique_column(es, "email_address", email_address)
     if user_details:
         ForgotPasswordEmail(user_details["email_address"])
     # try:
@@ -815,16 +821,17 @@ def register():
 
 
     params = request.form if request.form else request.args
+    params = params.to_dict(flat=True)
+    params["es_connection"] = es
 
     if params:
         logger.debug("Attempting to register the user...")
         result = RegisterUser(params)
         errors = result.errors
 
-        if result.thank_you_mode:
-            assert not errors, "Errors while in thank you mode? That seems wrong..."
-            return render_template("new_security/registered.html",
-                                   subject=VerificationEmail.subject)
+        if len(errors) == 0:
+            flash("Registration successful. You may login with your new account", "alert-info")
+            return redirect(url_for("login"))
 
     return render_template("new_security/register_user.html", values=params, errors=errors)