about summary refs log tree commit diff
path: root/wqflask/wqflask/collect.py
diff options
context:
space:
mode:
Diffstat (limited to 'wqflask/wqflask/collect.py')
-rw-r--r--wqflask/wqflask/collect.py84
1 files changed, 50 insertions, 34 deletions
diff --git a/wqflask/wqflask/collect.py b/wqflask/wqflask/collect.py
index 4ea8407c..5ad6b1f4 100644
--- a/wqflask/wqflask/collect.py
+++ b/wqflask/wqflask/collect.py
@@ -57,16 +57,25 @@ class AnonCollection(object):
     def __init__(self):
         self.anon_user = user_manager.AnonUser()
         self.key = "anon_collection:v5:{}".format(self.anon_user.anon_id)
-
+        self.name = None
+
+    @property
+    def num_members(self):
+        try:
+            return len(Redis.smembers(self.key))
+        except:
+            return 0
+        
     def add_traits(self, params, collection_name):
-        assert collection_name == "Default", "Unexpected collection name for anonymous user"
+        #assert collection_name == "Default", "Unexpected collection name for anonymous user"
+        self.name = collection_name
         print("params[traits]:", params['traits'])
         traits = process_traits(params['traits'])
         print("traits is:", traits)
         print("self.key is:", self.key)
         len_before = len(Redis.smembers(self.key))
         Redis.sadd(self.key, *list(traits))
-        Redis.expire(self.key, 60 * 60 * 24 * 3)
+        Redis.expire(self.key, 60 * 60 * 24 * 5)
         print("currently in redis:", Redis.smembers(self.key))
         len_now = len(Redis.smembers(self.key))
         report_change(len_before, len_now)
@@ -165,12 +174,17 @@ def collections_add():
         print("user_collections are:", user_collections)
         return render_template("collections/add.html",
                                traits=traits,
-                               user_collections = user_collections,
+                               collections = user_collections,
                                )
     else:
-        return render_template("collections/add_anonymous.html",
-                               traits=traits
-                               )
+        anon_collections = user_manager.AnonUser().get_collections()
+        return render_template("collections/add.html",
+                                   traits=traits,
+                                   collections = anon_collections,
+                                   )
+        # return render_template("collections/add_anonymous.html",
+                                   # traits=traits
+                                   # )
 
 
 @app.route("/collections/new")
@@ -178,14 +192,14 @@ def collections_new():
     params = request.args
     print("request.args in collections_new are:", params)
 
+    collection_name = params['new_collection']
+    
     if "anonymous_add" in params:
-        AnonCollection().add_traits(params, "Default")
+        AnonCollection(name=collection_name).add_traits(params, "Default")
         return redirect(url_for('view_collection'))
-    elif "sign_in" in params:
+    if "sign_in" in params:
         return redirect(url_for('login'))
 
-    collection_name = params['new_collection']
-
     if "create_new" in params:
         print("in create_new")
         return create_new(collection_name)
@@ -214,26 +228,29 @@ def process_traits(unprocessed_traits):
 
 def create_new(collection_name):
     params = request.args
-    uc = model.UserCollection()
-    uc.name = collection_name
-    print("user_session:", g.user_session.__dict__)
-    uc.user = g.user_session.user_id
+    
     unprocessed_traits = params['traits']
-
     traits = process_traits(unprocessed_traits)
+    
+    if 'uc_id' in params:
+        uc = model.UserCollection()
+        uc.name = collection_name
+        print("user_session:", g.user_session.__dict__)
+        uc.user = g.user_session.user_id
+        uc.members = json.dumps(list(traits))
+        db_session.add(uc)
+        db_session.commit()
+    else:
+        ac = AnonCollection().add_traits(params, collection_name)
+        print("traits are:", ac.get_traits())
+        user_manager.AnonUser().add_collection(ac)
 
-    uc.members = json.dumps(list(traits))
-    print("traits are:", traits)
-
-    db_session.add(uc)
-    db_session.commit()
-
-    print("Created: " + uc.name)
     return redirect(url_for('view_collection', uc_id=uc.id))
 
 @app.route("/collections/list")
 def list_collections():
     params = request.args
+    print("PARAMS:", params)
     try:
         user_collections = list(g.user_session.user_ob.user_collections)
         print("user_collections are:", user_collections)
@@ -285,7 +302,6 @@ def delete_collection():
     # But might want to check ownership in the future
     collection_name = uc.name
     db_session.delete(uc)
-    db_session.commit()
     flash("We've deletet the collection: {}.".format(collection_name), "alert-info")
 
     return redirect(url_for('list_collections'))
@@ -297,14 +313,14 @@ def view_collection():
     params = request.args
     print("PARAMS in view collection:", params)
 
-    #if "uc_id" in params:
-    uc_id = params['uc_id']
-    uc = model.UserCollection.query.get(uc_id)
-    traits = json.loads(uc.members)
-    print("traits are:", traits)
-    #else:
-    #    traits = AnonCollection().get_traits()
-
+    if "uc_id" in params:
+        uc_id = params['uc_id']
+        uc = model.UserCollection.query.get(uc_id)
+        traits = json.loads(uc.members)
+        print("traits are:", traits)
+    else:
+        traits = AnonCollection().get_traits()
+                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   
     print("in view_collection traits are:", traits)
 
     trait_obs = []
@@ -312,8 +328,8 @@ def view_collection():
 
     for atrait in traits:
         print("atrait is:", atrait)
-        name, dataset_name = atrait.split(':')
-
+        name, dataset_name = atrait.split(':')                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        
+                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         
         trait_ob = trait.GeneralTrait(name=name, dataset_name=dataset_name)
         trait_ob.retrieve_info(get_qtl_info=True)
         trait_obs.append(trait_ob)