aboutsummaryrefslogtreecommitdiff
path: root/gn2/wqflask/user_session.py
blob: af4e8cb469c53b65fa7cbfff7a55badaefd37b51 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
import datetime
import time
import uuid

import simplejson as json

from flask import (Flask, g, render_template, url_for, request, make_response,
                   redirect, flash, abort)

from gn2.wqflask import app
from gn2.utility import hmac

from gn2.utility.redis_tools import get_redis_conn, get_user_id, get_user_by_unique_column, set_user_attribute, get_user_collections, save_collections
Redis = get_redis_conn()


THREE_DAYS = 60 * 60 * 24 * 3
THIRTY_DAYS = 60 * 60 * 24 * 30


@app.before_request
def get_user_session():
    g.user_session = UserSession()
    # I think this should solve the issue of deleting the cookie and redirecting to the home page when a user's session has expired
    if not g.user_session:
        response = make_response(redirect(url_for('login')))
        response.set_cookie('session_id_v2', '', expires=0)
        return response


@app.after_request
def set_user_session(response):
    if hasattr(g, 'user_session'):
        if not request.cookies.get(g.user_session.cookie_name):
            response.set_cookie(g.user_session.cookie_name,
                                g.user_session.cookie)
    else:
        response.set_cookie('session_id_v2', '', expires=0)
    return response


def verify_cookie(cookie):
    the_uuid, separator, the_signature = cookie.partition(':')
    assert len(the_uuid) == 36, "Is session_id a uuid?"
    assert separator == ":", "Expected a : here"
    assert the_signature == hmac.hmac_creation(
        the_uuid), "Uh-oh, someone tampering with the cookie?"
    return the_uuid


def create_signed_cookie():
    the_uuid = str(uuid.uuid4())
    signature = hmac.hmac_creation(the_uuid)
    uuid_signed = the_uuid + ":" + signature
    return the_uuid, uuid_signed


@app.route("/user/manage", methods=('GET', 'POST'))
def manage_user():
    params = request.form if request.form else request.args
    if 'new_full_name' in params:
        set_user_attribute(g.user_session.user_id,
                           'full_name', params['new_full_name'])
    if 'new_organization' in params:
        set_user_attribute(g.user_session.user_id,
                           'organization', params['new_organization'])

    user_details = get_user_by_unique_column("user_id", g.user_session.user_id)

    return render_template("admin/manage_user.html", user_details=user_details)


