aboutsummaryrefslogtreecommitdiff
path: root/wqflask/flask_security/decorators.py
diff options
context:
space:
mode:
Diffstat (limited to 'wqflask/flask_security/decorators.py')
-rw-r--r--wqflask/flask_security/decorators.py207
1 files changed, 207 insertions, 0 deletions
diff --git a/wqflask/flask_security/decorators.py b/wqflask/flask_security/decorators.py
new file mode 100644
index 00000000..0ea1105c
--- /dev/null
+++ b/wqflask/flask_security/decorators.py
@@ -0,0 +1,207 @@
+# -*- coding: utf-8 -*-
+"""
+ flask.ext.security.decorators
+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+ Flask-Security decorators module
+
+ :copyright: (c) 2012 by Matt Wright.
+ :license: MIT, see LICENSE for more details.
+"""
+
+from functools import wraps
+
+from flask import current_app, Response, request, redirect, _request_ctx_stack
+from flask.ext.login import current_user, login_required
+from flask.ext.principal import RoleNeed, Permission, Identity, identity_changed
+from werkzeug.local import LocalProxy
+
+from . import utils
+
+
+# Convenient references
+_security = LocalProxy(lambda: current_app.extensions['security'])
+
+
+_default_unauthorized_html = """
+ <h1>Unauthorized</h1>
+ <p>The server could not verify that you are authorized to access the URL
+ requested. You either supplied the wrong credentials (e.g. a bad password),
+ or your browser doesn't understand how to supply the credentials required.</p>
+ """
+
+
+def _get_unauthorized_response(text=None, headers=None):
+ text = text or _default_unauthorized_html
+ headers = headers or {}
+ return Response(text, 401, headers)
+
+
+def _get_unauthorized_view():
+ cv = utils.get_url(utils.config_value('UNAUTHORIZED_VIEW'))
+ utils.do_flash(*utils.get_message('UNAUTHORIZED'))
+ return redirect(cv or request.referrer or '/')
+
+
+def _check_token():
+ header_key = _security.token_authentication_header
+ args_key = _security.token_authentication_key
+ header_token = request.headers.get(header_key, None)
+ token = request.args.get(args_key, header_token)
+ if request.json:
+ token = request.json.get(args_key, token)
+ serializer = _security.remember_token_serializer
+
+ try:
+ data = serializer.loads(token)
+ except:
+ return False
+
+ user = _security.datastore.find_user(id=data[0])
+
+ if utils.md5(user.password) == data[1]:
+ app = current_app._get_current_object()
+ _request_ctx_stack.top.user = user
+ identity_changed.send(app, identity=Identity(user.id))
+ return True
+
+
+def _check_http_auth():
+ auth = request.authorization or dict(username=None, password=None)
+ user = _security.datastore.find_user(email=auth.username)
+
+ if user and utils.verify_and_update_password(auth.password, user):
+ _security.datastore.commit()
+ app = current_app._get_current_object()
+ _request_ctx_stack.top.user = user
+ identity_changed.send(app, identity=Identity(user.id))
+ return True
+
+ return False
+
+
+def http_auth_required(realm):
+ """Decorator that protects endpoints using Basic HTTP authentication.
+ The username should be set to the user's email address.
+
+ :param realm: optional realm name"""
+
+ def decorator(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ if _check_http_auth():
+ return fn(*args, **kwargs)
+ r = _security.default_http_auth_realm if callable(realm) else realm
+ h = {'WWW-Authenticate': 'Basic realm="%s"' % r}
+ return _get_unauthorized_response(headers=h)
+ return wrapper
+
+ if callable(realm):
+ return decorator(realm)
+ return decorator
+
+
+def auth_token_required(fn):
+ """Decorator that protects endpoints using token authentication. The token
+ should be added to the request by the client by using a query string
+ variable with a name equal to the configuration value of
+ `SECURITY_TOKEN_AUTHENTICATION_KEY` or in a request header named that of
+ the configuration value of `SECURITY_TOKEN_AUTHENTICATION_HEADER`
+ """
+
+ @wraps(fn)
+ def decorated(*args, **kwargs):
+ if _check_token():
+ return fn(*args, **kwargs)
+ return _get_unauthorized_response()
+ return decorated
+
+
+def auth_required(*auth_methods):
+ """
+ Decorator that protects enpoints through multiple mechanisms
+ Example::
+
+ @app.route('/dashboard')
+ @auth_required('token', 'session')
+ def dashboard():
+ return 'Dashboard'
+
+ :param auth_methods: Specified mechanisms.
+ """
+ login_mechanisms = {
+ 'token': lambda: _check_token(),
+ 'basic': lambda: _check_http_auth(),
+ 'session': lambda: current_user.is_authenticated()
+ }
+
+ def wrapper(fn):
+ @wraps(fn)
+ def decorated_view(*args, **kwargs):
+ mechanisms = [login_mechanisms.get(method) for method in auth_methods]
+ for mechanism in mechanisms:
+ if mechanism and mechanism():
+ return fn(*args, **kwargs)
+ return _get_unauthorized_response()
+ return decorated_view
+ return wrapper
+
+
+def roles_required(*roles):
+ """Decorator which specifies that a user must have all the specified roles.
+ Example::
+
+ @app.route('/dashboard')
+ @roles_required('admin', 'editor')
+ def dashboard():
+ return 'Dashboard'
+
+ The current user must have both the `admin` role and `editor` role in order
+ to view the page.
+
+ :param args: The required roles.
+ """
+ def wrapper(fn):
+ @wraps(fn)
+ def decorated_view(*args, **kwargs):
+ perms = [Permission(RoleNeed(role)) for role in roles]
+ for perm in perms:
+ if not perm.can():
+ return _get_unauthorized_view()
+ return fn(*args, **kwargs)
+ return decorated_view
+ return wrapper
+
+
+def roles_accepted(*roles):
+ """Decorator which specifies that a user must have at least one of the
+ specified roles. Example::
+
+ @app.route('/create_post')
+ @roles_accepted('editor', 'author')
+ def create_post():
+ return 'Create Post'
+
+ The current user must have either the `editor` role or `author` role in
+ order to view the page.
+
+ :param args: The possible roles.
+ """
+ def wrapper(fn):
+ @wraps(fn)
+ def decorated_view(*args, **kwargs):
+ perm = Permission(*[RoleNeed(role) for role in roles])
+ if perm.can():
+ return fn(*args, **kwargs)
+ return _get_unauthorized_view()
+ return decorated_view
+ return wrapper
+
+
+def anonymous_user_required(f):
+ @wraps(f)
+ def wrapper(*args, **kwargs):
+ if current_user.is_authenticated():
+ return redirect(utils.get_url(_security.post_login_view))
+ return f(*args, **kwargs)
+ return wrapper