aboutsummaryrefslogtreecommitdiff
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=protected-access

import ast
import concurrent.futures
import logging
import time
from concurrent.futures import Future
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union

from azure.ai.ml._restclient.v2020_09_01_dataplanepreview.models import DataVersion, UriFileJobOutput
from azure.ai.ml._utils._arm_id_utils import is_ARM_id_for_resource, is_registry_id_for_resource
from azure.ai.ml._utils._logger_utils import initialize_logger_info
from azure.ai.ml.constants._common import ARM_ID_PREFIX, AzureMLResourceType, DefaultOpenEncoding, LROConfigurations
from azure.ai.ml.entities import BatchDeployment
from azure.ai.ml.entities._assets._artifacts.code import Code
from azure.ai.ml.entities._deployment.deployment import Deployment
from azure.ai.ml.entities._deployment.model_batch_deployment import ModelBatchDeployment
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, MlException, ValidationErrorType, ValidationException
from azure.ai.ml.operations._operation_orchestrator import OperationOrchestrator
from azure.core.exceptions import (
    ClientAuthenticationError,
    HttpResponseError,
    ResourceExistsError,
    ResourceNotFoundError,
    ServiceRequestTimeoutError,
    map_error,
)
from azure.core.polling import LROPoller
from azure.core.rest import HttpResponse
from azure.mgmt.core.exceptions import ARMErrorFormat

module_logger = logging.getLogger(__name__)
initialize_logger_info(module_logger, terminator="")


def get_duration(start_time: float) -> None:
    """Calculates the duration of the Long running operation took to finish.

    :param start_time: Start time
    :type start_time: float
    """
    end_time = time.time()
    duration = divmod(int(round(end_time - start_time)), 60)
    module_logger.warning("(%sm %ss)\n", duration[0], duration[1])


def polling_wait(
    poller: Union[LROPoller, Future],
    message: Optional[str] = None,
    start_time: Optional[float] = None,
    is_local=False,
    timeout=LROConfigurations.POLLING_TIMEOUT,
) -> Any:
    """Print out status while polling and time of operation once completed.

    :param poller: An poller which will return status update via function done().
    :type poller: Union[LROPoller, concurrent.futures.Future]
    :param (str, optional) message: Message to print out before starting operation write-out.
    :param (float, optional) start_time: Start time of operation.
    :param (bool, optional) is_local: If poller is for a local endpoint, so the timeout is removed.
    :param (int, optional) timeout: New value to overwrite the default timeout.
    """
    module_logger.warning("%s", message)
    if is_local:
        # We removed timeout on local endpoints in case it takes a long time
        # to pull image or install conda env.

        # We want user to be able to see that.

        while not poller.done():
            module_logger.warning(".")
            time.sleep(LROConfigurations.SLEEP_TIME)
    else:
        poller.result(timeout=timeout)

    if poller.done():
        module_logger.warning("Done ")
    else:
        module_logger.warning("Timeout waiting for long running operation")

    if start_time:
        get_duration(start_time)


def local_endpoint_polling_wrapper(func: Callable, message: str, **kwargs) -> Any:
    """Wrapper for polling local endpoint operations.

    :param func: Name of the endpoint.
    :type func: Callable
    :param message: Message to print out before starting operation write-out.
    :type message: str
    :return: The type returned by Func
    """
    pool = concurrent.futures.ThreadPoolExecutor()
    start_time = time.time()
    event = pool.submit(func, **kwargs)
    polling_wait(poller=event, start_time=start_time, message=message, is_local=True)
    return event.result()


def validate_response(response: HttpResponse) -> None:
    """Validates the response of POST requests, throws on error.

    :param HttpResponse response: the response of a POST requests
    :raises Exception: Raised when response is not json serializable
    :raises HttpResponseError: Raised when the response signals that an error occurred
    """
    r_json = {}

    if response.status_code not in [200, 201]:
        # We need to check for an empty response body or catch the exception raised.
        # It is possible the server responded with a 204 No Content response, and json parsing fails.
        if response.status_code != 204:
            try:
                r_json = response.json()
            except ValueError as e:
                # exception is not in the json format
                msg = response.content.decode("utf-8")
                raise MlException(message=msg, no_personal_data_message=msg) from e
        failure_msg = r_json.get("error", {}).get("message", response)
        error_map = {
            401: ClientAuthenticationError,
            404: ResourceNotFoundError,
            408: ServiceRequestTimeoutError,
            409: ResourceExistsError,
            424: HttpResponseError,
        }
        map_error(status_code=response.status_code, response=response, error_map=error_map)
        raise HttpResponseError(response=response, message=failure_msg, error_format=ARMErrorFormat)


