about summary refs log tree commit diff
diff options
context:
space:
mode:
authorzsloan2020-01-22 14:46:49 -0600
committerzsloan2020-01-22 14:46:49 -0600
commit9079631e2eb431b4fdd87b4b0249483d3e01df97 (patch)
treeea25d286caee6922010cfe7b27591cb4623fe595
parent26b3b191f12b6265b62a623d9c5675cffb369247 (diff)
downloadgenenetwork2-9079631e2eb431b4fdd87b4b0249483d3e01df97.tar.gz
Switched user authentication code to using Redis instead of ElasticSearch
-rw-r--r--wqflask/wqflask/templates/new_security/login_user.html14
-rw-r--r--wqflask/wqflask/user_manager.py167
2 files changed, 55 insertions, 126 deletions
diff --git a/wqflask/wqflask/templates/new_security/login_user.html b/wqflask/wqflask/templates/new_security/login_user.html
index c9aaf028..9df0e16a 100644
--- a/wqflask/wqflask/templates/new_security/login_user.html
+++ b/wqflask/wqflask/templates/new_security/login_user.html
@@ -8,7 +8,7 @@
 
         <h4>Already have an account? Sign in here.</h4>
 
-	    {% if es_server: %}
+	    {% if redis_is_available: %}
             <form class="form-horizontal" action="/n/login" method="POST" name="login_user_form" id="loginUserForm">
                  <fieldset>
                     <div class="form-group">
@@ -48,9 +48,9 @@
 
                     <h4>Don't have an account?</h4>
 
-                    {% if es_server: %}
+                {% if redis_is_available: %}
                 <a href="/n/register" class="btn btn-primary modalize">Create a new account</a>
-                {% else: %}
+                {% else %}
                 <div class="alert alert-warning">
                   <p>You cannot create an account at this moment.<br />
                 Please try again later.</p>
@@ -81,15 +81,15 @@
                 {% endif %}
 
             </form>
-	    {% else: %}
+        {% else: %}
 	    <div class="alert alert-warning">
 	      <p>You cannot login at this moment using your GeneNetwork account (the authentication service is down).<br />
 		Please try again later.</p>
-	    </div>
+        </div>
 	    {% endif %}
-            {% if not es_server and not external_login: %}
+            {% if not external_login: %}
             <hr>
-	    <div class="alert alert-warning">
+	        <div class="alert alert-warning">
                Note: it is safe to use GeneNetwork without a login. Login is only required for keeping track of
             collections and getting access to some types of restricted data.
             </div>
diff --git a/wqflask/wqflask/user_manager.py b/wqflask/wqflask/user_manager.py
index bf403536..1b27d7cb 100644
--- a/wqflask/wqflask/user_manager.py
+++ b/wqflask/wqflask/user_manager.py
@@ -36,7 +36,8 @@ logger = getLogger(__name__)
 from base.data_set import create_datasets_list
 
 import requests
-from utility.elasticsearch_tools import get_elasticsearch_connection, get_user_by_unique_column, get_item_by_unique_column, save_user, es_save_data
+
+from utility.redis_tools import get_user_id, get_user_by_unique_column, set_user_attribute, save_user, save_verification_code, check_verification_code, get_user_collections, save_collections
 
 from smtplib import SMTP
 from utility.tools import SMTP_CONNECT, SMTP_USERNAME, SMTP_PASSWORD, LOG_SQL_ALCHEMY
@@ -197,19 +198,14 @@ class UserSession(object):
             return ''
 
     @property
-    def es_user_id(self):
+    def redis_user_id(self):
         """User id from ElasticSearch (need to check if this is the same as the id stored in self.records)"""
 
-        es = get_elasticsearch_connection()
         user_email = self.record['user_email_address']
 
         #ZS: Get user's collections if they exist
