diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_endpoint_utils.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_endpoint_utils.py | 229 |
1 files changed, 229 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_endpoint_utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_endpoint_utils.py new file mode 100644 index 00000000..b8bab283 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_endpoint_utils.py @@ -0,0 +1,229 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import ast +import concurrent.futures +import logging +import time +from concurrent.futures import Future +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Union + +from azure.ai.ml._restclient.v2020_09_01_dataplanepreview.models import DataVersion, UriFileJobOutput +from azure.ai.ml._utils._arm_id_utils import is_ARM_id_for_resource, is_registry_id_for_resource +from azure.ai.ml._utils._logger_utils import initialize_logger_info +from azure.ai.ml.constants._common import ARM_ID_PREFIX, AzureMLResourceType, DefaultOpenEncoding, LROConfigurations +from azure.ai.ml.entities import BatchDeployment +from azure.ai.ml.entities._assets._artifacts.code import Code +from azure.ai.ml.entities._deployment.deployment import Deployment +from azure.ai.ml.entities._deployment.model_batch_deployment import ModelBatchDeployment +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, MlException, ValidationErrorType, ValidationException +from azure.ai.ml.operations._operation_orchestrator import OperationOrchestrator +from azure.core.exceptions import ( + ClientAuthenticationError, + HttpResponseError, + ResourceExistsError, + ResourceNotFoundError, + ServiceRequestTimeoutError, + map_error, +) +from azure.core.polling import LROPoller +from azure.core.rest import HttpResponse +from azure.mgmt.core.exceptions import ARMErrorFormat + +module_logger = logging.getLogger(__name__) +initialize_logger_info(module_logger, terminator="") + + +def get_duration(start_time: float) -> None: + """Calculates the duration of the Long running operation took to finish. + + :param start_time: Start time + :type start_time: float + """ + end_time = time.time() + duration = divmod(int(round(end_time - start_time)), 60) + module_logger.warning("(%sm %ss)\n", duration[0], duration[1]) + + +def polling_wait( + poller: Union[LROPoller, Future], + message: Optional[str] = None, + start_time: Optional[float] = None, + is_local=False, + timeout=LROConfigurations.POLLING_TIMEOUT, +) -> Any: + """Print out status while polling and time of operation once completed. + + :param poller: An poller which will return status update via function done(). + :type poller: Union[LROPoller, concurrent.futures.Future] + :param (str, optional) message: Message to print out before starting operation write-out. + :param (float, optional) start_time: Start time of operation. + :param (bool, optional) is_local: If poller is for a local endpoint, so the timeout is removed. + :param (int, optional) timeout: New value to overwrite the default timeout. + """ + module_logger.warning("%s", message) + if is_local: + # We removed timeout on local endpoints in case it takes a long time + # to pull image or install conda env. + + # We want user to be able to see that. + + while not poller.done(): + module_logger.warning(".") + time.sleep(LROConfigurations.SLEEP_TIME) + else: + poller.result(timeout=timeout) + + if poller.done(): + module_logger.warning("Done ") + else: + module_logger.warning("Timeout waiting for long running operation") + + if start_time: + get_duration(start_time) + + +def local_endpoint_polling_wrapper(func: Callable, message: str, **kwargs) -> Any: + """Wrapper for polling local endpoint operations. + + :param func: Name of the endpoint. + :type func: Callable + :param message: Message to print out before starting operation write-out. + :type message: str + :return: The type returned by Func + """ + pool = concurrent.futures.ThreadPoolExecutor() + start_time = time.time() + event = pool.submit(func, **kwargs) + polling_wait(poller=event, start_time=start_time, message=message, is_local=True) + return event.result() + + +def validate_response(response: HttpResponse) -> None: + """Validates the response of POST requests, throws on error. + + :param HttpResponse response: the response of a POST requests + :raises Exception: Raised when response is not json serializable + :raises HttpResponseError: Raised when the response signals that an error occurred + """ + r_json = {} + + if response.status_code not in [200, 201]: + # We need to check for an empty response body or catch the exception raised. + # It is possible the server responded with a 204 No Content response, and json parsing fails. + if response.status_code != 204: + try: + r_json = response.json() + except ValueError as e: + # exception is not in the json format + msg = response.content.decode("utf-8") + raise MlException(message=msg, no_personal_data_message=msg) from e + failure_msg = r_json.get("error", {}).get("message", response) + error_map = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 408: ServiceRequestTimeoutError, + 409: ResourceExistsError, + 424: HttpResponseError, + } + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response, message=failure_msg, error_format=ARMErrorFormat) + + +def upload_dependencies(deployment: Deployment, orchestrators: OperationOrchestrator) -> None: + """Upload code, dependency, model dependencies. For BatchDeployment only register compute. + + :param Deployment deployment: Endpoint deployment object. + :param OperationOrchestrator orchestrators: Operation Orchestrator. + """ + + module_logger.debug("Uploading the dependencies for deployment %s", deployment.name) + + # Create a code asset if code is not already an ARM ID + if ( + deployment.code_configuration + and not is_ARM_id_for_resource(deployment.code_configuration.code, AzureMLResourceType.CODE) + and not is_registry_id_for_resource(deployment.code_configuration.code) + ): + if deployment.code_configuration.code.startswith(ARM_ID_PREFIX): + deployment.code_configuration.code = orchestrators.get_asset_arm_id( + deployment.code_configuration.code[len(ARM_ID_PREFIX) :], + azureml_type=AzureMLResourceType.CODE, + ) + else: + deployment.code_configuration.code = orchestrators.get_asset_arm_id( + Code(base_path=deployment._base_path, path=deployment.code_configuration.code), + azureml_type=AzureMLResourceType.CODE, + ) + + if not is_registry_id_for_resource(deployment.environment): + deployment.environment = ( + orchestrators.get_asset_arm_id(deployment.environment, azureml_type=AzureMLResourceType.ENVIRONMENT) + if deployment.environment + else None + ) + if not is_registry_id_for_resource(deployment.model): + deployment.model = ( + orchestrators.get_asset_arm_id(deployment.model, azureml_type=AzureMLResourceType.MODEL) + if deployment.model + else None + ) + if isinstance(deployment, (BatchDeployment, ModelBatchDeployment)) and deployment.compute: + deployment.compute = orchestrators.get_asset_arm_id( + deployment.compute, azureml_type=AzureMLResourceType.COMPUTE + ) + + +def validate_scoring_script(deployment): + score_script_path = Path(deployment.base_path).joinpath( + deployment.code_configuration.code, deployment.scoring_script + ) + try: + with open(score_script_path, "r", encoding=DefaultOpenEncoding.READ) as script: + contents = script.read() + try: + ast.parse(contents, score_script_path) + except SyntaxError as err: + err.filename = err.filename.split("/")[-1] + msg = ( + f"Failed to submit deployment {deployment.name} due to syntax errors " + f"in scoring script {err.filename}.\nError on line {err.lineno}: " + f"{err.text}\nIf you wish to bypass this validation use --skip-script-validation paramater." + ) + + np_msg = ( + "Failed to submit deployment due to syntax errors in deployment script." + "\n If you wish to bypass this validation use --skip-script-validation paramater." + ) + raise ValidationException( + message=msg, + target=( + ErrorTarget.BATCH_DEPLOYMENT + if isinstance(deployment, BatchDeployment) + else ErrorTarget.ONLINE_DEPLOYMENT + ), + no_personal_data_message=np_msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.CANNOT_PARSE, + ) from err + except OSError as err: + raise MlException( + message=f"Failed to open scoring script {err.filename}.", + no_personal_data_message="Failed to open scoring script.", + ) from err + + +def convert_v1_dataset_to_v2(output_data_set: DataVersion, file_name: str) -> Dict[str, Any]: + if file_name: + v2_dataset = UriFileJobOutput( + uri=f"azureml://datastores/{output_data_set.datastore_id}/paths/{output_data_set.path}/{file_name}" + ).serialize() + else: + v2_dataset = UriFileJobOutput( + uri=f"azureml://datastores/{output_data_set.datastore_id}/paths/{output_data_set.path}" + ).serialize() + return {"output_name": v2_dataset} |