def upload_dependencies(deployment: Deployment, orchestrators: OperationOrchestrator) -> None:
    """Upload code, dependency, model dependencies. For BatchDeployment only register compute.

    :param Deployment deployment: Endpoint deployment object.
    :param OperationOrchestrator orchestrators: Operation Orchestrator.
    """

    module_logger.debug("Uploading the dependencies for deployment %s", deployment.name)

    # Create a code asset if code is not already an ARM ID
    if (
        deployment.code_configuration
        and not is_ARM_id_for_resource(deployment.code_configuration.code, AzureMLResourceType.CODE)
        and not is_registry_id_for_resource(deployment.code_configuration.code)
    ):
        if deployment.code_configuration.code.startswith(ARM_ID_PREFIX):
            deployment.code_configuration.code = orchestrators.get_asset_arm_id(
                deployment.code_configuration.code[len(ARM_ID_PREFIX) :],
                azureml_type=AzureMLResourceType.CODE,
            )
        else:
            deployment.code_configuration.code = orchestrators.get_asset_arm_id(
                Code(base_path=deployment._base_path, path=deployment.code_configuration.code),
                azureml_type=AzureMLResourceType.CODE,
            )

    if not is_registry_id_for_resource(deployment.environment):
        deployment.environment = (
            orchestrators.get_asset_arm_id(deployment.environment, azureml_type=AzureMLResourceType.ENVIRONMENT)
            if deployment.environment
            else None
        )
    if not is_registry_id_for_resource(deployment.model):
        deployment.model = (
            orchestrators.get_asset_arm_id(deployment.model, azureml_type=AzureMLResourceType.MODEL)
            if deployment.model
            else None
        )
    if isinstance(deployment, (BatchDeployment, ModelBatchDeployment)) and deployment.compute:
        deployment.compute = orchestrators.get_asset_arm_id(
            deployment.compute, azureml_type=AzureMLResourceType.COMPUTE
        )


def validate_scoring_script(deployment):
    score_script_path = Path(deployment.base_path).joinpath(
        deployment.code_configuration.code, deployment.scoring_script
    )
    try:
        with open(score_script_path, "r", encoding=DefaultOpenEncoding.READ) as script:
            contents = script.read()
            try:
                ast.parse(contents, score_script_path)
            except SyntaxError as err:
                err.filename = err.filename.split("/")[-1]
                msg = (
                    f"Failed to submit deployment {deployment.name} due to syntax errors "
                    f"in scoring script {err.filename}.\nError on line {err.lineno}: "
                    f"{err.text}\nIf you wish to bypass this validation use --skip-script-validation paramater."
                )

                np_msg = (
                    "Failed to submit deployment due to syntax errors in deployment script."
                    "\n If you wish to bypass this validation use --skip-script-validation paramater."
                )
                raise ValidationException(
                    message=msg,
                    target=(
                        ErrorTarget.BATCH_DEPLOYMENT
                        if isinstance(deployment, BatchDeployment)
                        else ErrorTarget.ONLINE_DEPLOYMENT
                    ),
                    no_personal_data_message=np_msg,
                    error_category=ErrorCategory.USER_ERROR,
                    error_type=ValidationErrorType.CANNOT_PARSE,
                ) from err
    except OSError as err:
        raise MlException(
            message=f"Failed to open scoring script {err.filename}.",
            no_personal_data_message="Failed to open scoring script.",
        ) from err


def convert_v1_dataset_to_v2(output_data_set: DataVersion, file_name: str) -> Dict[str, Any]:
    if file_name:
        v2_dataset = UriFileJobOutput(
            uri=f"azureml://datastores/{output_data_set.datastore_id}/paths/{output_data_set.path}/{file_name}"
        ).serialize()
    else:
        v2_dataset = UriFileJobOutput(
            uri=f"azureml://datastores/{output_data_set.datastore_id}/paths/{output_data_set.path}"
        ).serialize()
    return {"output_name": v2_dataset}