# --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- # pylint: disable=protected-access import json from typing import Any, Dict, Optional, Union from marshmallow.exceptions import ValidationError as SchemaValidationError from azure.ai.ml._azure_environments import _resource_to_scopes from azure.ai.ml._exception_helper import log_and_raise_error from azure.ai.ml._restclient.v2022_02_01_preview import AzureMachineLearningWorkspaces as ServiceClient022022Preview from azure.ai.ml._restclient.v2022_02_01_preview.models import KeyType, RegenerateEndpointKeysRequest from azure.ai.ml._scope_dependent_operations import ( OperationConfig, OperationsContainer, OperationScope, _ScopeDependentOperations, ) from azure.ai.ml._telemetry import ActivityType, monitor_with_activity from azure.ai.ml._utils._azureml_polling import AzureMLPolling from azure.ai.ml._utils._endpoint_utils import validate_response from azure.ai.ml._utils._http_utils import HttpPipeline from azure.ai.ml._utils._logger_utils import OpsLogger from azure.ai.ml.constants._common import ( AAD_TOKEN, AAD_TOKEN_RESOURCE_ENDPOINT, EMPTY_CREDENTIALS_ERROR, KEY, AzureMLResourceType, LROConfigurations, ) from azure.ai.ml.constants._endpoint import EndpointInvokeFields, EndpointKeyType from azure.ai.ml.entities import OnlineDeployment, OnlineEndpoint from azure.ai.ml.entities._assets import Data from azure.ai.ml.entities._endpoint.online_endpoint import EndpointAadToken, EndpointAuthKeys, EndpointAuthToken from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, MlException, ValidationErrorType, ValidationException from azure.ai.ml.operations._local_endpoint_helper import _LocalEndpointHelper from azure.core.credentials import TokenCredential from azure.core.paging import ItemPaged from azure.core.polling import LROPoller from azure.core.tracing.decorator import distributed_trace from ._operation_orchestrator import OperationOrchestrator ops_logger = OpsLogger(__name__) module_logger = ops_logger.module_logger def _strip_zeroes_from_traffic(traffic: Dict[str, str]) -> Dict[str, str]: return {k.lower(): v for k, v in traffic.items() if v and int(v) != 0} class OnlineEndpointOperations(_ScopeDependentOperations): """OnlineEndpointOperations. You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it for you and attaches it as an attribute. """ def __init__( self, operation_scope: OperationScope, operation_config: OperationConfig, service_client_02_2022_preview: ServiceClient022022Preview, all_operations: OperationsContainer, local_endpoint_helper: _LocalEndpointHelper, credentials: Optional[TokenCredential] = None, **kwargs: Dict, ): super(OnlineEndpointOperations, self).__init__(operation_scope, operation_config) ops_logger.update_filter() self._online_operation = service_client_02_2022_preview.online_endpoints self._online_deployment_operation = service_client_02_2022_preview.online_deployments self._all_operations = all_operations self._local_endpoint_helper = local_endpoint_helper self._credentials = credentials self._init_kwargs = kwargs self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline") @distributed_trace @monitor_with_activity(ops_logger, "OnlineEndpoint.List", ActivityType.PUBLICAPI) def list(self, *, local: bool = False) -> ItemPaged[OnlineEndpoint]: """List endpoints of the workspace. :keyword local: (Optional) Flag to indicate whether to interact with endpoints in local Docker environment. Default: False :type local: bool :return: A list of endpoints :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.OnlineEndpoint] """ if local: return self._local_endpoint_helper.list() return self._online_operation.list( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, cls=lambda objs: [OnlineEndpoint._from_rest_object(obj) for obj in objs], **self._init_kwargs, ) @distributed_trace @monitor_with_activity(ops_logger, "OnlineEndpoint.ListKeys", ActivityType.PUBLICAPI) def get_keys(self, name: str) -> Union[EndpointAuthKeys, EndpointAuthToken, EndpointAadToken]: """Get the auth credentials. :param name: The endpoint name :type name: str :raise: Exception if cannot get online credentials :return: Depending on the auth mode in the endpoint, returns either keys or token :rtype: Union[~azure.ai.ml.entities.EndpointAuthKeys, ~azure.ai.ml.entities.EndpointAuthToken] """ return self._get_online_credentials(name=name) @distributed_trace @monitor_with_activity(ops_logger, "OnlineEndpoint.Get", ActivityType.PUBLICAPI) def get( self, name: str, *, local: bool = False, ) -> OnlineEndpoint: """Get a Endpoint resource. :param name: Name of the endpoint. :type name: str :keyword local: Indicates whether to interact with endpoints in local Docker environment. Defaults to False. :paramtype local: Optional[bool] :raises ~azure.ai.ml.exceptions.LocalEndpointNotFoundError: Raised if local endpoint resource does not exist. :return: Endpoint object retrieved from the service. :rtype: ~azure.ai.ml.entities.OnlineEndpoint """ # first get the endpoint if local: return self._local_endpoint_helper.get(endpoint_name=name) endpoint = self._online_operation.get( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=name, **self._init_kwargs, ) deployments_list = self._online_deployment_operation.list( endpoint_name=name, resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, cls=lambda objs: [OnlineDeployment._from_rest_object(obj) for obj in objs], **self._init_kwargs, ) # populate deployments without traffic with zeroes in traffic map converted_endpoint = OnlineEndpoint._from_rest_object(endpoint) if deployments_list: for deployment in deployments_list: if not converted_endpoint.traffic.get(deployment.name) and not converted_endpoint.mirror_traffic.get( deployment.name ): converted_endpoint.traffic[deployment.name] = 0 return converted_endpoint @distributed_trace @monitor_with_activity(ops_logger, "OnlineEndpoint.BeginDelete", ActivityType.PUBLICAPI) def begin_delete(self, name: Optional[str] = None, *, local: bool = False) -> LROPoller[None]: """Delete an Online Endpoint. :param name: Name of the endpoint. :type name: str :keyword local: Whether to interact with the endpoint in local Docker environment. Defaults to False. :paramtype local: bool :raises ~azure.ai.ml.exceptions.LocalEndpointNotFoundError: Raised if local endpoint resource does not exist. :return: A poller to track the operation status if remote, else returns None if local. :rtype: ~azure.core.polling.LROPoller[None] """ if local: return self._local_endpoint_helper.delete(name=str(name)) path_format_arguments = { "endpointName": name, "resourceGroupName": self._resource_group_name, "workspaceName": self._workspace_name, } delete_poller = self._online_operation.begin_delete( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=name, polling=AzureMLPolling( LROConfigurations.POLL_INTERVAL, path_format_arguments=path_format_arguments, **self._init_kwargs, ), polling_interval=LROConfigurations.POLL_INTERVAL, **self._init_kwargs, ) return delete_poller @distributed_trace @monitor_with_activity(ops_logger, "OnlineEndpoint.BeginDeleteOrUpdate", ActivityType.PUBLICAPI) def begin_create_or_update(self, endpoint: OnlineEndpoint, *, local: bool = False) -> LROPoller[OnlineEndpoint]: """Create or update an endpoint. :param endpoint: The endpoint entity. :type endpoint: ~azure.ai.ml.entities.OnlineEndpoint :keyword local: Whether to interact with the endpoint in local Docker environment. Defaults to False. :paramtype local: bool :raises ~azure.ai.ml.exceptions.ValidationException: Raised if OnlineEndpoint cannot be successfully validated. Details will be provided in the error message. :raises ~azure.ai.ml.exceptions.AssetException: Raised if OnlineEndpoint assets (e.g. Data, Code, Model, Environment) cannot be successfully validated. Details will be provided in the error message. :raises ~azure.ai.ml.exceptions.ModelException: Raised if OnlineEndpoint model cannot be successfully validated. Details will be provided in the error message. :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty directory. :raises ~azure.ai.ml.exceptions.LocalEndpointNotFoundError: Raised if local endpoint resource does not exist. :return: A poller to track the operation status if remote, else returns None if local. :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.OnlineEndpoint] """ try: if local: return self._local_endpoint_helper.create_or_update(endpoint=endpoint) try: location = self._get_workspace_location() if endpoint.traffic: endpoint.traffic = _strip_zeroes_from_traffic(endpoint.traffic) if endpoint.mirror_traffic: endpoint.mirror_traffic = _strip_zeroes_from_traffic(endpoint.mirror_traffic) endpoint_resource = endpoint._to_rest_online_endpoint(location=location) orchestrators = OperationOrchestrator( operation_container=self._all_operations, operation_scope=self._operation_scope, operation_config=self._operation_config, ) if hasattr(endpoint_resource.properties, "compute"): endpoint_resource.properties.compute = orchestrators.get_asset_arm_id( endpoint_resource.properties.compute, azureml_type=AzureMLResourceType.COMPUTE, ) poller = self._online_operation.begin_create_or_update( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=endpoint.name, body=endpoint_resource, cls=lambda response, deserialized, headers: OnlineEndpoint._from_rest_object(deserialized), **self._init_kwargs, ) return poller except Exception as ex: raise ex except Exception as ex: # pylint: disable=W0718 if isinstance(ex, (ValidationException, SchemaValidationError)): log_and_raise_error(ex) else: raise ex @distributed_trace @monitor_with_activity(ops_logger, "OnlineEndpoint.BeginGenerateKeys", ActivityType.PUBLICAPI) def begin_regenerate_keys( self, name: str, *, key_type: str = EndpointKeyType.PRIMARY_KEY_TYPE, ) -> LROPoller[None]: """Regenerate keys for endpoint. :param name: The endpoint name. :type name: The endpoint type. Defaults to ONLINE_ENDPOINT_TYPE. :keyword key_type: One of "primary", "secondary". Defaults to "primary". :paramtype key_type: str :return: A poller to track the operation status. :rtype: ~azure.core.polling.LROPoller[None] """ endpoint = self._online_operation.get( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=name, **self._init_kwargs, ) if endpoint.properties.auth_mode.lower() == "key": return self._regenerate_online_keys(name=name, key_type=key_type) raise ValidationException( message=f"Endpoint '{name}' does not use keys for authentication.", target=ErrorTarget.ONLINE_ENDPOINT, no_personal_data_message="Endpoint does not use keys for authentication.", error_category=ErrorCategory.USER_ERROR, error_type=ValidationErrorType.INVALID_VALUE, ) @distributed_trace @monitor_with_activity(ops_logger, "OnlineEndpoint.Invoke", ActivityType.PUBLICAPI) def invoke( self, endpoint_name: str, *, request_file: Optional[str] = None, deployment_name: Optional[str] = None, # pylint: disable=unused-argument input_data: Optional[Union[str, Data]] = None, params_override: Any = None, local: bool = False, **kwargs: Any, ) -> str: """Invokes the endpoint with the provided payload. :param endpoint_name: The endpoint name :type endpoint_name: str :keyword request_file: File containing the request payload. This is only valid for online endpoint. :paramtype request_file: Optional[str] :keyword deployment_name: Name of a specific deployment to invoke. This is optional. By default requests are routed to any of the deployments according to the traffic rules. :paramtype deployment_name: Optional[str] :keyword input_data: To use a pre-registered data asset, pass str in format :paramtype input_data: Optional[Union[str, Data]] :keyword params_override: A dictionary of payload parameters to override and their desired values. :paramtype params_override: Any :keyword local: Indicates whether to interact with endpoints in local Docker environment. Defaults to False. :paramtype local: Optional[bool] :raises ~azure.ai.ml.exceptions.LocalEndpointNotFoundError: Raised if local endpoint resource does not exist. :raises ~azure.ai.ml.exceptions.MultipleLocalDeploymentsFoundError: Raised if there are multiple deployments and no deployment_name is specified. :raises ~azure.ai.ml.exceptions.InvalidLocalEndpointError: Raised if local endpoint is None. :return: Prediction output for online endpoint. :rtype: str """ params_override = params_override or [] # Until this bug is resolved https://msdata.visualstudio.com/Vienna/_workitems/edit/1446538 if deployment_name: self._validate_deployment_name(endpoint_name, deployment_name) with open(request_file, "rb") as f: # type: ignore[arg-type] data = json.loads(f.read()) if local: return self._local_endpoint_helper.invoke( endpoint_name=endpoint_name, data=data, deployment_name=deployment_name ) endpoint = self._online_operation.get( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=endpoint_name, **self._init_kwargs, ) keys = self._get_online_credentials(name=endpoint_name, auth_mode=endpoint.properties.auth_mode) if isinstance(keys, EndpointAuthKeys): key = keys.primary_key elif isinstance(keys, (EndpointAuthToken, EndpointAadToken)): key = keys.access_token else: key = "" headers = EndpointInvokeFields.DEFAULT_HEADER if key: headers[EndpointInvokeFields.AUTHORIZATION] = f"Bearer {key}" if deployment_name: headers[EndpointInvokeFields.MODEL_DEPLOYMENT] = deployment_name response = self._requests_pipeline.post(endpoint.properties.scoring_uri, json=data, headers=headers) validate_response(response) return str(response.text()) def _get_workspace_location(self) -> str: return str( self._all_operations.all_operations[AzureMLResourceType.WORKSPACE].get(self._workspace_name).location ) def _get_online_credentials( self, name: str, auth_mode: Optional[str] = None ) -> Union[EndpointAuthKeys, EndpointAuthToken, EndpointAadToken]: if not auth_mode: endpoint = self._online_operation.get( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=name, **self._init_kwargs, ) auth_mode = endpoint.properties.auth_mode if auth_mode is not None and auth_mode.lower() == KEY: return self._online_operation.list_keys( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=name, # pylint: disable=protected-access cls=lambda x, response, z: EndpointAuthKeys._from_rest_object(response), **self._init_kwargs, ) if auth_mode is not None and auth_mode.lower() == AAD_TOKEN: if self._credentials: return EndpointAadToken(self._credentials.get_token(*_resource_to_scopes(AAD_TOKEN_RESOURCE_ENDPOINT))) msg = EMPTY_CREDENTIALS_ERROR raise MlException(message=msg, no_personal_data_message=msg) return self._online_operation.get_token( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=name, # pylint: disable=protected-access cls=lambda x, response, z: EndpointAuthToken._from_rest_object(response), **self._init_kwargs, ) def _regenerate_online_keys( self, name: str, key_type: str = EndpointKeyType.PRIMARY_KEY_TYPE, ) -> LROPoller[None]: keys = self._online_operation.list_keys( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=name, **self._init_kwargs, ) if key_type.lower() == EndpointKeyType.PRIMARY_KEY_TYPE: key_request = RegenerateEndpointKeysRequest(key_type=KeyType.Primary, key_value=keys.primary_key) elif key_type.lower() == EndpointKeyType.SECONDARY_KEY_TYPE: key_request = RegenerateEndpointKeysRequest(key_type=KeyType.Secondary, key_value=keys.secondary_key) else: msg = "Key type must be 'primary' or 'secondary'." raise ValidationException( message=msg, target=ErrorTarget.ONLINE_ENDPOINT, no_personal_data_message=msg, error_category=ErrorCategory.USER_ERROR, error_type=ValidationErrorType.INVALID_VALUE, ) poller = self._online_operation.begin_regenerate_keys( resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, endpoint_name=name, body=key_request, **self._init_kwargs, ) return poller def _validate_deployment_name(self, endpoint_name: str, deployment_name: str) -> None: deployments_list = self._online_deployment_operation.list( endpoint_name=endpoint_name, resource_group_name=self._resource_group_name, workspace_name=self._workspace_name, cls=lambda objs: [obj.name for obj in objs], **self._init_kwargs, ) if deployments_list: if deployment_name not in deployments_list: raise ValidationException( message=f"Deployment name {deployment_name} not found for this endpoint", target=ErrorTarget.ONLINE_ENDPOINT, no_personal_data_message="Deployment name not found for this endpoint", error_category=ErrorCategory.USER_ERROR, error_type=ValidationErrorType.RESOURCE_NOT_FOUND, ) else: msg = "No deployment exists for this endpoint" raise ValidationException( message=msg, target=ErrorTarget.ONLINE_ENDPOINT, no_personal_data_message=msg, error_category=ErrorCategory.USER_ERROR, error_type=ValidationErrorType.RESOURCE_NOT_FOUND, )