-        response = es.search(
-                       index = "users", doc_type = "local", body = {
-                       "query": { "match": { "email_address": user_email } }
-                   })
-
-        user_id = response['hits']['hits'][0]['_id']
+        user_id = None
+        user_id = get_user_id("email_address", user_email)
         return user_id
 
     @property
@@ -224,49 +220,15 @@ class UserSession(object):
     def user_collections(self):
         """List of user's collections"""
 
-        es = get_elasticsearch_connection()
-
-        user_email = self.record['user_email_address']
-
         #ZS: Get user's collections if they exist
-        response = es.search(
-                       index = "users", doc_type = "local", body = {
-                       "query": { "match": { "email_address": user_email } }
-                   })
-        user_info = response['hits']['hits'][0]['_source']
-        if 'collections' in user_info.keys():
-            if len(user_info['collections']) > 0:
-                collection_list = json.loads(user_info['collections'])
-                return sorted(collection_list, key = lambda i: datetime.datetime.strptime(i['changed_timestamp'], '%b %d %Y %I:%M%p'), reverse=True)
-            else:
-                return []
-        else:
-            return []
+        collections = get_user_collections(self.redis_user_id)
+        return collections
 
     @property
     def num_collections(self):
         """Number of user's collections"""
 
-        es = get_elasticsearch_connection()
-
-        user_email = self.record['user_email_address']
-
-        #ZS: Get user's collections if they exist
-        response = es.search(
-                       index = "users", doc_type = "local", body = {
-                       "query": { "match": { "email_address": user_email } }
-                   })
-
-        user_info = response['hits']['hits'][0]['_source']
-        logger.debug("USER NUM COLL:", user_info)
-        if 'collections' in user_info.keys():
-            if user_info['collections'] != "[]" and len(user_info['collections']) > 0:
-                collections_json = json.loads(user_info['collections'])
-                return len(collections_json)
-            else:
-                return 0
-        else:
-            return 0
+        return len(self.user_collections)
 
 ###
 # ZS: This is currently not used, but I'm leaving it here commented out because the old "set superuser" code (at the bottom of this file) used it
@@ -297,29 +259,9 @@ class UserSession(object):
                            'num_members': len(traits),
                            'members': list(traits) }
 
-        es = get_elasticsearch_connection()
-
-        user_email = self.record['user_email_address']
-        response = es.search(
-                       index = "users", doc_type = "local", body = {
-                       "query": { "match": { "email_address": user_email } }
-                   })
-
-        user_id = response['hits']['hits'][0]['_id']
-        user_info = response['hits']['hits'][0]['_source']
-
-        if 'collections' in user_info.keys():
-            if user_info['collections'] != [] and user_info['collections'] != "[]":
-                current_collections = json.loads(user_info['collections'])
-                current_collections.append(collection_dict)
-                self.update_collections(current_collections)
-                #collections_json = json.dumps(current_collections)
-            else:
-                self.update_collections([collection_dict])
-                #collections_json = json.dumps([collection_dict])
-        else:
-            self.update_collections([collection_dict])
-            #collections_json = json.dumps([collection_dict])
+        current_collections = self.user_collections
+        current_collections.append(collection_dict)
+        self.update_collections(current_collections)
 
         return collection_dict['id']
 
@@ -399,10 +341,9 @@ class UserSession(object):
         return None
 
     def update_collections(self, updated_collections):
-        es = get_elasticsearch_connection()
+        collection_body = json.dumps(updated_collections)
 
-        collection_body = {'doc': {'collections': json.dumps(updated_collections)}}
-        es.update(index='users', doc_type='local', id=self.es_user_id, refresh='wait_for', body=collection_body)
+        save_collections(self.redis_user_id, collection_body)
 
     def delete_session(self):
         # And more importantly delete the redis record
@@ -449,16 +390,13 @@ class RegisterUser(object):
         self.thank_you_mode = False
         self.errors = []
         self.user = Bunch()
