# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
# pylint: disable=protected-access
import ast
import concurrent.futures
import logging
import time
from concurrent.futures import Future
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union
from azure.ai.ml._restclient.v2020_09_01_dataplanepreview.models import DataVersion, UriFileJobOutput
from azure.ai.ml._utils._arm_id_utils import is_ARM_id_for_resource, is_registry_id_for_resource
from azure.ai.ml._utils._logger_utils import initialize_logger_info
from azure.ai.ml.constants._common import ARM_ID_PREFIX, AzureMLResourceType, DefaultOpenEncoding, LROConfigurations
from azure.ai.ml.entities import BatchDeployment
from azure.ai.ml.entities._assets._artifacts.code import Code
from azure.ai.ml.entities._deployment.deployment import Deployment
from azure.ai.ml.entities._deployment.model_batch_deployment import ModelBatchDeployment
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, MlException, ValidationErrorType, ValidationException
from azure.ai.ml.operations._operation_orchestrator import OperationOrchestrator
from azure.core.exceptions import (
ClientAuthenticationError,
HttpResponseError,
ResourceExistsError,
ResourceNotFoundError,
ServiceRequestTimeoutError,
map_error,
)
from azure.core.polling import LROPoller
from azure.core.rest import HttpResponse
from azure.mgmt.core.exceptions import ARMErrorFormat
module_logger = logging.getLogger(__name__)
initialize_logger_info(module_logger, terminator="")
def get_duration(start_time: float) -> None:
"""Calculates the duration of the Long running operation took to finish.
:param start_time: Start time
:type start_time: float
"""
end_time = time.time()
duration = divmod(int(round(end_time - start_time)), 60)
module_logger.warning("(%sm %ss)\n", duration[0], duration[1])
def polling_wait(
poller: Union[LROPoller, Future],
message: Optional[str] = None,
start_time: Optional[float] = None,
is_local=False,
timeout=LROConfigurations.POLLING_TIMEOUT,
) -> Any:
"""Print out status while polling and time of operation once completed.
:param poller: An poller which will return status update via function done().
:type poller: Union[LROPoller, concurrent.futures.Future]
:param (str, optional) message: Message to print out before starting operation write-out.
:param (float, optional) start_time: Start time of operation.
:param (bool, optional) is_local: If poller is for a local endpoint, so the timeout is removed.
:param (int, optional) timeout: New value to overwrite the default timeout.
"""
module_logger.warning("%s", message)
if is_local:
# We removed timeout on local endpoints in case it takes a long time
# to pull image or install conda env.
# We want user to be able to see that.
while not poller.done():
module_logger.warning(".")
time.sleep(LROConfigurations.SLEEP_TIME)
else:
poller.result(timeout=timeout)
if poller.done():
module_logger.warning("Done ")
else:
module_logger.warning("Timeout waiting for long running operation")
if start_time:
get_duration(start_time)
def local_endpoint_polling_wrapper(func: Callable, message: str, **kwargs) -> Any:
"""Wrapper for polling local endpoint operations.
:param func: Name of the endpoint.
:type func: Callable
:param message: Message to print out before starting operation write-out.
:type message: str
:return: The type returned by Func
"""
pool = concurrent.futures.ThreadPoolExecutor()
start_time = time.time()
event = pool.submit(func, **kwargs)
polling_wait(poller=event, start_time=start_time, message=message, is_local=True)
return event.result()
def validate_response(response: HttpResponse) -> None:
"""Validates the response of POST requests, throws on error.
:param HttpResponse response: the response of a POST requests
:raises Exception: Raised when response is not json serializable
:raises HttpResponseError: Raised when the response signals that an error occurred
"""
r_json = {}
if response.status_code not in [200, 201]:
# We need to check for an empty response body or catch the exception raised.
# It is possible the server responded with a 204 No Content response, and json parsing fails.
if response.status_code != 204:
try:
r_json = response.json()
except ValueError as e:
# exception is not in the json format
msg = response.content.decode("utf-8")
raise MlException(message=msg, no_personal_data_message=msg) from e
failure_msg = r_json.get("error", {}).get("message", response)
error_map = {
401: ClientAuthenticationError,
404: ResourceNotFoundError,
408: ServiceRequestTimeoutError,
409: ResourceExistsError,
424: HttpResponseError,
}
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response, message=failure_msg, error_format=ARMErrorFormat)
def upload_dependencies(deployment: Deployment, orchestrators: OperationOrchestrator) -> None:
"""Upload code, dependency, model dependencies. For BatchDeployment only register compute.
:param Deployment deployment: Endpoint deployment object.
:param OperationOrchestrator orchestrators: Operation Orchestrator.
"""
module_logger.debug("Uploading the dependencies for deployment %s", deployment.name)
# Create a code asset if code is not already an ARM ID
if (
deployment.code_configuration
and not is_ARM_id_for_resource(deployment.code_configuration.code, AzureMLResourceType.CODE)
and not is_registry_id_for_resource(deployment.code_configuration.code)
):
if deployment.code_configuration.code.startswith(ARM_ID_PREFIX):
deployment.code_configuration.code = orchestrators.get_asset_arm_id(
deployment.code_configuration.code[len(ARM_ID_PREFIX) :],
azureml_type=AzureMLResourceType.CODE,
)
else:
deployment.code_configuration.code = orchestrators.get_asset_arm_id(
Code(base_path=deployment._base_path, path=deployment.code_configuration.code),
azureml_type=AzureMLResourceType.CODE,
)
if not is_registry_id_for_resource(deployment.environment):
deployment.environment = (
orchestrators.get_asset_arm_id(deployment.environment, azureml_type=AzureMLResourceType.ENVIRONMENT)
if deployment.environment
else None
)
if not is_registry_id_for_resource(deployment.model):
deployment.model = (
orchestrators.get_asset_arm_id(deployment.model, azureml_type=AzureMLResourceType.MODEL)
if deployment.model
else None
)
if isinstance(deployment, (BatchDeployment, ModelBatchDeployment)) and deployment.compute:
deployment.compute = orchestrators.get_asset_arm_id(
deployment.compute, azureml_type=AzureMLResourceType.COMPUTE
)
def validate_scoring_script(deployment):
score_script_path = Path(deployment.base_path).joinpath(
deployment.code_configuration.code, deployment.scoring_script
)
try:
with open(score_script_path, "r", encoding=DefaultOpenEncoding.READ) as script:
contents = script.read()
try:
ast.parse(contents, score_script_path)
except SyntaxError as err:
err.filename = err.filename.split("/")[-1]
msg = (
f"Failed to submit deployment {deployment.name} due to syntax errors "
f"in scoring script {err.filename}.\nError on line {err.lineno}: "
f"{err.text}\nIf you wish to bypass this validation use --skip-script-validation paramater."
)
np_msg = (
"Failed to submit deployment due to syntax errors in deployment script."
"\n If you wish to bypass this validation use --skip-script-validation paramater."
)
raise ValidationException(
message=msg,
target=(
ErrorTarget.BATCH_DEPLOYMENT
if isinstance(deployment, BatchDeployment)
else ErrorTarget.ONLINE_DEPLOYMENT
),
no_personal_data_message=np_msg,
error_category=ErrorCategory.USER_ERROR,
error_type=ValidationErrorType.CANNOT_PARSE,
) from err
except OSError as err:
raise MlException(
message=f"Failed to open scoring script {err.filename}.",
no_personal_data_message="Failed to open scoring script.",
) from err
def convert_v1_dataset_to_v2(output_data_set: DataVersion, file_name: str) -> Dict[str, Any]:
if file_name:
v2_dataset = UriFileJobOutput(
uri=f"azureml://datastores/{output_data_set.datastore_id}/paths/{output_data_set.path}/{file_name}"
).serialize()
else:
v2_dataset = UriFileJobOutput(
uri=f"azureml://datastores/{output_data_set.datastore_id}/paths/{output_data_set.path}"
).serialize()
return {"output_name": v2_dataset}