class UserSession:
    """Logged in user handling"""

    user_cookie_name = 'session_id_v2'
    anon_cookie_name = 'anon_user_v1'

    def __init__(self):
        user_cookie = request.cookies.get(self.user_cookie_name)
        if not user_cookie:
            self.logged_in = False
            anon_cookie = request.cookies.get(self.anon_cookie_name)
            self.cookie_name = self.anon_cookie_name
            if anon_cookie:
                self.cookie = anon_cookie
                session_id = verify_cookie(self.cookie)
            else:
                session_id, self.cookie = create_signed_cookie()
        else:
            self.cookie_name = self.user_cookie_name
            self.cookie = user_cookie
            session_id = verify_cookie(self.cookie)

        self.redis_key = self.cookie_name + ":" + session_id
        self.session_id = session_id
        self.record = Redis.hgetall(self.redis_key)

        # ZS: If user correctly logged in but their session expired
        # ZS: Need to test this by setting the time-out to be really short or something
        if not self.record or self.record == []:
            if user_cookie:
                self.logged_in = False
                self.record = dict(login_time=time.time(),
                                   user_type="anon",
                                   user_id=str(uuid.uuid4()))
                Redis.hmset(self.redis_key, self.record)
                Redis.expire(self.redis_key, THIRTY_DAYS)

                # Grrr...this won't work because of the way flask handles cookies
                # Delete the cookie
                flash(
                    "Due to inactivity your session has expired. If you'd like please login again.")
                return None
            else:
                self.record = dict(login_time=time.time(),
                                   user_type="anon",
                                   user_id=str(uuid.uuid4()))
                Redis.hmset(self.redis_key, self.record)
                Redis.expire(self.redis_key, THIRTY_DAYS)
        else:
            if user_cookie:
                self.logged_in = True
                self.user_details = get_user_by_unique_column("user_id", self.user_id)
                if not self.user_details:
                    self.logged_in = False
                    return None

        if user_cookie:
            session_time = THREE_DAYS
        else:
            session_time = THIRTY_DAYS

        if Redis.ttl(self.redis_key) < session_time:
            # (Almost) everytime the user does something we extend the session_id in Redis...
            Redis.expire(self.redis_key, session_time)

    @property
    def user_id(self):
        """Shortcut to the user_id"""
        if b'user_id' not in self.record:
            self.record[b'user_id'] = str(uuid.uuid4())

        try:
            return self.record[b'user_id'].decode("utf-8")
        except:
            return self.record[b'user_id']

    @property
    def user_email(self):
        """Shortcut to the user email address"""

        if self.logged_in and 'email_address' in self.user_details:
            return self.user_details['email_address']
        else:
            return None

    @property
    def redis_user_id(self):
        """User id from Redis (need to check if this is the same as the id stored in self.records)"""

        # This part is a bit weird. Some accounts used to not have saved user ids, and in the process of testing I think I created some duplicate accounts for myself.
        # Accounts should automatically generate user_ids if they don't already have one now, so this might not be necessary for anything other than my account's collections

        if 'user_email_address' in self.record:
            user_email = self.record['user_email_address']

            # Get user's collections if they exist
            user_id = None
            user_id = get_user_id("email_address", user_email)
        elif 'user_id' in self.record:
            user_id = self.record['user_id']
        elif 'github_id' in self.record:
            user_github_id = self.record['github_id']
            user_id = None
            user_id = get_user_id("github_id", user_github_id)
        else:  # Anonymous user
            return None

        return user_id

    @property
    def user_name(self):
        """Shortcut to the user_name"""
        if 'user_name' in self.record:
            return self.record['user_name']
        else:
            return ''

    @property
    def user_collections(self):
        """List of user's collections"""

        # Get user's collections if they exist
        collections = get_user_collections(self.user_id)
        collections = [item for item in collections if item['name'] != "Your Default Collection"] + \
            [item for item in collections if item['name']
                == "Your Default Collection"]  # Ensure Default Collection is last in list
        return collections

    @property
    def num_collections(self):
        """Number of user's collections"""

        return len([item for item in self.user_collections if item['num_members'] > 0])

    def add_collection(self, collection_name, traits):
        """Add collection into Redis"""

        collection_dict = {'id': str(uuid.uuid4()),
                           'name': collection_name,
                           'created_timestamp': datetime.datetime.utcnow().strftime('%b %d %Y %I:%M%p'),
                           'changed_timestamp': datetime.datetime.utcnow().strftime('%b %d %Y %I:%M%p'),
                           'num_members': len(traits),
                           'members': list(traits)}

        current_collections = self.user_collections
        current_collections.append(collection_dict)
        self.update_collections(current_collections)

        return collection_dict['id']

    def change_collection_name(self, collection_id, new_name):
        updated_collections = []
        for collection in self.user_collections:
            updated_collection = collection
            if collection['id'] == collection_id:
                updated_collection['name'] = new_name
            updated_collections.append(collection)

        self.update_collections(updated_collections)
        return new_name

    def delete_collection(self, collection_id):
        """Remove collection with given ID"""

        updated_collections = []
        for collection in self.user_collections:
            if collection['id'] == collection_id:
                continue
            else:
                updated_collections.append(collection)

        self.update_collections(updated_collections)

        return collection['name']

    def add_traits_to_collection(self, collection_id, traits_to_add):
        """Add specified traits to a collection"""

        this_collection = self.get_collection_by_id(collection_id)

        updated_collection = this_collection
        current_members_minus_new = [
            member for member in this_collection['members'] if member not in traits_to_add]
        updated_traits = traits_to_add + current_members_minus_new

        updated_collection['members'] = updated_traits
        updated_collection['num_members'] = len(updated_traits)
        updated_collection['changed_timestamp'] = datetime.datetime.utcnow().strftime(
            '%b %d %Y %I:%M%p')

        updated_collections = []
        for collection in self.user_collections:
            if collection['id'] == collection_id:
                updated_collections.append(updated_collection)
            else:
                updated_collections.append(collection)

        self.update_collections(updated_collections)

    def remove_traits_from_collection(self, collection_id, traits_to_remove):
        """Remove specified traits from a collection"""

        this_collection = self.get_collection_by_id(collection_id)

        updated_collection = this_collection
        updated_traits = []
        for trait in this_collection['members']:
            if trait in traits_to_remove:
                continue
            else:
                updated_traits.append(trait)

        updated_collection['members'] = updated_traits
        updated_collection['num_members'] = len(updated_traits)
        updated_collection['changed_timestamp'] = datetime.datetime.utcnow().strftime(
            '%b %d %Y %I:%M%p')

        updated_collections = []
        for collection in self.user_collections:
            if collection['id'] == collection_id:
                updated_collections.append(updated_collection)
            else:
                updated_collections.append(collection)

        self.update_collections(updated_collections)

        return updated_traits

    def get_collection_by_id(self, collection_id):
        for collection in self.user_collections:
            if collection['id'] == collection_id:
                return collection

    def get_collection_by_name(self, collection_name):
        for collection in self.user_collections:
            if collection['name'] == collection_name:
                return collection

        return None

    def update_collections(self, updated_collections):
        collection_body = json.dumps(updated_collections)

        save_collections(self.user_id, collection_body)

    def import_traits_to_user(self, anon_id):
        collections = get_user_collections(anon_id)
        for collection in collections:
            collection_exists = self.get_collection_by_name(collection['name'])
            if collection_exists:
                continue
            else:
                self.add_collection(collection['name'], collection['members'])

    def delete_session(self):
        # And more importantly delete the redis record
        Redis.delete(self.redis_key)
        self.logged_in = False