-        es = kw.get('es_connection', None)
-
-        if not es:
-            self.errors.append("Missing connection object")
 
         self.user.email_address = kw.get('email_address', '').encode("utf-8").strip()
         if not (5 <= len(self.user.email_address) <= 50):
             self.errors.append('Email Address needs to be between 5 and 50 characters.')
         else:
-            email_exists = get_user_by_unique_column(es, "email_address", self.user.email_address)
+            email_exists = get_user_by_unique_column("email_address", self.user.email_address)
+            #email_exists = get_user_by_unique_column(es, "email_address", self.user.email_address)
             if email_exists:
                 self.errors.append('User already exists with that email')
 
@@ -487,7 +425,7 @@ class RegisterUser(object):
         self.user.confirmed = 1
 
         self.user.registration_info = json.dumps(basic_info(), sort_keys=True)
-        save_user(es, self.user.__dict__, self.user.user_id)
+        save_user(self.user.__dict__, self.user.user_id)
 
 def set_password(password, user):
     pwfields = Bunch()
@@ -563,8 +501,9 @@ class ForgotPasswordEmail(VerificationEmail):
             "email_address": toaddr,
             "timestamp": timestamp()
         }
-        es = get_elasticsearch_connection()
-        es_save_data(es, self.key_prefix, "local", data, verification_code)
+
+        save_verification_code(toaddr, verification_code)
+
 
         subject = self.subject
         body = render_template(
@@ -621,19 +560,11 @@ def password_reset():
     # user_encode = DecodeUser(ForgotPasswordEmail.key_prefix).reencode_standalone()
     verification_code = request.args.get('code')
     hmac = request.args.get('hm')
-    es = get_elasticsearch_connection()
+
     if verification_code:
-        code_details = get_item_by_unique_column(
-            es
-            , "verification_code"
-            , verification_code
-            , ForgotPasswordEmail.key_prefix
-            , "local")
-        if code_details:
-            user_details = get_user_by_unique_column(
-                es
-                , "email_address"
-                , code_details["email_address"])
+        user_email = check_verification_code(verification_code)
+        if user_email:
+            user_details = get_user_by_unique_column('email_address', user_email)
             if user_details:
                 return render_template(
                     "new_security/password_reset.html", user_encode=user_details["user_id"])
@@ -641,7 +572,8 @@ def password_reset():
                 flash("Invalid code: User no longer exists!", "error")
         else:
             flash("Invalid code: Password reset code does not exist or might have expired!", "error")
-        return redirect(url_for("login"))#render_template("new_security/login_user.html", error=error)
+    else:
+        return redirect(url_for("login"))
 
 @app.route("/n/password_reset_step2", methods=('POST',))
 def password_reset_step2():
@@ -658,16 +590,7 @@ def password_reset_step2():
     password = request.form['password']
     set_password(password, user)
 
-    es = get_elasticsearch_connection()
-    es.update(
-        index = "users"
-        , doc_type = "local"
-        , id = user_id
-        , body = {
-            "doc": {
-                "password": user.__dict__.get("password")
-            }
-        })
+    set_user_attribute(user_id, "password", user.__dict__.get("password"))
 
     flash("Password changed successfully. You can now sign in.", "alert-info")
     response = make_response(redirect(url_for('login')))
@@ -719,8 +642,8 @@ def github_oauth2():
     result_dict = {arr[0]:arr[1] for arr in [tok.split("=") for tok in [token.encode("utf-8") for token in result.text.split("&")]]}
 
     github_user = get_github_user_details(result_dict["access_token"])
-    es = get_elasticsearch_connection()
-    user_details = get_user_by_unique_column(es, "github_id", github_user["id"])
+
+    user_details = get_user_by_unique_column("github_id", github_user["id"])
     if user_details == None:
         user_details = {
             "user_id": str(uuid.uuid4())
@@ -732,7 +655,8 @@ def github_oauth2():
             , "active": 1
             , "confirmed": 1
         }
-        save_user(es, user_details, user_details["user_id"])
+        save_user(user_details, user_details["user_id"])
+
     url = "/n/login?type=github&uid="+user_details["user_id"]
     return redirect(url)
 
@@ -753,8 +677,7 @@ def orcid_oauth2():
         result = requests.post(ORCID_TOKEN_URL, data=data)
         result_dict = json.loads(result.text.encode("utf-8"))
 
-        es = get_elasticsearch_connection()
-        user_details = get_user_by_unique_column(es, "orcid", result_dict["orcid"])
+        user_details = get_user_by_unique_column("orcid", result_dict["orcid"])
         if user_details == None:
             user_details = {
                 "user_id": str(uuid4())
@@ -768,7 +691,8 @@ def orcid_oauth2():
                 , "active": 1
                 , "confirmed": 1
             }
-            save_user(es, user_details, user_details["user_id"])
+            save_user(user_details, user_details["user_id"])
+
         url = "/n/login?type=orcid&uid="+user_details["user_id"]
     else:
         flash("There was an error getting code from ORCID")
@@ -788,8 +712,8 @@ class LoginUser(object):
 
     def oauth2_login(self, login_type, user_id):
         """Login via an OAuth2 provider"""
-        es = get_elasticsearch_connection()
-        user_details = get_user_by_unique_column(es, "user_id", user_id)
+
+        user_details = get_user_by_unique_column("user_id", user_id)
         if user_details:
             user = model.User()
             user.id = user_details["user_id"] if user_details["user_id"] == None else "N/A"
@@ -804,7 +728,7 @@ class LoginUser(object):
         """Login through the normal form"""
         params = request.form if request.form else request.args
         logger.debug("in login params are:", params)
-        es = get_elasticsearch_connection()
+
         if not params:
             from utility.tools import GITHUB_AUTH_URL, GITHUB_CLIENT_ID, ORCID_AUTH_URL, ORCID_CLIENT_ID
             external_login = {}
@@ -812,13 +736,14 @@ class LoginUser(object):
                 external_login["github"] = GITHUB_AUTH_URL
             if ORCID_AUTH_URL and ORCID_CLIENT_ID != 'UNKNOWN':
                 external_login["orcid"] = ORCID_AUTH_URL
-            assert(es is not None)
+
             return render_template(
                 "new_security/login_user.html"
                 , external_login=external_login
-                , es_server=es.ping())
+                , redis_is_available = is_redis_available())
         else:
