about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_check.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_check.py')
-rw-r--r--.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_check.py80
1 files changed, 80 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_check.py b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_check.py
new file mode 100644
index 00000000..4851c270
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/litellm/proxy/auth/oauth2_check.py
@@ -0,0 +1,80 @@
+from litellm.proxy._types import UserAPIKeyAuth
+
+
+async def check_oauth2_token(token: str) -> UserAPIKeyAuth:
+    """
+    Makes a request to the token info endpoint to validate the OAuth2 token.
+
+    Args:
+    token (str): The OAuth2 token to validate.
+
+    Returns:
+    Literal[True]: If the token is valid.
+
+    Raises:
+    ValueError: If the token is invalid, the request fails, or the token info endpoint is not set.
+    """
+    import os
+
+    import httpx
+
+    from litellm._logging import verbose_proxy_logger
+    from litellm.llms.custom_httpx.http_handler import (
+        get_async_httpx_client,
+        httpxSpecialProvider,
+    )
+    from litellm.proxy._types import CommonProxyErrors
+    from litellm.proxy.proxy_server import premium_user
+
+    if premium_user is not True:
+        raise ValueError(
+            "Oauth2 token validation is only available for premium users"
+            + CommonProxyErrors.not_premium_user.value
+        )
+
+    verbose_proxy_logger.debug("Oauth2 token validation for token=%s", token)
+    # Get the token info endpoint from environment variable
+    token_info_endpoint = os.getenv("OAUTH_TOKEN_INFO_ENDPOINT")
+    user_id_field_name = os.environ.get("OAUTH_USER_ID_FIELD_NAME", "sub")
+    user_role_field_name = os.environ.get("OAUTH_USER_ROLE_FIELD_NAME", "role")
+    user_team_id_field_name = os.environ.get("OAUTH_USER_TEAM_ID_FIELD_NAME", "team_id")
+
+    if not token_info_endpoint:
+        raise ValueError("OAUTH_TOKEN_INFO_ENDPOINT environment variable is not set")
+
+    client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
+    headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
+
+    try:
+        response = await client.get(token_info_endpoint, headers=headers)
+
+        # if it's a bad token we expect it to raise an HTTPStatusError
+        response.raise_for_status()
+
+        # If we get here, the request was successful
+        data = response.json()
+
+        verbose_proxy_logger.debug(
+            "Oauth2 token validation for token=%s, response from /token/info=%s",
+            token,
+            data,
+        )
+
+        # You might want to add additional checks here based on the response
+        # For example, checking if the token is expired or has the correct scope
+        user_id = data.get(user_id_field_name)
+        user_team_id = data.get(user_team_id_field_name)
+        user_role = data.get(user_role_field_name)
+
+        return UserAPIKeyAuth(
+            api_key=token,
+            team_id=user_team_id,
+            user_id=user_id,
+            user_role=user_role,
+        )
+    except httpx.HTTPStatusError as e:
+        # This will catch any 4xx or 5xx errors
+        raise ValueError(f"Oauth 2.0 Token validation failed: {e}")
+    except Exception as e:
+        # This will catch any other errors (like network issues)
+        raise ValueError(f"An error occurred during token validation: {e}")