about summary refs log tree commit diff
path: root/wqflask/wqflask/collect.py
diff options
context:
space:
mode:
Diffstat (limited to 'wqflask/wqflask/collect.py')
-rw-r--r--wqflask/wqflask/collect.py86
1 files changed, 68 insertions, 18 deletions
diff --git a/wqflask/wqflask/collect.py b/wqflask/wqflask/collect.py
index 34def295..a77e19af 100644
--- a/wqflask/wqflask/collect.py
+++ b/wqflask/wqflask/collect.py
@@ -44,28 +44,54 @@ from wqflask import user_manager
 from base import trait
 
 
+def get_collection():
+    if g.user_session.logged_in:
+        return UserCollection()
+    else:
+        return AnonCollection()
+    #else:
+    #    CauseError
+
 class AnonCollection(object):
     """User is not logged in"""
     def __init__(self):
         self.anon_user = user_manager.AnonUser()
-        self.key = "anon_collection:v1:{}".format(self.anon_user.anon_id)
-    
+        self.key = "anon_collection:v4:{}".format(self.anon_user.anon_id)
     
-    def add_traits(params, collection_name):
+    def add_traits(self, params, collection_name):
         assert collection_name == "Default", "Unexpected collection name for anonymous user"
+        print("params[traits]:", params['traits'])
         traits = process_traits(params['traits'])
-        len_before = len(Redis.smembers)
-        Redis.sadd(self.key, traits)
+        print("traits is:", traits)
+        print("self.key is:", self.key)
+        len_before = len(Redis.smembers(self.key))
+        Redis.sadd(self.key, *list(traits))
         Redis.expire(self.key, 60 * 60 * 24 * 3)
-        len_now = len(Redis.smembers)
+        print("currently in redis:", Redis.smembers(self.key))
+        len_now = len(Redis.smembers(self.key))
         report_change(len_before, len_now)
         
+    def remove_traits(self, params):
+        traits_to_remove = params.getlist('traits[]')
+        print("traits_to_remove:", traits_to_remove)
+        len_before = len(Redis.smembers(self.key))
+        Redis.srem(self.key, traits_to_remove)
+        len_now = len(Redis.smembers(self.key))
+        print("Went from {} to {} members in set.".format(len(self.collection_members), len(members_now)))
+
+        # We need to return something so we'll return this...maybe in the future
+        # we can use it to check the results
+        return str(len(members_now))
     
+    def get_traits(self):
+        traits = Redis.smembers(self.key)
+        print("traits:", traits)
+        return traits
     
 class UserCollection(object):
     """User is logged in"""
     
-    def add_traits(params, collection_name):
+    def add_traits(self, params, collection_name):
         print("---> params are:", params.keys())
         print("     type(params):", type(params))
         if collection_name=="Default":
@@ -94,15 +120,40 @@ class UserCollection(object):
         # Probably have to change that
         return redirect(url_for('view_collection', uc_id=uc.id))
     
+    def remove_traits(self, params):
+    
+        #params = request.form
+        print("params are:", params)
+        uc_id = params['uc_id']
+        uc = model.UserCollection.query.get(uc_id)
+        traits_to_remove = params.getlist('traits[]')
+        print("traits_to_remove are:", traits_to_remove)
+        traits_to_remove = process_traits(traits_to_remove)
+        print("\n\n  after processing, traits_to_remove:", traits_to_remove)
+        all_traits = uc.members_as_set()
+        print("  all_traits:", all_traits)
+        members_now = all_traits - traits_to_remove
+        print("  members_now:", members_now)
+        print("Went from {} to {} members in set.".format(len(all_traits), len(members_now)))
+        uc.members = json.dumps(list(members_now))
+        uc.changed_timestamp = datetime.datetime.utcnow()
+        db_session.commit()
+    
+        # We need to return something so we'll return this...maybe in the future
+        # we can use it to check the results
+        return str(len(members_now))
+    
 def report_change(len_before, len_now):
     new_length = len_now - len_before
     if new_length:
+        print("We've added {} to your collection.".format(
+            numify(new_length, 'new trait', 'new traits')))
         flash("We've added {} to your collection.".format(
             numify(new_length, 'new trait', 'new traits')))
     else:
-        flash("No new traits were added.")
+        print("No new traits were added.")
+
 
-    
 
 
 @app.route("/collections/add")
@@ -128,7 +179,8 @@ def collections_new():
     print("request.args in collections_new are:", params)
 
     if "anonymous_add" in params:
-        return add_anon_traits(params)
+        AnonCollection().add_traits(params, "Default")
+        return redirect(url_for('view_collection'))
 
     collection_name = params['new_collection']
 
@@ -143,8 +195,6 @@ def collections_new():
         CauseAnError
 
 
-
-
 def process_traits(unprocessed_traits):
     print("unprocessed_traits are:", unprocessed_traits)
     if isinstance(unprocessed_traits, basestring):
@@ -239,11 +289,9 @@ def view_collection():
         uc_id = params['uc_id']
         uc = model.UserCollection.query.get(uc_id)
         traits = json.loads(uc.members)
+        print("traits are:", traits)
     else:
-        anon_id = params['key']
-        uc = model.AnonCollection(anon_id = anon_id)
-        traits = Redis.smembers(anon_id)
-        print("the traits are:", traits)
+        traits = AnonCollection().get_traits()
 
     print("in view_collection traits are:", traits)
 
@@ -251,6 +299,7 @@ def view_collection():
     json_version = []
 
     for atrait in traits:
+        print("atrait is:", atrait)
         name, dataset_name = atrait.split(':')
 
         trait_ob = trait.GeneralTrait(name=name, dataset_name=dataset_name)
@@ -267,8 +316,9 @@ def view_collection():
         #                         dis=trait_ob.description))
         #json_version.append(trait_ob.__dict__th)
         
-    collection_info = dict(trait_obs=trait_obs,
-                           uc = uc)
+    #collection_info = dict(trait_obs=trait_obs,
+    #                       uc = uc)
+    collection_info = dict(trait_obs=trait_obs)
     if "json" in params:
         print("json_version:", json_version)
         return json.dumps(json_version)