-            user_details = get_user_by_unique_column(es, "email_address", params["email_address"])
+            user_details = get_user_by_unique_column("email_address", params["email_address"])
+            #user_details = get_user_by_unique_column(es, "email_address", params["email_address"])
             user = None
             valid = None
             if user_details:
@@ -942,8 +867,7 @@ def forgot_password_submit():
     next_page = None
     if email_address != "":
         logger.debug("Wants to send password E-mail to ",email_address)
-        es = get_elasticsearch_connection()
-        user_details = get_user_by_unique_column(es, "email_address", email_address)
+        user_details = get_user_by_unique_column("email_address", email_address)
         if user_details:
             ForgotPasswordEmail(user_details["email_address"])
             return render_template("new_security/forgot_password_step2.html",
@@ -960,6 +884,13 @@ def forgot_password_submit():
 def unauthorized(error):
     return redirect(url_for('login'))
 
+def is_redis_available():
+    try:
+        Redis.ping()
+    except:
+        return False
+    return True
+
 ###
 # ZS: The following 6 functions require the old MySQL User accounts; I'm leaving them commented out just in case we decide to reimplement them using ElasticSearch
 ###
@@ -1021,8 +952,6 @@ def register():
 
     params = request.form if request.form else request.args
     params = params.to_dict(flat=True)
-    es = get_elasticsearch_connection()
-    params["es_connection"] = es
 
     if params:
         logger.debug("Attempting to register the user...")