about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/__init__.py20
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/get_token_mixin.py94
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_base.py55
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_client.py38
4 files changed, 207 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/__init__.py
new file mode 100644
index 00000000..1059f88c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/__init__.py
@@ -0,0 +1,20 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import abc
+from typing import Any
+
+
+class AsyncContextManager(abc.ABC):
+    @abc.abstractmethod
+    async def close(self) -> None:
+        pass
+
+    async def __aenter__(self) -> "AsyncContextManager":
+        return self
+
+    async def __aexit__(self, *args: Any) -> None:
+        await self.close()
+
+
+__all__ = ["AsyncContextManager"]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/get_token_mixin.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/get_token_mixin.py
new file mode 100644
index 00000000..dc2dcef2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/get_token_mixin.py
@@ -0,0 +1,94 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import abc
+import logging
+import time
+from typing import TYPE_CHECKING, Any, Optional
+
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+
+from ..._constants import DEFAULT_REFRESH_OFFSET, DEFAULT_TOKEN_REFRESH_RETRY_DELAY
+
+if TYPE_CHECKING:
+    from azure.core.credentials import AccessToken
+
+_LOGGER = logging.getLogger(__name__)
+
+
+class GetTokenMixin(abc.ABC):
+    def __init__(self, *args: "Any", **kwargs: "Any") -> None:
+        self._last_request_time = 0
+
+        # https://github.com/python/mypy/issues/5887
+        super(GetTokenMixin, self).__init__(*args, **kwargs)
+
+    @abc.abstractmethod
+    # pylint: disable-next=docstring-missing-param
+    async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]":
+        """Attempt to acquire an access token from a cache or by redeeming a refresh token."""
+
+    @abc.abstractmethod
+    # pylint: disable-next=docstring-missing-param
+    async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
+        """Request an access token from the STS."""
+
+    def _should_refresh(self, token: "AccessToken") -> bool:
+        now = int(time.time())
+        if token.expires_on - now > DEFAULT_REFRESH_OFFSET:
+            return False
+        if now - self._last_request_time < DEFAULT_TOKEN_REFRESH_RETRY_DELAY:
+            return False
+        return True
+
+    async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
+        """Request an access token for `scopes`.
+
+        This method is called automatically by Azure SDK clients.
+
+        :param scopes: Scopes to request access for
+        :type: str
+        :raises CredentialUnavailableError: the credential is unable to attempt authentication because it lacks
+            required data, state, or platform support
+        :raises ~azure.core.exceptions.ClientAuthenticationError: authentication failed. The error's ``message``
+            attribute gives a reason.
+        :return: The access token
+        :rtype: ~azure.core.credentials.AccessToken
+        """
+        if not scopes:
+            msg = '"get_token" requires at least one scope'
+            raise ValidationException(
+                message=msg,
+                no_personal_data_message=msg,
+                target=ErrorTarget.IDENTITY,
+                error_category=ErrorCategory.USER_ERROR,
+            )
+
+        try:
+            token = await self._acquire_token_silently(*scopes, **kwargs)
+            if not token:
+                self._last_request_time = int(time.time())
+                token = await self._request_token(*scopes, **kwargs)
+            elif self._should_refresh(token):
+                try:
+                    self._last_request_time = int(time.time())
+                    token = await self._request_token(*scopes, **kwargs)
+                except Exception:  # pylint:disable=broad-except
+                    pass
+            _LOGGER.log(
+                logging.INFO,
+                "%s.get_token succeeded",
+                self.__class__.__name__,
+            )
+            return token
+
+        except Exception as ex:
+            _LOGGER.log(
+                logging.WARNING,
+                "%s.get_token failed: %s",
+                self.__class__.__name__,
+                ex,
+                exc_info=_LOGGER.isEnabledFor(logging.DEBUG),
+            )
+            raise
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_base.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_base.py
new file mode 100644
index 00000000..4f3cc167
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_base.py
@@ -0,0 +1,55 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import abc
+from typing import TYPE_CHECKING, Any, Optional, cast
+
+from ..._exceptions import CredentialUnavailableError
+from . import AsyncContextManager
+from .get_token_mixin import GetTokenMixin
+from .managed_identity_client import AsyncManagedIdentityClient
+
+if TYPE_CHECKING:
+    from azure.core.credentials import AccessToken
+
+
+class AsyncManagedIdentityBase(AsyncContextManager, GetTokenMixin):
+    """Base class for internal credentials using AsyncManagedIdentityClient."""
+
+    def __init__(self, **kwargs: Any) -> None:
+        super().__init__()
+        self._client = self.get_client(**kwargs)
+
+    @abc.abstractmethod
+    def get_client(self, **kwargs: Any) -> "Optional[AsyncManagedIdentityClient]":
+        pass
+
+    @abc.abstractmethod
+    def get_unavailable_message(self) -> str:
+        pass
+
+    async def __aenter__(self) -> "AsyncManagedIdentityBase":
+        if self._client:
+            await self._client.__aenter__()
+        return self
+
+    async def __aexit__(self, *args: Any) -> None:
+        if self._client:
+            await self._client.__aexit__(*args)
+
+    async def close(self) -> None:
+        await self.__aexit__()
+
+    async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
+        if not self._client:
+            raise CredentialUnavailableError(message=self.get_unavailable_message())
+        return await super().get_token(*scopes, **kwargs)
+
+    async def _acquire_token_silently(self, *scopes: str, **kwargs: "Any") -> "Optional[AccessToken]":
+        # casting because mypy can't determine that these methods are called
+        # only by get_token, which raises when self._client is None
+        return cast(AsyncManagedIdentityClient, self._client).get_cached_token(*scopes)
+
+    async def _request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
+        return await cast(AsyncManagedIdentityClient, self._client).request_token(*scopes, **kwargs)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_client.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_client.py
new file mode 100644
index 00000000..01d951e9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/identity/_aio/_internal/managed_identity_client.py
@@ -0,0 +1,38 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import time
+from typing import TYPE_CHECKING, Any
+
+from ..._internal import _scopes_to_resource
+from ..._internal.managed_identity_client import ManagedIdentityClientBase
+from ..._internal.pipeline import build_async_pipeline
+from .._internal import AsyncContextManager
+
+if TYPE_CHECKING:
+    from azure.core.credentials import AccessToken
+    from azure.core.pipeline import AsyncPipeline
+
+
+# pylint:disable=async-client-bad-name
+class AsyncManagedIdentityClient(AsyncContextManager, ManagedIdentityClientBase):
+    async def __aenter__(self) -> "AsyncManagedIdentityClient":
+        await self._pipeline.__aenter__()
+        return self
+
+    async def close(self) -> None:
+        await self._pipeline.__aexit__()
+
+    async def request_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken":
+        # pylint:disable=invalid-overridden-method
+        resource = _scopes_to_resource(*scopes)
+        # pylint: disable=no-member
+        request = self._request_factory(resource, self._identity_config)  # type: ignore
+        request_time = int(time.time())
+        response = await self._pipeline.run(request, retry_on_methods=[request.method], **kwargs)
+        token = self._process_response(response=response, request_time=request_time, resource=resource)
+        return token
+
+    def _build_pipeline(self, **kwargs: "Any") -> "AsyncPipeline":
+        return build_async_pipeline(**kwargs)