From f9849394e3a252b5a1ac59c78a06728d20ca69ed Mon Sep 17 00:00:00 2001 From: zsloan Date: Tue, 3 Mar 2020 15:08:48 -0600 Subject: Refactored all the anonymous (not logged-in) user stuff so that both logged-in and anonymous users share the same code --- wqflask/utility/tools.py | 3 +- wqflask/wqflask/static/new/images/CITGLogo.png | Bin 0 -> 11962 bytes wqflask/wqflask/user_login.py | 467 +++++++++++++++++++++++++ wqflask/wqflask/user_session.py | 272 ++++++++++++++ wqflask/wqflask/views.py | 7 +- 5 files changed, 744 insertions(+), 5 deletions(-) create mode 100644 wqflask/wqflask/static/new/images/CITGLogo.png create mode 100644 wqflask/wqflask/user_login.py create mode 100644 wqflask/wqflask/user_session.py (limited to 'wqflask') diff --git a/wqflask/utility/tools.py b/wqflask/utility/tools.py index ec9f01bd..75bddb24 100644 --- a/wqflask/utility/tools.py +++ b/wqflask/utility/tools.py @@ -254,7 +254,6 @@ JS_GN_PATH = get_setting('JS_GN_PATH') GITHUB_CLIENT_ID = get_setting('GITHUB_CLIENT_ID') GITHUB_CLIENT_SECRET = get_setting('GITHUB_CLIENT_SECRET') -GITHUB_AUTH_URL = None if GITHUB_CLIENT_ID != 'UNKNOWN' and GITHUB_CLIENT_SECRET: GITHUB_AUTH_URL = "https://github.com/login/oauth/authorize?client_id=" + \ GITHUB_CLIENT_ID+"&client_secret="+GITHUB_CLIENT_SECRET @@ -264,7 +263,7 @@ ORCID_CLIENT_ID = get_setting('ORCID_CLIENT_ID') ORCID_CLIENT_SECRET = get_setting('ORCID_CLIENT_SECRET') ORCID_AUTH_URL = None if ORCID_CLIENT_ID != 'UNKNOWN' and ORCID_CLIENT_SECRET: - ORCID_AUTH_URL = "https://sandbox.orcid.org/oauth/authorize?response_type=code&scope=/authenticate&show_login=true&client_id=" + \ + ORCID_AUTH_URL = "https://orcid.org/oauth/authorize?response_type=code&scope=/authenticate&show_login=true&client_id=" + \ ORCID_CLIENT_ID+"&client_secret="+ORCID_CLIENT_SECRET ORCID_TOKEN_URL = get_setting('ORCID_TOKEN_URL') diff --git a/wqflask/wqflask/static/new/images/CITGLogo.png b/wqflask/wqflask/static/new/images/CITGLogo.png new file mode 100644 index 00000000..ae99fedb Binary files /dev/null and b/wqflask/wqflask/static/new/images/CITGLogo.png differ diff --git a/wqflask/wqflask/user_login.py b/wqflask/wqflask/user_login.py new file mode 100644 index 00000000..05885e2c --- /dev/null +++ b/wqflask/wqflask/user_login.py @@ -0,0 +1,467 @@ +from __future__ import print_function, division, absolute_import + +import os +import hashlib +import datetime +import time +import logging +import uuid +import hashlib +import hmac +import base64 +import requests + +import simplejson as json + +import redis # used for collections +Redis = redis.StrictRedis() + +from flask import (Flask, g, render_template, url_for, request, make_response, + redirect, flash, abort) + +from wqflask import app +from wqflask import pbkdf2 +from wqflask.hmac_func import hmac_creation +from wqflask.user_session import UserSession + +from utility.redis_tools import is_redis_available, get_user_id, get_user_by_unique_column, set_user_attribute, save_user, save_verification_code, check_verification_code, get_user_collections, save_collections + +from utility.logger import getLogger +logger = getLogger(__name__) + +from smtplib import SMTP +from utility.tools import SMTP_CONNECT, SMTP_USERNAME, SMTP_PASSWORD, LOG_SQL_ALCHEMY + +THREE_DAYS = 60 * 60 * 24 * 3 + +def timestamp(): + return datetime.datetime.utcnow().isoformat() + +def basic_info(): + return dict(timestamp = timestamp(), + ip_address = request.remote_addr, + user_agent = request.headers.get('User-Agent')) + +def encode_password(pass_gen_fields): + hashfunc = getattr(hashlib, pass_gen_fields['hashfunc']) + + salt = base64.b64decode(pass_gen_fields['salt']) + password = pbkdf2.pbkdf2_hex(str(pass_gen_fields['unencrypted_password']), + pass_gen_fields['salt'], + pass_gen_fields['iterations'], + pass_gen_fields['keylength'], + hashfunc) + + return password + +def set_password(password): + pass_gen_fields = { + "unencrypted_password": password, + "algorithm": "pbkdf2", + "hashfunc": "sha256", + "salt": base64.b64encode(os.urandom(32)), + "iterations": 100000, + "keylength": 32, + "created_timestamp": timestamp() + } + + assert len(password) >= 6, "Password shouldn't be shorter than 6 characters" + + encoded_password = encode_password(pass_gen_fields) + + return encoded_password + +def encrypt_password(unencrypted_password, pwfields): + hashfunc = getattr(hashlib, pwfields['hashfunc']) + salt = base64.b64decode(pwfields['salt']) + iterations = pwfields['iterations'] + keylength = pwfields['keylength'] + encrypted_password = pbkdf2.pbkdf2_hex(str(unencrypted_password), + salt, iterations, keylength, hashfunc) + return encrypted_password + +def get_signed_session_id(user): + session_id = str(uuid.uuid4()) + + session_id_signature = hmac_creation(session_id) + session_id_signed = session_id + ":" + session_id_signature + + #ZS: Need to check if this is ever actually used or exists + if 'user_id' not in user: + user['user_id'] = str(uuid.uuid4()) + save_user(user, user['user_id']) + + if 'github_id' in user: + session = dict(login_time = time.time(), + user_type = "github", + user_id = user['user_id'], + github_id = user['github_id'], + user_name = user['name'], + user_url = user['user_url']) + elif 'orcid' in user: + session = dict(login_time = time.time(), + user_type = "orcid", + user_id = user['user_id'], + github_id = user['orcid'], + user_name = user['name'], + user_url = user['user_url']) + else: + session = dict(login_time = time.time(), + user_type = "gn2", + user_id = user['user_id'], + user_name = user['full_name'], + user_email_address = user['email_address']) + + key = UserSession.user_cookie_name + ":" + session_id + Redis.hmset(key, session) + Redis.expire(key, THREE_DAYS) + + return session_id_signed + +def send_email(toaddr, msg, fromaddr="no-reply@genenetwork.org"): + """Send an E-mail through SMTP_CONNECT host. If SMTP_USERNAME is not + 'UNKNOWN' TLS is used + + """ + if SMTP_USERNAME == 'UNKNOWN': + server = SMTP(SMTP_CONNECT) + server.sendmail(fromaddr, toaddr, msg) + else: + server = SMTP(SMTP_CONNECT) + server.starttls() + server.login(SMTP_USERNAME, SMTP_PASSWORD) + server.sendmail(fromaddr, toaddr, msg) + server.quit() + logger.info("Successfully sent email to "+toaddr) + +def send_verification_email(user_details, template_name = "email/verification.txt", key_prefix = "verification_code", subject = "GeneNetwork email verification"): + verification_code = str(uuid.uuid4()) + key = key_prefix + ":" + verification_code + + data = json.dumps(dict(id=user_details['user_id'], timestamp = timestamp())) + + Redis.set(key, data) + Redis.expire(key, THREE_DAYS) + + recipient = user_details['email_address'] + body = render_template(template_name, verification_code = verification_code) + send_email(recipient, subject, body) + return {"recipient": recipient, "subject": subject, "body": body} + +@app.route("/n/login", methods=('GET', 'POST')) +def login(): + params = request.form if request.form else request.args + logger.debug("in login params are:", params) + + if not params: #ZS: If coming to page for first time + from utility.tools import GITHUB_AUTH_URL, GITHUB_CLIENT_ID, ORCID_AUTH_URL, ORCID_CLIENT_ID + external_login = {} + if GITHUB_AUTH_URL and GITHUB_CLIENT_ID != 'UNKNOWN': + external_login["github"] = GITHUB_AUTH_URL + if ORCID_AUTH_URL and ORCID_CLIENT_ID != 'UNKNOWN': + external_login["orcid"] = ORCID_AUTH_URL + return render_template("new_security/login_user.html", external_login = external_login, redis_is_available=is_redis_available()) + else: #ZS: After clicking sign-in + if 'type' in params and 'uid' in params: + user_details = get_user_by_unique_column("user_id", params['uid']) + if user_details: + session_id_signed = get_signed_session_id(user_details) + if 'name' in user_details and user_details['name'] != "None": + display_id = user_details['name'] + elif 'github_id' in user_details: + display_id = user_details['github_id'] + elif 'orcid' in user_details: + display_id = user_details['orcid'] + else: + display_id = "" + flash("Thank you for logging in {}.".format(display_id), "alert-success") + response = make_response(redirect(url_for('index_page'))) + response.set_cookie(UserSession.user_cookie_name, session_id_signed, max_age=None) + else: + flash("Something went unexpectedly wrong.", "alert-danger") + response = make_response(redirect(url_for('index_page'))) + return response + else: + user_details = get_user_by_unique_column("email_address", params['email_address']) + password_match = False + if user_details: + submitted_password = params['password'] + pwfields = json.loads(user_details['password']) + encrypted_pass = encrypt_password(submitted_password, pwfields) + password_match = pbkdf2.safe_str_cmp(encrypted_pass, pwfields['password']) + else: # Invalid e-mail + flash("Invalid e-mail address. Please try again.", "alert-danger") + response = make_response(redirect(url_for('login'))) + + return response + if password_match: # If password correct + if user_details['confirmed']: # If account confirmed + import_col = "false" + if 'import_collections' in params: + import_col = "true" + + session_id_signed = get_signed_session_id(user_details) + flash("Thank you for logging in {}.".format(user_details['full_name']), "alert-success") + response = make_response(redirect(url_for('index_page', import_collections = import_col))) + response.set_cookie(UserSession.user_cookie_name, session_id_signed, max_age=None) + return response + else: + email_ob = send_verification_email(user_details) + return render_template("newsecurity/verification_still_needed.html", subject=email_ob['subject']) + else: # Incorrect password + #ZS: It previously seemed to store that there was an incorrect log-in attempt here, but it did so in the MySQL DB so this might need to be reproduced with Redis + flash("Invalid password. Please try again.", "alert-danger") + response = make_response(redirect(url_for('login'))) + + return response + +@app.route("/n/login/github_oauth2", methods=('GET', 'POST')) +def github_oauth2(): + from utility.tools import GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET, GITHUB_AUTH_URL + code = request.args.get("code") + data = { + "client_id": GITHUB_CLIENT_ID, + "client_secret": GITHUB_CLIENT_SECRET, + "code": code + } + logger.debug("LOGIN DATA:", data) + result = requests.post("https://github.com/login/oauth/access_token", json=data) + result_dict = {arr[0]:arr[1] for arr in [tok.split("=") for tok in [token.encode("utf-8") for token in result.text.split("&")]]} + + github_user = get_github_user_details(result_dict["access_token"]) + + user_details = get_user_by_unique_column("github_id", github_user["id"]) + if user_details == None: + user_details = { + "user_id": str(uuid.uuid4()), + "name": github_user["name"].encode("utf-8") if github_user["name"] else "None", + "github_id": github_user["id"], + "user_url": github_user["html_url"].encode("utf-8"), + "login_type": "github", + "organization": "", + "active": 1, + "confirmed": 1 + } + save_user(user_details, user_details["user_id"]) + + url = "/n/login?type=github&uid="+user_details["user_id"] + return redirect(url) + +def get_github_user_details(access_token): + from utility.tools import GITHUB_API_URL + result = requests.get(GITHUB_API_URL, headers = {'Authorization':'token ' + access_token }).content + + return json.loads(result) + +@app.route("/n/login/orcid_oauth2", methods=('GET', 'POST')) +def orcid_oauth2(): + from uuid import uuid4 + from utility.tools import ORCID_CLIENT_ID, ORCID_CLIENT_SECRET, ORCID_TOKEN_URL, ORCID_AUTH_URL + code = request.args.get("code") + error = request.args.get("error") + url = "/n/login" + if code: + data = { + "client_id": ORCID_CLIENT_ID, + "client_secret": ORCID_CLIENT_SECRET, + "grant_type": "authorization_code", + "code": code + } + result = requests.post(ORCID_TOKEN_URL, data=data) + result_dict = json.loads(result.text.encode("utf-8")) + + user_details = get_user_by_unique_column("orcid", result_dict["orcid"]) + if user_details == None: + user_details = { + "user_id": str(uuid4()), + "name": result_dict["name"], + "orcid": result_dict["orcid"], + "user_url": "%s/%s" % ("/".join(ORCID_AUTH_URL.split("/")[:-2]), result_dict["orcid"]), + "login_type": "orcid", + "organization": "", + "active": 1, + "confirmed": 1 + } + save_user(user_details, user_details["user_id"]) + + url = "/n/login?type=orcid&uid="+user_details["user_id"] + else: + flash("There was an error getting code from ORCID") + return redirect(url) + +def get_github_user_details(access_token): + from utility.tools import GITHUB_API_URL + result = requests.get(GITHUB_API_URL, headers = {'Authorization':'token ' + access_token }).content + + return json.loads(result) + + +@app.route("/n/logout") +def logout(): + logger.debug("Logging out...") + UserSession().delete_session() + flash("You are now logged out. We hope you come back soon!") + response = make_response(redirect(url_for('index_page'))) + # Delete the cookie + response.set_cookie(UserSession.user_cookie_name, '', expires=0) + return response + +@app.route("/n/forgot_password", methods=['GET']) +def forgot_password(): + """Entry point for forgotten password""" + print("ARGS: ", request.args) + errors = {"no-email": request.args.get("no-email")} + print("ERRORS: ", errors) + return render_template("new_security/forgot_password.html", errors=errors) + +def send_forgot_password_email(verification_email): + from email.MIMEMultipart import MIMEMultipart + from email.MIMEText import MIMEText + + template_name = "email/forgot_password.txt" + key_prefix = "forgot_password_code" + subject = "GeneNetwork password reset" + fromaddr = "no-reply@genenetwork.org" + + verification_code = str(uuid.uuid4()) + key = key_prefix + ":" + verification_code + + data = { + "verification_code": verification_code, + "email_address": verification_email, + "timestamp": timestamp() + } + + save_verification_code(verification_email, verification_code) + + body = render_template(template_name, verification_code = verification_code) + + msg = MIMEMultipart() + msg["To"] = verification_email + msg["Subject"] = subject + msg["From"] = fromaddr + msg.attach(MIMEText(body, "plain")) + + send_email(verification_email, msg.as_string()) + + return subject + +@app.route("/n/forgot_password_submit", methods=('POST',)) +def forgot_password_submit(): + """When a forgotten password form is submitted we get here""" + params = request.form + email_address = params['email_address'] + next_page = None + if email_address != "": + logger.debug("Wants to send password E-mail to ", email_address) + user_details = get_user_by_unique_column("email_address", email_address) + if user_details: + email_subject = send_forgot_password_email(user_details["email_address"]) + return render_template("new_security/forgot_password_step2.html", + subject=email_subject) + else: + flash("The e-mail entered is not associated with an account.", "alert-danger") + return redirect(url_for("forgot_password")) + + else: + flash("You MUST provide an email", "alert-danger") + return redirect(url_for("forgot_password")) + +@app.route("/n/password_reset", methods=['GET']) +def password_reset(): + """Entry point after user clicks link in E-mail""" + logger.debug("in password_reset request.url is:", request.url) + + verification_code = request.args.get('code') + hmac = request.args.get('hm') + + if verification_code: + user_email = check_verification_code(verification_code) + if user_email: + user_details = get_user_by_unique_column('email_address', user_email) + if user_details: + return render_template( + "new_security/password_reset.html", user_encode=user_details["email_address"]) + else: + flash("Invalid code: User no longer exists!", "error") + else: + flash("Invalid code: Password reset code does not exist or might have expired!", "error") + else: + return redirect(url_for("login")) + +@app.route("/n/password_reset_step2", methods=('POST',)) +def password_reset_step2(): + """Handle confirmation E-mail for password reset""" + logger.debug("in password_reset request.url is:", request.url) + + errors = [] + user_email = request.form['user_encode'] + + password = request.form['password'] + encoded_password = set_password(password) + + set_user_attribute(user_id, "password", encoded_password) + + flash("Password changed successfully. You can now sign in.", "alert-info") + response = make_response(redirect(url_for('login'))) + + return response + +def register_user(params): + thank_you_mode = False + errors = [] + user_details = {} + + user_details['email_address'] = params.get('email_address', '').encode("utf-8").strip() + if not (5 <= len(user_details['email_address']) <= 50): + errors.append('Email Address needs to be between 5 and 50 characters.') + else: + email_exists = get_user_by_unique_column("email_address", user_details['email_address']) + if email_exists: + errors.append('User already exists with that email') + + user_details['full_name'] = params.get('full_name', '').encode("utf-8").strip() + if not (5 <= len(user_details['full_name']) <= 50): + errors.append('Full Name needs to be between 5 and 50 characters.') + + user_details['organization'] = params.get('organization', '').encode("utf-8").strip() + if user_details['organization'] and not (5 <= len(user_details['organization']) <= 50): + errors.append('Organization needs to be empty or between 5 and 50 characters.') + + password = str(params.get('password', '')) + if not (6 <= len(password)): + errors.append('Password needs to be at least 6 characters.') + + if params.get('password_confirm') != password: + errors.append("Passwords don't match.") + + if errors: + return errors + + user_details['password'] = set_password(password) + user_details['user_id'] = str(uuid.uuid4()) + user_details['confirmed'] = 1 + + user_details['registration_info'] = json.dumps(basic_info(), sort_keys=True) + save_user(user_details, user_details['user_id']) + +@app.route("/n/register", methods=('GET', 'POST')) +def register(): + errors = None + + params = request.form if request.form else request.args + params = params.to_dict(flat=True) + + if params: + logger.debug("Attempting to register the user...") + errors = register_user(params) + + if len(errors) == 0: + flash("Registration successful. You may login with your new account", "alert-info") + return redirect(url_for("login")) + + return render_template("new_security/register_user.html", values=params, errors=errors) + +@app.errorhandler(401) +def unauthorized(error): + return redirect(url_for('login')) \ No newline at end of file diff --git a/wqflask/wqflask/user_session.py b/wqflask/wqflask/user_session.py new file mode 100644 index 00000000..1f3e6558 --- /dev/null +++ b/wqflask/wqflask/user_session.py @@ -0,0 +1,272 @@ +from __future__ import print_function, division, absolute_import + +import datetime +import time +import uuid + +import simplejson as json + +import redis # used for collections +Redis = redis.StrictRedis() + +from flask import (Flask, g, render_template, url_for, request, make_response, + redirect, flash, abort) + +from wqflask import app +from wqflask.hmac_func import hmac_creation + +#from utility.elasticsearch_tools import get_elasticsearch_connection +from utility.redis_tools import get_user_id, get_user_by_unique_column, get_user_collections, save_collections + +from utility.logger import getLogger +logger = getLogger(__name__) + +THREE_DAYS = 60 * 60 * 24 * 3 +THIRTY_DAYS = 60 * 60 * 24 * 30 + +def verify_cookie(cookie): + the_uuid, separator, the_signature = cookie.partition(':') + assert len(the_uuid) == 36, "Is session_id a uuid?" + assert separator == ":", "Expected a : here" + assert the_signature == hmac_creation(the_uuid), "Uh-oh, someone tampering with the cookie?" + return the_uuid + +def create_signed_cookie(): + the_uuid = str(uuid.uuid4()) + signature = hmac_creation(the_uuid) + uuid_signed = the_uuid + ":" + signature + logger.debug("uuid_signed:", uuid_signed) + return the_uuid, uuid_signed + +class UserSession(object): + """Logged in user handling""" + + user_cookie_name = 'session_id_v1' + anon_cookie_name = 'anon_user_v1' + + def __init__(self): + user_cookie = request.cookies.get(self.user_cookie_name) + if not user_cookie: + self.logged_in = False + anon_cookie = request.cookies.get(self.anon_cookie_name) + self.cookie_name = self.anon_cookie_name + if anon_cookie: + self.cookie = anon_cookie + session_id = verify_cookie(self.cookie) + else: + session_id, self.cookie = create_signed_cookie() + else: + self.cookie_name = self.user_cookie_name + self.cookie = user_cookie + session_id = verify_cookie(self.cookie) + + self.redis_key = self.cookie_name + ":" + session_id + self.session_id = session_id + self.record = Redis.hgetall(self.redis_key) + + #ZS: If user correctled logged in but their session expired + #ZS: Need to test this by setting the time-out to be really short or something + if not self.record: + if user_cookie: + self.logged_in = False + + ########### Grrr...this won't work because of the way flask handles cookies + # Delete the cookie + response = make_response(redirect(url_for('login'))) + #response.set_cookie(self.cookie_name, '', expires=0) + flash("Due to inactivity your session has expired. If you'd like please login again.") + return response + #return + else: + self.record = dict(login_time = time.time(), + user_type = "anon", + user_id = str(uuid.uuid4())) + + Redis.hmset(self.redis_key, self.record) + Redis.expire(self.redis_key, THIRTY_DAYS) + else: + if user_cookie: + self.logged_in = True + + if user_cookie: + session_time = THREE_DAYS + else: + session_time = THIRTY_DAYS + + if Redis.ttl(self.redis_key) < session_time: + # (Almost) everytime the user does something we extend the session_id in Redis... + logger.debug("Extending ttl...") + Redis.expire(self.redis_key, session_time) + + @property + def user_id(self): + """Shortcut to the user_id""" + if 'user_id' in self.record: + return self.record['user_id'] + else: + return '' + + @property + def redis_user_id(self): + """User id from Redis (need to check if this is the same as the id stored in self.records)""" + + #ZS: This part is a bit weird. Some accounts used to not have saved user ids, and in the process of testing I think I created some duplicate accounts for myself. + #ZS: Accounts should automatically generate user_ids if they don't already have one now, so this might not be necessary for anything other than my account's collections + + if 'user_email_address' in self.record: + user_email = self.record['user_email_address'] + + #ZS: Get user's collections if they exist + user_id = None + user_id = get_user_id("email_address", user_email) + elif 'user_id' in self.record: + user_id = self.record['user_id'] + elif 'github_id' in self.record: + user_github_id = self.record['github_id'] + user_id = None + user_id = get_user_id("github_id", user_github_id) + else: #ZS: Anonymous user + return None + + return user_id + + @property + def user_name(self): + """Shortcut to the user_name""" + if 'user_name' in self.record: + return self.record['user_name'] + else: + return '' + + @property + def user_collections(self): + """List of user's collections""" + + #ZS: Get user's collections if they exist + collections = get_user_collections(self.redis_user_id) + return collections + + @property + def num_collections(self): + """Number of user's collections""" + + return len(self.user_collections) + + def add_collection(self, collection_name, traits): + """Add collection into ElasticSearch""" + + collection_dict = {'id': unicode(uuid.uuid4()), + 'name': collection_name, + 'created_timestamp': datetime.datetime.utcnow().strftime('%b %d %Y %I:%M%p'), + 'changed_timestamp': datetime.datetime.utcnow().strftime('%b %d %Y %I:%M%p'), + 'num_members': len(traits), + 'members': list(traits) } + + current_collections = self.user_collections + current_collections.append(collection_dict) + self.update_collections(current_collections) + + return collection_dict['id'] + + def change_collection_name(self, collection_id, new_name): + for collection in self.user_collections: + if collection['id'] == collection_id: + collection['name'] = new_name + break + + return new_name + + def delete_collection(self, collection_id): + """Remove collection with given ID""" + + updated_collections = [] + for collection in self.user_collections: + if collection['id'] == collection_id: + continue + else: + updated_collections.append(collection) + + self.update_collections(updated_collections) + + return collection['name'] + + def add_traits_to_collection(self, collection_id, traits_to_add): + """Add specified traits to a collection""" + + this_collection = self.get_collection_by_id(collection_id) + + updated_collection = this_collection + updated_traits = this_collection['members'] + traits_to_add + + updated_collection['members'] = updated_traits + updated_collection['num_members'] = len(updated_traits) + updated_collection['changed_timestamp'] = datetime.datetime.utcnow().strftime('%b %d %Y %I:%M%p') + + updated_collections = [] + for collection in self.user_collections: + if collection['id'] == collection_id: + updated_collections.append(updated_collection) + else: + updated_collections.append(collection) + + self.update_collections(updated_collections) + + def remove_traits_from_collection(self, collection_id, traits_to_remove): + """Remove specified traits from a collection""" + + this_collection = self.get_collection_by_id(collection_id) + + updated_collection = this_collection + updated_traits = [] + for trait in this_collection['members']: + if trait in traits_to_remove: + continue + else: + updated_traits.append(trait) + + updated_collection['members'] = updated_traits + updated_collection['num_members'] = len(updated_traits) + updated_collection['changed_timestamp'] = datetime.datetime.utcnow().strftime('%b %d %Y %I:%M%p') + + updated_collections = [] + for collection in self.user_collections: + if collection['id'] == collection_id: + updated_collections.append(updated_collection) + else: + updated_collections.append(collection) + + self.update_collections(updated_collections) + + return updated_traits + + def get_collection_by_id(self, collection_id): + for collection in self.user_collections: + if collection['id'] == collection_id: + return collection + + def get_collection_by_name(self, collection_name): + for collection in self.user_collections: + if collection['name'] == collection_name: + return collection + + return None + + def update_collections(self, updated_collections): + collection_body = json.dumps(updated_collections) + + save_collections(self.redis_user_id, collection_body) + + def delete_session(self): + # And more importantly delete the redis record + Redis.delete(self.redis_key) + self.logged_in = False + +@app.before_request +def before_request(): + g.user_session = UserSession() + +@app.after_request +def set_cookie(response): + if not request.cookies.get(g.user_session.cookie_name): + response.set_cookie(g.user_session.cookie_name, g.user_session.cookie) + return response \ No newline at end of file diff --git a/wqflask/wqflask/views.py b/wqflask/wqflask/views.py index a84a7480..436ebc91 100644 --- a/wqflask/wqflask/views.py +++ b/wqflask/wqflask/views.py @@ -65,9 +65,8 @@ from utility.benchmark import Bench from pprint import pformat as pf -#from wqflask import user_login -#from wqflask import user_session -from wqflask import user_manager +from wqflask import user_login +from wqflask import user_session from wqflask import collect from wqflask.database import db_session @@ -839,6 +838,8 @@ def browser_inputs(): return flask.jsonify(file_contents) + + ########################################################################## def json_default_handler(obj): -- cgit v1.2.3