about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_serverless_endpoint_operations.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_serverless_endpoint_operations.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_serverless_endpoint_operations.py223
1 files changed, 223 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_serverless_endpoint_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_serverless_endpoint_operations.py
new file mode 100644
index 00000000..5efec117
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_serverless_endpoint_operations.py
@@ -0,0 +1,223 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import re
+from typing import Iterable
+
+from azure.ai.ml._restclient.v2024_01_01_preview import (
+    AzureMachineLearningWorkspaces as ServiceClient202401Preview,
+)
+from azure.ai.ml._restclient.v2024_01_01_preview.models import (
+    KeyType,
+    RegenerateEndpointKeysRequest,
+)
+from azure.ai.ml._scope_dependent_operations import (
+    OperationConfig,
+    OperationsContainer,
+    OperationScope,
+    _ScopeDependentOperations,
+)
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml.constants._common import REGISTRY_VERSION_PATTERN, AzureMLResourceType
+from azure.ai.ml.constants._endpoint import EndpointKeyType
+from azure.ai.ml.entities._autogen_entities.models import ServerlessEndpoint
+from azure.ai.ml.entities._endpoint.online_endpoint import EndpointAuthKeys
+from azure.ai.ml.exceptions import (
+    ErrorCategory,
+    ErrorTarget,
+    ValidationErrorType,
+    ValidationException,
+)
+from azure.core.polling import LROPoller
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class ServerlessEndpointOperations(_ScopeDependentOperations):
+    """ServerlessEndpointOperations.
+
+    You should not instantiate this class directly. Instead, you should
+    create an MLClient instance that instantiates it for you and
+    attaches it as an attribute.
+    """
+
+    def __init__(
+        self,
+        operation_scope: OperationScope,
+        operation_config: OperationConfig,
+        service_client: ServiceClient202401Preview,
+        all_operations: OperationsContainer,
+    ):
+        super().__init__(operation_scope, operation_config)
+        ops_logger.update_filter()
+        self._service_client = service_client.serverless_endpoints
+        self._marketplace_subscriptions = service_client.marketplace_subscriptions
+        self._all_operations = all_operations
+
+    def _get_workspace_location(self) -> str:
+        return str(
+            self._all_operations.all_operations[AzureMLResourceType.WORKSPACE].get(self._workspace_name).location
+        )
+
+    @experimental
+    @monitor_with_activity(ops_logger, "ServerlessEndpoint.BeginCreateOrUpdate", ActivityType.PUBLICAPI)
+    def begin_create_or_update(self, endpoint: ServerlessEndpoint, **kwargs) -> LROPoller[ServerlessEndpoint]:
+        """Create or update a serverless endpoint.
+
+        :param endpoint: The serverless endpoint entity.
+        :type endpoint: ~azure.ai.ml.entities.ServerlessEndpoint
+        :raises ~azure.ai.ml.exceptions.ValidationException: Raised if ServerlessEndpoint cannot be
+            successfully validated. Details will be provided in the error message.
+        :return: A poller to track the operation status
+        :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.ServerlessEndpoint]
+        """
+        if not endpoint.location:
+            endpoint.location = self._get_workspace_location()
+        if re.match(REGISTRY_VERSION_PATTERN, endpoint.model_id):
+            msg = (
+                "The given model_id {} points to a specific model version, which is not supported. "
+                "Please provide a model_id without the version information."
+            )
+            raise ValidationException(
+                message=msg.format(endpoint.model_id),
+                no_personal_data_message="Invalid model_id given for serverless endpoint",
+                target=ErrorTarget.SERVERLESS_ENDPOINT,
+                error_category=ErrorCategory.USER_ERROR,
+                error_type=ValidationErrorType.INVALID_VALUE,
+            )
+        return self._service_client.begin_create_or_update(
+            self._resource_group_name,
+            self._workspace_name,
+            endpoint.name,
+            endpoint._to_rest_object(),  # type: ignore
+            cls=(
+                lambda response, deserialized, headers: ServerlessEndpoint._from_rest_object(  # type: ignore
+                    deserialized
+                )
+            ),
+            **kwargs,
+        )
+
+    @experimental
+    @monitor_with_activity(ops_logger, "ServerlessEndpoint.Get", ActivityType.PUBLICAPI)
+    def get(self, name: str, **kwargs) -> ServerlessEndpoint:
+        """Get a Serverless Endpoint resource.
+
+        :param name: Name of the serverless endpoint.
+        :type name: str
+        :return: Serverless endpoint object retrieved from the service.
+        :rtype: ~azure.ai.ml.entities.ServerlessEndpoint
+        """
+        return self._service_client.get(
+            self._resource_group_name,
+            self._workspace_name,
+            name,
+            cls=(
+                lambda response, deserialized, headers: ServerlessEndpoint._from_rest_object(  # type: ignore
+                    deserialized
+                )
+            ),
+            **kwargs,
+        )
+
+    @experimental
+    @monitor_with_activity(ops_logger, "ServerlessEndpoint.list", ActivityType.PUBLICAPI)
+    def list(self, **kwargs) -> Iterable[ServerlessEndpoint]:
+        """List serverless endpoints of the workspace.
+
+        :return: A list of serverless endpoints
+        :rtype: ~typing.Iterable[~azure.ai.ml.entities.ServerlessEndpoint]
+        """
+        return self._service_client.list(
+            self._resource_group_name,
+            self._workspace_name,
+            cls=lambda objs: [ServerlessEndpoint._from_rest_object(obj) for obj in objs],  # type: ignore
+            **kwargs,
+        )
+
+    @experimental
+    @monitor_with_activity(ops_logger, "ServerlessEndpoint.BeginDelete", ActivityType.PUBLICAPI)
+    def begin_delete(self, name: str, **kwargs) -> LROPoller[None]:
+        """Delete a Serverless Endpoint.
+
+        :param name: Name of the serverless endpoint.
+        :type name: str
+        :return: A poller to track the operation status.
+        :rtype: ~azure.core.polling.LROPoller[None]
+        """
+        return self._service_client.begin_delete(
+            self._resource_group_name,
+            self._workspace_name,
+            name,
+            **kwargs,
+        )
+
+    @experimental
+    @monitor_with_activity(ops_logger, "ServerlessEndpoint.GetKeys", ActivityType.PUBLICAPI)
+    def get_keys(self, name: str, **kwargs) -> EndpointAuthKeys:
+        """Get serveless endpoint auth keys.
+
+        :param name: The serverless endpoint name
+        :type name: str
+        :return: Returns the keys of the serverless endpoint
+        :rtype: ~azure.ai.ml.entities.EndpointAuthKeys
+        """
+        return self._service_client.list_keys(
+            self._resource_group_name,
+            self._workspace_name,
+            name,
+            cls=lambda response, deserialized, headers: EndpointAuthKeys._from_rest_object(deserialized),
+            **kwargs,
+        )
+
+    @experimental
+    @monitor_with_activity(ops_logger, "ServerlessEndpoint.BeginRegenerateKeys", ActivityType.PUBLICAPI)
+    def begin_regenerate_keys(
+        self,
+        name: str,
+        *,
+        key_type: str = EndpointKeyType.PRIMARY_KEY_TYPE,
+        **kwargs,
+    ) -> LROPoller[EndpointAuthKeys]:
+        """Regenerate keys for a serverless endpoint.
+
+        :param name: The endpoint name.
+        :type name: str
+        :keyword key_type: One of "primary", "secondary". Defaults to "primary".
+        :paramtype key_type: str
+        :raises ~azure.ai.ml.exceptions.ValidationException: Raised if key_type is not "primary"
+            or "secondary"
+        :return: A poller to track the operation status.
+        :rtype: ~azure.core.polling.LROPoller[EndpointAuthKeys]
+        """
+        keys = self.get_keys(
+            name=name,
+        )
+        if key_type.lower() == EndpointKeyType.PRIMARY_KEY_TYPE:
+            key_request = RegenerateEndpointKeysRequest(key_type=KeyType.Primary, key_value=keys.primary_key)
+        elif key_type.lower() == EndpointKeyType.SECONDARY_KEY_TYPE:
+            key_request = RegenerateEndpointKeysRequest(key_type=KeyType.Secondary, key_value=keys.secondary_key)
+        else:
+            msg = "Key type must be 'primary' or 'secondary'."
+            raise ValidationException(
+                message=msg,
+                target=ErrorTarget.SERVERLESS_ENDPOINT,
+                no_personal_data_message=msg,
+                error_category=ErrorCategory.USER_ERROR,
+                error_type=ValidationErrorType.INVALID_VALUE,
+            )
+
+        return self._service_client.begin_regenerate_keys(
+            resource_group_name=self._resource_group_name,
+            workspace_name=self._workspace_name,
+            endpoint_name=name,
+            body=key_request,
+            cls=lambda response, deserialized, headers: EndpointAuthKeys._from_rest_object(deserialized),
+            **kwargs,
+        )