aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMuriithi Frederick Muriuki2018-02-13 13:48:28 +0300
committerPjotr Prins2018-03-26 09:29:29 +0000
commitc68cffc414ac6d7536db36e79914f7d57af741c6 (patch)
tree0d70bd2e62a01a15553255aba4831b5b5069c8fc
parent690e525a8a063d1be107c15521474052bd6ae2b4 (diff)
downloadgenenetwork2-c68cffc414ac6d7536db36e79914f7d57af741c6.tar.gz
Update module to make it more testable
* Update functions to make them more testable. * Update code using updated functions.
-rw-r--r--wqflask/utility/elasticsearch_tools.py40
-rw-r--r--wqflask/wqflask/user_manager.py61
2 files changed, 58 insertions, 43 deletions
diff --git a/wqflask/utility/elasticsearch_tools.py b/wqflask/utility/elasticsearch_tools.py
index 4fc0035c..a964b025 100644
--- a/wqflask/utility/elasticsearch_tools.py
+++ b/wqflask/utility/elasticsearch_tools.py
@@ -1,23 +1,31 @@
-es = None
-try:
- from elasticsearch import Elasticsearch, TransportError
- from utility.tools import ELASTICSEARCH_HOST, ELASTICSEARCH_PORT
+from elasticsearch import Elasticsearch, TransportError
+import logging
- es = Elasticsearch([{
- "host": ELASTICSEARCH_HOST
- , "port": ELASTICSEARCH_PORT
- }]) if (ELASTICSEARCH_HOST and ELASTICSEARCH_PORT) else None
-
-except:
+def get_elasticsearch_connection():
es = None
+ try:
+ from utility.tools import ELASTICSEARCH_HOST, ELASTICSEARCH_PORT
+
+ es = Elasticsearch([{
+ "host": ELASTICSEARCH_HOST
+ , "port": ELASTICSEARCH_PORT
+ }]) if (ELASTICSEARCH_HOST and ELASTICSEARCH_PORT) else None
+
+ es_logger = logging.getLogger("elasticsearch")
+ es_logger.setLevel(logging.INFO)
+ es_logger.addHandler(logging.NullHandler())
+ except:
+ es = None
+
+ return es
-def get_user_by_unique_column(column_name, column_value):
- return get_item_by_unique_column(column_name, column_value, index="users", doc_type="local")
+def get_user_by_unique_column(es, column_name, column_value, index="users", doc_type="local"):
+ return get_item_by_unique_column(es, column_name, column_value, index=index, doc_type=doc_type)
-def save_user(user, user_id):
- es_save_data("users", "local", user, user_id)
+def save_user(es, user, user_id):
+ es_save_data(es, "users", "local", user, user_id)
-def get_item_by_unique_column(column_name, column_value, index, doc_type):
+def get_item_by_unique_column(es, column_name, column_value, index, doc_type):
item_details = None
try:
response = es.search(
@@ -32,7 +40,7 @@ def get_item_by_unique_column(column_name, column_value, index, doc_type):
pass
return item_details
-def es_save_data(index, doc_type, data_item, data_id,):
+def es_save_data(es, index, doc_type, data_item, data_id,):
from time import sleep
es.create(index, doc_type, body=data_item, id=data_id)
sleep(1) # Delay 1 second to allow indexing
diff --git a/wqflask/wqflask/user_manager.py b/wqflask/wqflask/user_manager.py
index 630be9aa..6b667615 100644
--- a/wqflask/wqflask/user_manager.py
+++ b/wqflask/wqflask/user_manager.py
@@ -55,8 +55,9 @@ logger = getLogger(__name__)
from base.data_set import create_datasets_list
import requests
-from utility.elasticsearch_tools import get_user_by_unique_column, save_user, es_save_data
+from utility.elasticsearch_tools import *
+es = get_elasticsearch_connection()
THREE_DAYS = 60 * 60 * 24 * 3
#THREE_DAYS = 45
@@ -271,14 +272,18 @@ class RegisterUser(object):
self.thank_you_mode = False
self.errors = []
self.user = Bunch()
+ es = kw.get('es_connection', None)
+
+ if not es:
+ self.errors.append("Missing connection object")
self.user.email_address = kw.get('email_address', '').encode("utf-8").strip()
if not (5 <= len(self.user.email_address) <= 50):
self.errors.append('Email Address needs to be between 5 and 50 characters.')
-
- email_exists = get_user_by_unique_column("email_address", self.user.email_address)
- if email_exists:
- self.errors.append('User already exists with that email')
+ else:
+ email_exists = get_user_by_unique_column(es, "email_address", self.user.email_address)
+ if email_exists:
+ self.errors.append('User already exists with that email')
self.user.full_name = kw.get('full_name', '').encode("utf-8").strip()
if not (5 <= len(self.user.full_name) <= 50):
@@ -305,7 +310,7 @@ class RegisterUser(object):
self.user.confirmed = 1
self.user.registration_info = json.dumps(basic_info(), sort_keys=True)
- save_user(self.user.__dict__, self.user.user_id)
+ save_user(es, self.user.__dict__, self.user.user_id)
def set_password(password, user):
pwfields = Bunch()
@@ -381,7 +386,7 @@ class ForgotPasswordEmail(VerificationEmail):
"email_address": toaddr,
"timestamp": timestamp()
}
- es_save_data(self.key_prefix, "local", data, verification_code)
+ es_save_data(es, self.key_prefix, "local", data, verification_code)
subject = self.subject
body = render_template(
@@ -431,7 +436,6 @@ def verify_email():
@app.route("/n/password_reset", methods=['GET'])
def password_reset():
- from utility.elasticsearch_tools import get_item_by_unique_column
logger.debug("in password_reset request.url is:", request.url)
# We do this mainly just to assert that it's in proper form for displaying next page
@@ -441,14 +445,16 @@ def password_reset():
hmac = request.args.get('hm')
if verification_code:
code_details = get_item_by_unique_column(
- "verification_code",
- verification_code,
- ForgotPasswordEmail.key_prefix,
- "local")
+ es
+ , "verification_code"
+ , verification_code
+ , ForgotPasswordEmail.key_prefix
+ , "local")
if code_details:
user_details = get_user_by_unique_column(
- "email_address",
- code_details["email_address"])
+ es
+ , "email_address"
+ , code_details["email_address"])
if user_details:
return render_template(
"new_security/password_reset.html", user_encode=user_details["user_id"])
@@ -533,7 +539,7 @@ def github_oauth2():
result_dict = {arr[0]:arr[1] for arr in [tok.split("=") for tok in [token.encode("utf-8") for token in result.text.split("&")]]}
github_user = get_github_user_details(result_dict["access_token"])
- user_details = get_user_by_unique_column("github_id", github_user["id"])
+ user_details = get_user_by_unique_column(es, "github_id", github_user["id"])
if user_details == None:
user_details = {
"user_id": str(uuid.uuid4())
@@ -545,7 +551,7 @@ def github_oauth2():
, "active": 1
, "confirmed": 1
}
- save_user(user_details, user_details["user_id"])
+ save_user(es, user_details, user_details["user_id"])
url = "/n/login?type=github&uid="+user_details["user_id"]
return redirect(url)
@@ -566,7 +572,7 @@ def orcid_oauth2():
result = requests.post(ORCID_TOKEN_URL, data=data)
result_dict = json.loads(result.text.encode("utf-8"))
- user_details = get_user_by_unique_column("orcid", result_dict["orcid"])
+ user_details = get_user_by_unique_column(es, "orcid", result_dict["orcid"])
if user_details == None:
user_details = {
"user_id": str(uuid4())
@@ -580,7 +586,7 @@ def orcid_oauth2():
, "active": 1
, "confirmed": 1
}
- save_user(user_details, user_details["user_id"])
+ save_user(es, user_details, user_details["user_id"])
url = "/n/login?type=orcid&uid="+user_details["user_id"]
else:
flash("There was an error getting code from ORCID")
@@ -599,7 +605,7 @@ class LoginUser(object):
def oauth2_login(self, login_type, user_id):
"""Login via an OAuth2 provider"""
- user_details = get_user_by_unique_column("user_id", user_id)
+ user_details = get_user_by_unique_column(es, "user_id", user_id)
if user_details:
user = model.User()
user.id = user_details["user_id"]
@@ -616,7 +622,6 @@ class LoginUser(object):
logger.debug("in login params are:", params)
if not params:
from utility.tools import GITHUB_AUTH_URL, ORCID_AUTH_URL
- from utility.elasticsearch_tools import es
external_login = None
if GITHUB_AUTH_URL or ORCID_AUTH_URL:
external_login={
@@ -628,8 +633,9 @@ class LoginUser(object):
, external_login=external_login
, es_server=es.ping())
else:
- user_details = get_user_by_unique_column("email_address", params["email_address"])
+ user_details = get_user_by_unique_column(es, "email_address", params["email_address"])
user = None
+ valid = None
if user_details:
user = model.User();
for key in user_details:
@@ -672,7 +678,7 @@ class LoginUser(object):
else:
if user:
self.unsuccessful_login(user)
- flash("Invalid email-address or password. Please try again.", "alert-error")
+ flash("Invalid email-address or password. Please try again.", "alert-danger")
response = make_response(redirect(url_for('login')))
return response
@@ -739,7 +745,7 @@ def forgot_password():
def forgot_password_submit():
params = request.form
email_address = params['email_address']
- user_details = get_user_by_unique_column("email_address", email_address)
+ user_details = get_user_by_unique_column(es, "email_address", email_address)
if user_details:
ForgotPasswordEmail(user_details["email_address"])
# try:
@@ -815,16 +821,17 @@ def register():
params = request.form if request.form else request.args
+ params = params.to_dict(flat=True)
+ params["es_connection"] = es
if params:
logger.debug("Attempting to register the user...")
result = RegisterUser(params)
errors = result.errors
- if result.thank_you_mode:
- assert not errors, "Errors while in thank you mode? That seems wrong..."
- return render_template("new_security/registered.html",
- subject=VerificationEmail.subject)
+ if len(errors) == 0:
+ flash("Registration successful. You may login with your new account", "alert-info")
+ return redirect(url_for("login"))
return render_template("new_security/register_user.html", values=params, errors=errors)