diff options
author | zsloan | 2020-03-03 15:08:48 -0600 |
---|---|---|
committer | zsloan | 2020-03-03 15:08:48 -0600 |
commit | f9849394e3a252b5a1ac59c78a06728d20ca69ed (patch) | |
tree | 94faf9af074e2e1bf61cdf479d76fec45e5b9c9b | |
parent | e81f98118808dc66f591282d9b3adb2bc41dbd57 (diff) | |
download | genenetwork2-f9849394e3a252b5a1ac59c78a06728d20ca69ed.tar.gz |
Refactored all the anonymous (not logged-in) user stuff so that both logged-in and anonymous users share the same code
-rw-r--r-- | wqflask/utility/tools.py | 3 | ||||
-rw-r--r-- | wqflask/wqflask/static/new/images/CITGLogo.png | bin | 0 -> 11962 bytes | |||
-rw-r--r-- | wqflask/wqflask/user_login.py | 467 | ||||
-rw-r--r-- | wqflask/wqflask/user_session.py | 272 | ||||
-rw-r--r-- | wqflask/wqflask/views.py | 7 |
5 files changed, 744 insertions, 5 deletions
diff --git a/wqflask/utility/tools.py b/wqflask/utility/tools.py index ec9f01bd..75bddb24 100644 --- a/wqflask/utility/tools.py +++ b/wqflask/utility/tools.py @@ -254,7 +254,6 @@ JS_GN_PATH = get_setting('JS_GN_PATH') GITHUB_CLIENT_ID = get_setting('GITHUB_CLIENT_ID') GITHUB_CLIENT_SECRET = get_setting('GITHUB_CLIENT_SECRET') -GITHUB_AUTH_URL = None if GITHUB_CLIENT_ID != 'UNKNOWN' and GITHUB_CLIENT_SECRET: GITHUB_AUTH_URL = "https://github.com/login/oauth/authorize?client_id=" + \ GITHUB_CLIENT_ID+"&client_secret="+GITHUB_CLIENT_SECRET @@ -264,7 +263,7 @@ ORCID_CLIENT_ID = get_setting('ORCID_CLIENT_ID') ORCID_CLIENT_SECRET = get_setting('ORCID_CLIENT_SECRET') ORCID_AUTH_URL = None if ORCID_CLIENT_ID != 'UNKNOWN' and ORCID_CLIENT_SECRET: - ORCID_AUTH_URL = "https://sandbox.orcid.org/oauth/authorize?response_type=code&scope=/authenticate&show_login=true&client_id=" + \ + ORCID_AUTH_URL = "https://orcid.org/oauth/authorize?response_type=code&scope=/authenticate&show_login=true&client_id=" + \ ORCID_CLIENT_ID+"&client_secret="+ORCID_CLIENT_SECRET ORCID_TOKEN_URL = get_setting('ORCID_TOKEN_URL') diff --git a/wqflask/wqflask/static/new/images/CITGLogo.png b/wqflask/wqflask/static/new/images/CITGLogo.png Binary files differnew file mode 100644 index 00000000..ae99fedb --- /dev/null +++ b/wqflask/wqflask/static/new/images/CITGLogo.png diff --git a/wqflask/wqflask/user_login.py b/wqflask/wqflask/user_login.py new file mode 100644 index 00000000..05885e2c --- /dev/null +++ b/wqflask/wqflask/user_login.py @@ -0,0 +1,467 @@ +from __future__ import print_function, division, absolute_import + +import os +import hashlib +import datetime +import time +import logging +import uuid +import hashlib +import hmac +import base64 +import requests + +import simplejson as json + +import redis # used for collections +Redis = redis.StrictRedis() + +from flask import (Flask, g, render_template, url_for, request, make_response, + redirect, flash, abort) + +from wqflask import app +from wqflask import pbkdf2 +from wqflask.hmac_func import hmac_creation +from wqflask.user_session import UserSession + +from utility.redis_tools import is_redis_available, get_user_id, get_user_by_unique_column, set_user_attribute, save_user, save_verification_code, check_verification_code, get_user_collections, save_collections + +from utility.logger import getLogger +logger = getLogger(__name__) + +from smtplib import SMTP +from utility.tools import SMTP_CONNECT, SMTP_USERNAME, SMTP_PASSWORD, LOG_SQL_ALCHEMY + +THREE_DAYS = 60 * 60 * 24 * 3 + +def timestamp(): + return datetime.datetime.utcnow().isoformat() + +def basic_info(): + return dict(timestamp = timestamp(), + ip_address = request.remote_addr, + user_agent = request.headers.get('User-Agent')) + +def encode_password(pass_gen_fields): + hashfunc = getattr(hashlib, pass_gen_fields['hashfunc']) + + salt = base64.b64decode(pass_gen_fields['salt']) + password = pbkdf2.pbkdf2_hex(str(pass_gen_fields['unencrypted_password']), + pass_gen_fields['salt'], + pass_gen_fields['iterations'], + pass_gen_fields['keylength'], + hashfunc) + + return password + +def set_password(password): + pass_gen_fields = { + "unencrypted_password": password, + "algorithm": "pbkdf2", + "hashfunc": "sha256", + "salt": base64.b64encode(os.urandom(32)), + "iterations": 100000, + "keylength": 32, + "created_timestamp": timestamp() + } + + assert len(password) >= 6, "Password shouldn't be shorter than 6 characters" + + encoded_password = encode_password(pass_gen_fields) + + return encoded_password + +def encrypt_password(unencrypted_password, pwfields): + hashfunc = getattr(hashlib, pwfields['hashfunc']) + salt = base64.b64decode(pwfields['salt']) + iterations = pwfields['iterations'] + keylength = pwfields['keylength'] + encrypted_password = pbkdf2.pbkdf2_hex(str(unencrypted_password), + salt, iterations, keylength, hashfunc) + return encrypted_password + +def get_signed_session_id(user): + session_id = str(uuid.uuid4()) + + session_id_signature = hmac_creation(session_id) + session_id_signed = session_id + ":" + session_id_signature + + #ZS: Need to check if this is ever actually used or exists + if 'user_id' not in user: + user['user_id'] = str(uuid.uuid4()) + save_user(user, user['user_id']) + + if 'github_id' in user: + session = dict(login_time = time.time(), + user_type = "github", + user_id = user['user_id'], + github_id = user['github_id'], + user_name = user['name'], + user_url = user['user_url']) + elif 'orcid' in user: + session = dict(login_time = time.time(), + user_type = "orcid", + user_id = user['user_id'], + github_id = user['orcid'], + user_name = user['name'], + user_url = user['user_url']) + else: + session = dict(login_time = time.time(), + user_type = "gn2", + user_id = user['user_id'], + user_name = user['full_name'], + user_email_address = user['email_address']) + + key = UserSession.user_cookie_name + ":" + session_id + Redis.hmset(key, session) + Redis.expire(key, THREE_DAYS) + + return session_id_signed + +def send_email(toaddr, msg, fromaddr="no-reply@genenetwork.org"): + """Send an E-mail through SMTP_CONNECT host. If SMTP_USERNAME is not + 'UNKNOWN' TLS is used + + """ + if SMTP_USERNAME == 'UNKNOWN': + server = SMTP(SMTP_CONNECT) + server.sendmail(fromaddr, toaddr, msg) + else: + server = SMTP(SMTP_CONNECT) + server.starttls() + server.login(SMTP_USERNAME, SMTP_PASSWORD) + server.sendmail(fromaddr, toaddr, msg) + server.quit() + logger.info("Successfully sent email to "+toaddr) + +def send_verification_email(user_details, template_name = "email/verification.txt", key_prefix = "verification_code", subject = "GeneNetwork email verification"): + verification_code = str(uuid.uuid4()) + key = key_prefix + ":" + verification_code + + data = json.dumps(dict(id=user_details['user_id'], timestamp = timestamp())) + + Redis.set(key, data) + Redis.expire(key, THREE_DAYS) + + recipient = user_details['email_address'] + body = render_template(template_name, verification_code = verification_code) + send_email(recipient, subject, body) + return {"recipient": recipient, "subject": subject, "body": body} + +@app.route("/n/login", methods=('GET', 'POST')) +def login(): + params = request.form if request.form else request.args + logger.debug("in login params are:", params) + + if not params: #ZS: If coming to page for first time + from utility.tools import GITHUB_AUTH_URL, GITHUB_CLIENT_ID, ORCID_AUTH_URL, ORCID_CLIENT_ID + external_login = {} + if GITHUB_AUTH_URL and GITHUB_CLIENT_ID != 'UNKNOWN': + external_login["github"] = GITHUB_AUTH_URL + if ORCID_AUTH_URL and ORCID_CLIENT_ID != 'UNKNOWN': + external_login["orcid"] = ORCID_AUTH_URL + return render_template("new_security/login_user.html", external_login = external_login, redis_is_available=is_redis_available()) + else: #ZS: After clicking sign-in + if 'type' in params and 'uid' in params: + user_details = get_user_by_unique_column("user_id", params['uid']) + if user_details: + session_id_signed = get_signed_session_id(user_details) + if 'name' in user_details and user_details['name'] != "None": + display_id = user_details['name'] + elif 'github_id' in user_details: + display_id = user_details['github_id'] + elif 'orcid' in user_details: + display_id = user_details['orcid'] + else: + display_id = "" + flash("Thank you for logging in {}.".format(display_id), "alert-success") + response = make_response(redirect(url_for('index_page'))) + response.set_cookie(UserSession.user_cookie_name, session_id_signed, max_age=None) + else: + flash("Something went unexpectedly wrong.", "alert-danger") + response = make_response(redirect(url_for('index_page'))) + return response + else: + user_details = get_user_by_unique_column("email_address", params['email_address']) + password_match = False + if user_details: + submitted_password = params['password'] + pwfields = json.loads(user_details['password']) + encrypted_pass = encrypt_password(submitted_password, pwfields) + password_match = pbkdf2.safe_str_cmp(encrypted_pass, pwfields['password']) + else: # Invalid e-mail + flash("Invalid e-mail address. Please try again.", "alert-danger") + response = make_response(redirect(url_for('login'))) + + return response + if password_match: # If password correct + if user_details['confirmed']: # If account confirmed + import_col = "false" + if 'import_collections' in params: + import_col = "true" + + session_id_signed = get_signed_session_id(user_details) + flash("Thank you for logging in {}.".format(user_details['full_name']), "alert-success") + response = make_response(redirect(url_for('index_page', import_collections = import_col))) + response.set_cookie(UserSession.user_cookie_name, session_id_signed, max_age=None) + return response + else: + email_ob = send_verification_email(user_details) + return render_template("newsecurity/verification_still_needed.html", subject=email_ob['subject']) + else: # Incorrect password + #ZS: It previously seemed to store that there was an incorrect log-in attempt here, but it did so in the MySQL DB so this might need to be reproduced with Redis + flash("Invalid password. Please try again.", "alert-danger") + response = make_response(redirect(url_for('login'))) + + return response + +@app.route("/n/login/github_oauth2", methods=('GET', 'POST')) +def github_oauth2(): + from utility.tools import GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET, GITHUB_AUTH_URL + code = request.args.get("code") + data = { + "client_id": GITHUB_CLIENT_ID, + "client_secret": GITHUB_CLIENT_SECRET, + "code": code + } + logger.debug("LOGIN DATA:", data) + result = requests.post("https://github.com/login/oauth/access_token", json=data) + result_dict = {arr[0]:arr[1] for arr in [tok.split("=") for tok in [token.encode("utf-8") for token in result.text.split("&")]]} + + github_user = get_github_user_details(result_dict["access_token"]) + + user_details = get_user_by_unique_column("github_id", github_user["id"]) + if user_details == None: + user_details = { + "user_id": str(uuid.uuid4()), + "name": github_user["name"].encode("utf-8") if github_user["name"] else "None", + "github_id": github_user["id"], + "user_url": github_user["html_url"].encode("utf-8"), + "login_type": "github", + "organization": "", + "active": 1, + "confirmed": 1 + } + save_user(user_details, user_details["user_id"]) + + url = "/n/login?type=github&uid="+user_details["user_id"] + return redirect(url) + +def get_github_user_details(access_token): + from utility.tools import GITHUB_API_URL + result = requests.get(GITHUB_API_URL, headers = {'Authorization':'token ' + access_token }).content + + return json.loads(result) + +@app.route("/n/login/orcid_oauth2", methods=('GET', 'POST')) +def orcid_oauth2(): + from uuid import uuid4 + from utility.tools import ORCID_CLIENT_ID, ORCID_CLIENT_SECRET, ORCID_TOKEN_URL, ORCID_AUTH_URL + code = request.args.get("code") + error = request.args.get("error") + url = "/n/login" + if code: + data = { + "client_id": ORCID_CLIENT_ID, + "client_secret": ORCID_CLIENT_SECRET, + "grant_type": "authorization_code", + "code": code + } + result = requests.post(ORCID_TOKEN_URL, data=data) + result_dict = json.loads(result.text.encode("utf-8")) + + user_details = get_user_by_unique_column("orcid", result_dict["orcid"]) + if user_details == None: + user_details = { + "user_id": str(uuid4()), + "name": result_dict["name"], + "orcid": result_dict["orcid"], + "user_url": "%s/%s" % ("/".join(ORCID_AUTH_URL.split("/")[:-2]), result_dict["orcid"]), + "login_type": "orcid", + "organization": "", + "active": 1, + "confirmed": 1 + } + save_user(user_details, user_details["user_id"]) + + url = "/n/login?type=orcid&uid="+user_details["user_id"] + else: + flash("There was an error getting code from ORCID") + return redirect(url) + +def get_github_user_details(access_token): + from utility.tools import GITHUB_API_URL + result = requests.get(GITHUB_API_URL, headers = {'Authorization':'token ' + access_token }).content + + return json.loads(result) + + +@app.route("/n/logout") +def logout(): + logger.debug("Logging out...") + UserSession().delete_session() + flash("You are now logged out. We hope you come back soon!") + response = make_response(redirect(url_for('index_page'))) + # Delete the cookie + response.set_cookie(UserSession.user_cookie_name, '', expires=0) + return response + +@app.route("/n/forgot_password", methods=['GET']) +def forgot_password(): + """Entry point for forgotten password""" + print("ARGS: ", request.args) + errors = {"no-email": request.args.get("no-email")} + print("ERRORS: ", errors) + return render_template("new_security/forgot_password.html", errors=errors) + +def send_forgot_password_email(verification_email): + from email.MIMEMultipart import MIMEMultipart + from email.MIMEText import MIMEText + + template_name = "email/forgot_password.txt" + key_prefix = "forgot_password_code" + subject = "GeneNetwork password reset" + fromaddr = "no-reply@genenetwork.org" + + verification_code = str(uuid.uuid4()) + key = key_prefix + ":" + verification_code + + data = { + "verification_code": verification_code, + "email_address": verification_email, + "timestamp": timestamp() + } + + save_verification_code(verification_email, verification_code) + + body = render_template(template_name, verification_code = verification_code) + + msg = MIMEMultipart() + msg["To"] = verification_email + msg["Subject"] = subject + msg["From"] = fromaddr + msg.attach(MIMEText(body, "plain")) + + send_email(verification_email, msg.as_string()) + + return subject + +@app.route("/n/forgot_password_submit", methods=('POST',)) +def forgot_password_submit(): + """When a forgotten password form is submitted we get here""" + params = request.form + email_address = params['email_address'] + next_page = None + if email_address != "": + logger.debug("Wants to send password E-mail to ", email_address) + user_details = get_user_by_unique_column("email_address", email_address) + if user_details: + email_subject = send_forgot_password_email(user_details["email_address"]) + return render_template("new_security/forgot_password_step2.html", + subject=email_subject) + else: + flash("The e-mail entered is not associated with an account.", "alert-danger") + return redirect(url_for("forgot_password")) + + else: + flash("You MUST provide an email", "alert-danger") + return redirect(url_for("forgot_password")) + +@app.route("/n/password_reset", methods=['GET']) +def password_reset(): + """Entry point after user clicks link in E-mail""" + logger.debug("in password_reset request.url is:", request.url) + + verification_code = request.args.get('code') + hmac = request.args.get('hm') + + if verification_code: + user_email = check_verification_code(verification_code) + if user_email: + user_details = get_user_by_unique_column('email_address', user_email) + if user_details: + return render_template( + "new_security/password_reset.html", user_encode=user_details["email_address"]) + else: + flash("Invalid code: User no longer exists!", "error") + else: + flash("Invalid code: Password reset code does not exist or might have expired!", "error") + else: + return redirect(url_for("login")) + +@app.route("/n/password_reset_step2", methods=('POST',)) +def password_reset_step2(): + """Handle confirmation E-mail for password reset""" + logger.debug("in password_reset request.url is:", request.url) + + errors = [] + user_email = request.form['user_encode'] + + password = request.form['password'] + encoded_password = set_password(password) + + set_user_attribute(user_id, "password", encoded_password) + + flash("Password changed successfully. You can now sign in.", "alert-info") + response = make_response(redirect(url_for('login'))) + + return response + +def register_user(params): + thank_you_mode = False + errors = [] + user_details = {} + + user_details['email_address'] = params.get('email_address', '').encode("utf-8").strip() + if not (5 <= len(user_details['email_address']) <= 50): + errors.append('Email Address needs to be between 5 and 50 characters.') + else: + email_exists = get_user_by_unique_column("email_address", user_details['email_address']) + if email_exists: + errors.append('User already exists with that email') + + user_details['full_name'] = params.get('full_name', '').encode("utf-8").strip() + if not (5 <= len(user_details['full_name']) <= 50): + errors.append('Full Name needs to be between 5 and 50 characters.') + + user_details['organization'] = params.get('organization', '').encode("utf-8").strip() + if user_details['organization'] and not (5 <= len(user_details['organization']) <= 50): + errors.append('Organization needs to be empty or between 5 and 50 characters.') + + password = str(params.get('password', '')) + if not (6 <= len(password)): + errors.append('Password needs to be at least 6 characters.') + + if params.get('password_confirm') != password: + errors.append("Passwords don't match.") + + if errors: + return errors + + user_details['password'] = set_password(password) + user_details['user_id'] = str(uuid.uuid4()) + user_details['confirmed'] = 1 + + user_details['registration_info'] = json.dumps(basic_info(), sort_keys=True) + save_user(user_details, user_details['user_id']) + +@app.route("/n/register", methods=('GET', 'POST')) +def register(): + errors = None + + params = request.form if request.form else request.args + params = params.to_dict(flat=True) + + if params: + logger.debug("Attempting to register the user...") + errors = register_user(params) + + if len(errors) == 0: + flash("Registration successful. You may login with your new account", "alert-info") + return redirect(url_for("login")) + + return render_template("new_security/register_user.html", values=params, errors=errors) + +@app.errorhandler(401) +def unauthorized(error): + return redirect(url_for('login'))
\ No newline at end of file diff --git a/wqflask/wqflask/user_session.py b/wqflask/wqflask/user_session.py new file mode 100644 index 00000000..1f3e6558 --- /dev/null +++ b/wqflask/wqflask/user_session.py @@ -0,0 +1,272 @@ +from __future__ import print_function, division, absolute_import
+
+import datetime
+import time
+import uuid
+
+import simplejson as json
+
+import redis # used for collections
+Redis = redis.StrictRedis()
+
+from flask import (Flask, g, render_template, url_for, request, make_response,
+ redirect, flash, abort)
+
+from wqflask import app
+from wqflask.hmac_func import hmac_creation
+
+#from utility.elasticsearch_tools import get_elasticsearch_connection
+from utility.redis_tools import get_user_id, get_user_by_unique_column, get_user_collections, save_collections
+
+from utility.logger import getLogger
+logger = getLogger(__name__)
+
+THREE_DAYS = 60 * 60 * 24 * 3
+THIRTY_DAYS = 60 * 60 * 24 * 30
+
+def verify_cookie(cookie):
+ the_uuid, separator, the_signature = cookie.partition(':')
+ assert len(the_uuid) == 36, "Is session_id a uuid?"
+ assert separator == ":", "Expected a : here"
+ assert the_signature == hmac_creation(the_uuid), "Uh-oh, someone tampering with the cookie?"
+ return the_uuid
+
+def create_signed_cookie():
+ the_uuid = str(uuid.uuid4())
+ signature = hmac_creation(the_uuid)
+ uuid_signed = the_uuid + ":" + signature
+ logger.debug("uuid_signed:", uuid_signed)
+ return the_uuid, uuid_signed
+
+class UserSession(object):
+ """Logged in user handling"""
+
+ user_cookie_name = 'session_id_v1'
+ anon_cookie_name = 'anon_user_v1'
+
+ def __init__(self):
+ user_cookie = request.cookies.get(self.user_cookie_name)
+ if not user_cookie:
+ self.logged_in = False
+ anon_cookie = request.cookies.get(self.anon_cookie_name)
+ self.cookie_name = self.anon_cookie_name
+ if anon_cookie:
+ self.cookie = anon_cookie
+ session_id = verify_cookie(self.cookie)
+ else:
+ session_id, self.cookie = create_signed_cookie()
+ else:
+ self.cookie_name = self.user_cookie_name
+ self.cookie = user_cookie
+ session_id = verify_cookie(self.cookie)
+
+ self.redis_key = self.cookie_name + ":" + session_id
+ self.session_id = session_id
+ self.record = Redis.hgetall(self.redis_key)
+
+ #ZS: If user correctled logged in but their session expired
+ #ZS: Need to test this by setting the time-out to be really short or something
+ if not self.record:
+ if user_cookie:
+ self.logged_in = False
+
+ ########### Grrr...this won't work because of the way flask handles cookies
+ # Delete the cookie
+ response = make_response(redirect(url_for('login')))
+ #response.set_cookie(self.cookie_name, '', expires=0)
+ flash("Due to inactivity your session has expired. If you'd like please login again.")
+ return response
+ #return
+ else:
+ self.record = dict(login_time = time.time(),
+ user_type = "anon",
+ user_id = str(uuid.uuid4()))
+
+ Redis.hmset(self.redis_key, self.record)
+ Redis.expire(self.redis_key, THIRTY_DAYS)
+ else:
+ if user_cookie:
+ self.logged_in = True
+
+ if user_cookie:
+ session_time = THREE_DAYS
+ else:
+ session_time = THIRTY_DAYS
+
+ if Redis.ttl(self.redis_key) < session_time:
+ # (Almost) everytime the user does something we extend the session_id in Redis...
+ logger.debug("Extending ttl...")
+ Redis.expire(self.redis_key, session_time)
+
+ @property
+ def user_id(self):
+ """Shortcut to the user_id"""
+ if 'user_id' in self.record:
+ return self.record['user_id']
+ else:
+ return ''
+
+ @property
+ def redis_user_id(self):
+ """User id from Redis (need to check if this is the same as the id stored in self.records)"""
+
+ #ZS: This part is a bit weird. Some accounts used to not have saved user ids, and in the process of testing I think I created some duplicate accounts for myself.
+ #ZS: Accounts should automatically generate user_ids if they don't already have one now, so this might not be necessary for anything other than my account's collections
+
+ if 'user_email_address' in self.record:
+ user_email = self.record['user_email_address']
+
+ #ZS: Get user's collections if they exist
+ user_id = None
+ user_id = get_user_id("email_address", user_email)
+ elif 'user_id' in self.record:
+ user_id = self.record['user_id']
+ elif 'github_id' in self.record:
+ user_github_id = self.record['github_id']
+ user_id = None
+ user_id = get_user_id("github_id", user_github_id)
+ else: #ZS: Anonymous user
+ return None
+
+ return user_id
+
+ @property
+ def user_name(self):
+ """Shortcut to the user_name"""
+ if 'user_name' in self.record:
+ return self.record['user_name']
+ else:
+ return ''
+
+ @property
+ def user_collections(self):
+ """List of user's collections"""
+
+ #ZS: Get user's collections if they exist
+ collections = get_user_collections(self.redis_user_id)
+ return collections
+
+ @property
+ def num_collections(self):
+ """Number of user's collections"""
+
+ return len(self.user_collections)
+
+ def add_collection(self, collection_name, traits):
+ """Add collection into ElasticSearch"""
+
+ collection_dict = {'id': unicode(uuid.uuid4()),
+ 'name': collection_name,
+ 'created_timestamp': datetime.datetime.utcnow().strftime('%b %d %Y %I:%M%p'),
+ 'changed_timestamp': datetime.datetime.utcnow().strftime('%b %d %Y %I:%M%p'),
+ 'num_members': len(traits),
+ 'members': list(traits) }
+
+ current_collections = self.user_collections
+ current_collections.append(collection_dict)
+ self.update_collections(current_collections)
+
+ return collection_dict['id']
+
+ def change_collection_name(self, collection_id, new_name):
+ for collection in self.user_collections:
+ if collection['id'] == collection_id:
+ collection['name'] = new_name
+ break
+
+ return new_name
+
+ def delete_collection(self, collection_id):
+ """Remove collection with given ID"""
+
+ updated_collections = []
+ for collection in self.user_collections:
+ if collection['id'] == collection_id:
+ continue
+ else:
+ updated_collections.append(collection)
+
+ self.update_collections(updated_collections)
+
+ return collection['name']
+
+ def add_traits_to_collection(self, collection_id, traits_to_add):
+ """Add specified traits to a collection"""
+
+ this_collection = self.get_collection_by_id(collection_id)
+
+ updated_collection = this_collection
+ updated_traits = this_collection['members'] + traits_to_add
+
+ updated_collection['members'] = updated_traits
+ updated_collection['num_members'] = len(updated_traits)
+ updated_collection['changed_timestamp'] = datetime.datetime.utcnow().strftime('%b %d %Y %I:%M%p')
+
+ updated_collections = []
+ for collection in self.user_collections:
+ if collection['id'] == collection_id:
+ updated_collections.append(updated_collection)
+ else:
+ updated_collections.append(collection)
+
+ self.update_collections(updated_collections)
+
+ def remove_traits_from_collection(self, collection_id, traits_to_remove):
+ """Remove specified traits from a collection"""
+
+ this_collection = self.get_collection_by_id(collection_id)
+
+ updated_collection = this_collection
+ updated_traits = []
+ for trait in this_collection['members']:
+ if trait in traits_to_remove:
+ continue
+ else:
+ updated_traits.append(trait)
+
+ updated_collection['members'] = updated_traits
+ updated_collection['num_members'] = len(updated_traits)
+ updated_collection['changed_timestamp'] = datetime.datetime.utcnow().strftime('%b %d %Y %I:%M%p')
+
+ updated_collections = []
+ for collection in self.user_collections:
+ if collection['id'] == collection_id:
+ updated_collections.append(updated_collection)
+ else:
+ updated_collections.append(collection)
+
+ self.update_collections(updated_collections)
+
+ return updated_traits
+
+ def get_collection_by_id(self, collection_id):
+ for collection in self.user_collections:
+ if collection['id'] == collection_id:
+ return collection
+
+ def get_collection_by_name(self, collection_name):
+ for collection in self.user_collections:
+ if collection['name'] == collection_name:
+ return collection
+
+ return None
+
+ def update_collections(self, updated_collections):
+ collection_body = json.dumps(updated_collections)
+
+ save_collections(self.redis_user_id, collection_body)
+
+ def delete_session(self):
+ # And more importantly delete the redis record
+ Redis.delete(self.redis_key)
+ self.logged_in = False
+
+@app.before_request
+def before_request():
+ g.user_session = UserSession()
+
+@app.after_request
+def set_cookie(response):
+ if not request.cookies.get(g.user_session.cookie_name):
+ response.set_cookie(g.user_session.cookie_name, g.user_session.cookie)
+ return response
\ No newline at end of file diff --git a/wqflask/wqflask/views.py b/wqflask/wqflask/views.py index a84a7480..436ebc91 100644 --- a/wqflask/wqflask/views.py +++ b/wqflask/wqflask/views.py @@ -65,9 +65,8 @@ from utility.benchmark import Bench from pprint import pformat as pf -#from wqflask import user_login -#from wqflask import user_session -from wqflask import user_manager +from wqflask import user_login +from wqflask import user_session from wqflask import collect from wqflask.database import db_session @@ -839,6 +838,8 @@ def browser_inputs(): return flask.jsonify(file_contents) + + ########################################################################## def json_default_handler(obj): |