diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_serverless_endpoint_operations.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/operations/_serverless_endpoint_operations.py | 223 |
1 files changed, 223 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_serverless_endpoint_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_serverless_endpoint_operations.py new file mode 100644 index 00000000..5efec117 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_serverless_endpoint_operations.py @@ -0,0 +1,223 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import re +from typing import Iterable + +from azure.ai.ml._restclient.v2024_01_01_preview import ( + AzureMachineLearningWorkspaces as ServiceClient202401Preview, +) +from azure.ai.ml._restclient.v2024_01_01_preview.models import ( + KeyType, + RegenerateEndpointKeysRequest, +) +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationsContainer, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml.constants._common import REGISTRY_VERSION_PATTERN, AzureMLResourceType +from azure.ai.ml.constants._endpoint import EndpointKeyType +from azure.ai.ml.entities._autogen_entities.models import ServerlessEndpoint +from azure.ai.ml.entities._endpoint.online_endpoint import EndpointAuthKeys +from azure.ai.ml.exceptions import ( + ErrorCategory, + ErrorTarget, + ValidationErrorType, + ValidationException, +) +from azure.core.polling import LROPoller + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class ServerlessEndpointOperations(_ScopeDependentOperations): + """ServerlessEndpointOperations. + + You should not instantiate this class directly. Instead, you should + create an MLClient instance that instantiates it for you and + attaches it as an attribute. + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client: ServiceClient202401Preview, + all_operations: OperationsContainer, + ): + super().__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._service_client = service_client.serverless_endpoints + self._marketplace_subscriptions = service_client.marketplace_subscriptions + self._all_operations = all_operations + + def _get_workspace_location(self) -> str: + return str( + self._all_operations.all_operations[AzureMLResourceType.WORKSPACE].get(self._workspace_name).location + ) + + @experimental + @monitor_with_activity(ops_logger, "ServerlessEndpoint.BeginCreateOrUpdate", ActivityType.PUBLICAPI) + def begin_create_or_update(self, endpoint: ServerlessEndpoint, **kwargs) -> LROPoller[ServerlessEndpoint]: + """Create or update a serverless endpoint. + + :param endpoint: The serverless endpoint entity. + :type endpoint: ~azure.ai.ml.entities.ServerlessEndpoint + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if ServerlessEndpoint cannot be + successfully validated. Details will be provided in the error message. + :return: A poller to track the operation status + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.ServerlessEndpoint] + """ + if not endpoint.location: + endpoint.location = self._get_workspace_location() + if re.match(REGISTRY_VERSION_PATTERN, endpoint.model_id): + msg = ( + "The given model_id {} points to a specific model version, which is not supported. " + "Please provide a model_id without the version information." + ) + raise ValidationException( + message=msg.format(endpoint.model_id), + no_personal_data_message="Invalid model_id given for serverless endpoint", + target=ErrorTarget.SERVERLESS_ENDPOINT, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + return self._service_client.begin_create_or_update( + self._resource_group_name, + self._workspace_name, + endpoint.name, + endpoint._to_rest_object(), # type: ignore + cls=( + lambda response, deserialized, headers: ServerlessEndpoint._from_rest_object( # type: ignore + deserialized + ) + ), + **kwargs, + ) + + @experimental + @monitor_with_activity(ops_logger, "ServerlessEndpoint.Get", ActivityType.PUBLICAPI) + def get(self, name: str, **kwargs) -> ServerlessEndpoint: + """Get a Serverless Endpoint resource. + + :param name: Name of the serverless endpoint. + :type name: str + :return: Serverless endpoint object retrieved from the service. + :rtype: ~azure.ai.ml.entities.ServerlessEndpoint + """ + return self._service_client.get( + self._resource_group_name, + self._workspace_name, + name, + cls=( + lambda response, deserialized, headers: ServerlessEndpoint._from_rest_object( # type: ignore + deserialized + ) + ), + **kwargs, + ) + + @experimental + @monitor_with_activity(ops_logger, "ServerlessEndpoint.list", ActivityType.PUBLICAPI) + def list(self, **kwargs) -> Iterable[ServerlessEndpoint]: + """List serverless endpoints of the workspace. + + :return: A list of serverless endpoints + :rtype: ~typing.Iterable[~azure.ai.ml.entities.ServerlessEndpoint] + """ + return self._service_client.list( + self._resource_group_name, + self._workspace_name, + cls=lambda objs: [ServerlessEndpoint._from_rest_object(obj) for obj in objs], # type: ignore + **kwargs, + ) + + @experimental + @monitor_with_activity(ops_logger, "ServerlessEndpoint.BeginDelete", ActivityType.PUBLICAPI) + def begin_delete(self, name: str, **kwargs) -> LROPoller[None]: + """Delete a Serverless Endpoint. + + :param name: Name of the serverless endpoint. + :type name: str + :return: A poller to track the operation status. + :rtype: ~azure.core.polling.LROPoller[None] + """ + return self._service_client.begin_delete( + self._resource_group_name, + self._workspace_name, + name, + **kwargs, + ) + + @experimental + @monitor_with_activity(ops_logger, "ServerlessEndpoint.GetKeys", ActivityType.PUBLICAPI) + def get_keys(self, name: str, **kwargs) -> EndpointAuthKeys: + """Get serveless endpoint auth keys. + + :param name: The serverless endpoint name + :type name: str + :return: Returns the keys of the serverless endpoint + :rtype: ~azure.ai.ml.entities.EndpointAuthKeys + """ + return self._service_client.list_keys( + self._resource_group_name, + self._workspace_name, + name, + cls=lambda response, deserialized, headers: EndpointAuthKeys._from_rest_object(deserialized), + **kwargs, + ) + + @experimental + @monitor_with_activity(ops_logger, "ServerlessEndpoint.BeginRegenerateKeys", ActivityType.PUBLICAPI) + def begin_regenerate_keys( + self, + name: str, + *, + key_type: str = EndpointKeyType.PRIMARY_KEY_TYPE, + **kwargs, + ) -> LROPoller[EndpointAuthKeys]: + """Regenerate keys for a serverless endpoint. + + :param name: The endpoint name. + :type name: str + :keyword key_type: One of "primary", "secondary". Defaults to "primary". + :paramtype key_type: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if key_type is not "primary" + or "secondary" + :return: A poller to track the operation status. + :rtype: ~azure.core.polling.LROPoller[EndpointAuthKeys] + """ + keys = self.get_keys( + name=name, + ) + if key_type.lower() == EndpointKeyType.PRIMARY_KEY_TYPE: + key_request = RegenerateEndpointKeysRequest(key_type=KeyType.Primary, key_value=keys.primary_key) + elif key_type.lower() == EndpointKeyType.SECONDARY_KEY_TYPE: + key_request = RegenerateEndpointKeysRequest(key_type=KeyType.Secondary, key_value=keys.secondary_key) + else: + msg = "Key type must be 'primary' or 'secondary'." + raise ValidationException( + message=msg, + target=ErrorTarget.SERVERLESS_ENDPOINT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + return self._service_client.begin_regenerate_keys( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=name, + body=key_request, + cls=lambda response, deserialized, headers: EndpointAuthKeys._from_rest_object(deserialized), + **kwargs, + ) |