about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/oauthlib/oauth2/rfc6749/endpoints/base.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/oauthlib/oauth2/rfc6749/endpoints/base.py')
-rw-r--r--.venv/lib/python3.12/site-packages/oauthlib/oauth2/rfc6749/endpoints/base.py113
1 files changed, 113 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/oauthlib/oauth2/rfc6749/endpoints/base.py b/.venv/lib/python3.12/site-packages/oauthlib/oauth2/rfc6749/endpoints/base.py
new file mode 100644
index 00000000..3f239917
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/oauthlib/oauth2/rfc6749/endpoints/base.py
@@ -0,0 +1,113 @@
+"""
+oauthlib.oauth2.rfc6749
+~~~~~~~~~~~~~~~~~~~~~~~
+
+This module is an implementation of various logic needed
+for consuming and providing OAuth 2.0 RFC6749.
+"""
+import functools
+import logging
+
+from ..errors import (
+    FatalClientError, InvalidClientError, InvalidRequestError, OAuth2Error,
+    ServerError, TemporarilyUnavailableError, UnsupportedTokenTypeError,
+)
+
+log = logging.getLogger(__name__)
+
+
+class BaseEndpoint:
+
+    def __init__(self):
+        self._available = True
+        self._catch_errors = False
+        self._valid_request_methods = None
+
+    @property
+    def valid_request_methods(self):
+        return self._valid_request_methods
+
+    @valid_request_methods.setter
+    def valid_request_methods(self, valid_request_methods):
+        if valid_request_methods is not None:
+            valid_request_methods = [x.upper() for x in valid_request_methods]
+        self._valid_request_methods = valid_request_methods
+    
+
+    @property
+    def available(self):
+        return self._available
+
+    @available.setter
+    def available(self, available):
+        self._available = available       
+
+    @property
+    def catch_errors(self):
+        return self._catch_errors
+
+    @catch_errors.setter
+    def catch_errors(self, catch_errors):
+        self._catch_errors = catch_errors
+
+    def _raise_on_missing_token(self, request):
+        """Raise error on missing token."""
+        if not request.token:
+            raise InvalidRequestError(request=request,
+                                      description='Missing token parameter.')
+    def _raise_on_invalid_client(self, request):
+        """Raise on failed client authentication."""
+        if self.request_validator.client_authentication_required(request):
+            if not self.request_validator.authenticate_client(request):
+                log.debug('Client authentication failed, %r.', request)
+                raise InvalidClientError(request=request)
+        elif not self.request_validator.authenticate_client_id(request.client_id, request):
+            log.debug('Client authentication failed, %r.', request)
+            raise InvalidClientError(request=request)
+
+    def _raise_on_unsupported_token(self, request):
+        """Raise on unsupported tokens."""
+        if (request.token_type_hint and
+            request.token_type_hint in self.valid_token_types and
+            request.token_type_hint not in self.supported_token_types):
+            raise UnsupportedTokenTypeError(request=request)
+
+    def _raise_on_bad_method(self, request):
+        if self.valid_request_methods is None:
+            raise ValueError('Configure "valid_request_methods" property first')
+        if request.http_method.upper() not in self.valid_request_methods:
+            raise InvalidRequestError(request=request,
+                                      description=('Unsupported request method %s' % request.http_method.upper()))
+
+    def _raise_on_bad_post_request(self, request):
+        """Raise if invalid POST request received
+        """
+        if request.http_method.upper() == 'POST':
+            query_params = request.uri_query or ""
+            if query_params:
+                raise InvalidRequestError(request=request,
+                                          description=('URL query parameters are not allowed'))
+
+def catch_errors_and_unavailability(f):
+    @functools.wraps(f)
+    def wrapper(endpoint, uri, *args, **kwargs):
+        if not endpoint.available:
+            e = TemporarilyUnavailableError()
+            log.info('Endpoint unavailable, ignoring request %s.' % uri)
+            return {}, e.json, 503
+
+        if endpoint.catch_errors:
+            try:
+                return f(endpoint, uri, *args, **kwargs)
+            except OAuth2Error:
+                raise
+            except FatalClientError:
+                raise
+            except Exception as e:
+                error = ServerError()
+                log.warning(
+                    'Exception caught while processing request, %s.' % e)
+                return {}, error.json, 500
+        else:
+            return f(endpoint, uri, *args, **kwargs)
+    return wrapper