about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--wqflask/wqflask/user_manager.py58
1 files changed, 53 insertions, 5 deletions
diff --git a/wqflask/wqflask/user_manager.py b/wqflask/wqflask/user_manager.py
index 9012c842..daeb7bc5 100644
--- a/wqflask/wqflask/user_manager.py
+++ b/wqflask/wqflask/user_manager.py
@@ -54,6 +54,9 @@ logger = getLogger(__name__)
 
 from base.data_set import create_datasets_list
 
+import requests
+from utility.elasticsearch_tools import get_user_by_unique_column, save_user
+
 THREE_DAYS = 60 * 60 * 24 * 3
 #THREE_DAYS = 45
 
@@ -492,13 +495,16 @@ class DecodeUser(object):
 @app.route("/n/login", methods=('GET', 'POST'))
 def login():
     lu = LoginUser()
-    return lu.standard_login()
+    login_type = request.args.get("type")
+    if login_type:
+        uid = request.args.get("uid")
+        return lu.oauth2_login(login_type, uid)
+    else:
+        return lu.standard_login()
 
 @app.route("/n/login/github_oauth2", methods=('GET', 'POST'))
 def github_oauth2():
     from utility.tools import GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET
-    from utility.elasticsearch_tools import get_user_by_unique_column
-    import requests
     code = request.args.get("code")
     data = {
         "client_id": GITHUB_CLIENT_ID,
@@ -512,13 +518,15 @@ def github_oauth2():
     user_details = get_user_by_unique_column("github_id", github_user["id"])
     if user_details == None:
         user_details = {
-            "user_id": str(uuid4())
+            "user_id": str(uuid.uuid4())
             , "name": github_user["name"]
             , "github_id": github_user["id"]
             , "user_url": github_user["html_url"]
             , "login_type": "github"
+            , "organization": ""
         }
-    url = "/n/login?type=github"
+        save_user(user_details, user_details.get("user_id"))
+    url = "/n/login?type=github&uid="+user_details["user_id"]
     return redirect(url)
 
 def get_github_user_details(access_token):
@@ -532,6 +540,46 @@ class LoginUser(object):
     def __init__(self):
         self.remember_me = False
 
+    def oauth2_login(self, login_type, user_id):
+        """Login via an OAuth2 provider"""
+        user_details = get_user_by_unique_column("user_id", user_id)
+        if user_details:
+            user = model.User()
+            user.id = user_details["user_id"]
+            user.full_name = user_details["name"]
+            user.login_type = user_details["login_type"]
+            return self.actual_login_oauth2(user)
+        else:
+            flash("Error logging in via OAuth2")
+            return make_response(redirect(url_for('login')))
+
+    def actual_login_oauth2(self, user, assumed_by=None, import_collections=None):
+        """The meat of the logging in process"""
+        session_id_signed = self.successful_login_oauth2(user)
+        flash("Thank you for logging in {}.".format(user.full_name), "alert-success")
+        print("IMPORT1:", import_collections)
+        response = make_response(redirect(url_for('index_page', import_collections=import_collections)))
+        if self.remember_me:
+            max_age = self.remember_time
+        else:
+            max_age = None
+        response.set_cookie(UserSession.cookie_name, session_id_signed, max_age=max_age)
+        return response
+
+    def successful_login_oauth2(self, user, assumed_by=None):
+        login_rec = model.Login(user)
+        login_rec.successful = True
+        login_rec.session_id = str(uuid.uuid4())
+        login_rec.assumed_by = assumed_by
+        session_id_signature = actual_hmac_creation(login_rec.session_id)
+        session_id_signed = login_rec.session_id + ":" + session_id_signature
+        logger.debug("session_id_signed:", session_id_signed)
+
+        session = dict(login_time = time.time(),
+                       user_id = user.id,
+                       user_login_type = user.login_type)
+        return session_id_signed
+
     def standard_login(self):
         """Login through the normal form"""
         params = request.form if request.form else request.args