aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_endpoint_utils.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_endpoint_utils.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_endpoint_utils.py')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_endpoint_utils.py229
1 files changed, 229 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_endpoint_utils.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_endpoint_utils.py
new file mode 100644
index 00000000..b8bab283
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/_utils/_endpoint_utils.py
@@ -0,0 +1,229 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import ast
+import concurrent.futures
+import logging
+import time
+from concurrent.futures import Future
+from pathlib import Path
+from typing import Any, Callable, Dict, Optional, Union
+
+from azure.ai.ml._restclient.v2020_09_01_dataplanepreview.models import DataVersion, UriFileJobOutput
+from azure.ai.ml._utils._arm_id_utils import is_ARM_id_for_resource, is_registry_id_for_resource
+from azure.ai.ml._utils._logger_utils import initialize_logger_info
+from azure.ai.ml.constants._common import ARM_ID_PREFIX, AzureMLResourceType, DefaultOpenEncoding, LROConfigurations
+from azure.ai.ml.entities import BatchDeployment
+from azure.ai.ml.entities._assets._artifacts.code import Code
+from azure.ai.ml.entities._deployment.deployment import Deployment
+from azure.ai.ml.entities._deployment.model_batch_deployment import ModelBatchDeployment
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, MlException, ValidationErrorType, ValidationException
+from azure.ai.ml.operations._operation_orchestrator import OperationOrchestrator
+from azure.core.exceptions import (
+ ClientAuthenticationError,
+ HttpResponseError,
+ ResourceExistsError,
+ ResourceNotFoundError,
+ ServiceRequestTimeoutError,
+ map_error,
+)
+from azure.core.polling import LROPoller
+from azure.core.rest import HttpResponse
+from azure.mgmt.core.exceptions import ARMErrorFormat
+
+module_logger = logging.getLogger(__name__)
+initialize_logger_info(module_logger, terminator="")
+
+
+def get_duration(start_time: float) -> None:
+ """Calculates the duration of the Long running operation took to finish.
+
+ :param start_time: Start time
+ :type start_time: float
+ """
+ end_time = time.time()
+ duration = divmod(int(round(end_time - start_time)), 60)
+ module_logger.warning("(%sm %ss)\n", duration[0], duration[1])
+
+
+def polling_wait(
+ poller: Union[LROPoller, Future],
+ message: Optional[str] = None,
+ start_time: Optional[float] = None,
+ is_local=False,
+ timeout=LROConfigurations.POLLING_TIMEOUT,
+) -> Any:
+ """Print out status while polling and time of operation once completed.
+
+ :param poller: An poller which will return status update via function done().
+ :type poller: Union[LROPoller, concurrent.futures.Future]
+ :param (str, optional) message: Message to print out before starting operation write-out.
+ :param (float, optional) start_time: Start time of operation.
+ :param (bool, optional) is_local: If poller is for a local endpoint, so the timeout is removed.
+ :param (int, optional) timeout: New value to overwrite the default timeout.
+ """
+ module_logger.warning("%s", message)
+ if is_local:
+ # We removed timeout on local endpoints in case it takes a long time
+ # to pull image or install conda env.
+
+ # We want user to be able to see that.
+
+ while not poller.done():
+ module_logger.warning(".")
+ time.sleep(LROConfigurations.SLEEP_TIME)
+ else:
+ poller.result(timeout=timeout)
+
+ if poller.done():
+ module_logger.warning("Done ")
+ else:
+ module_logger.warning("Timeout waiting for long running operation")
+
+ if start_time:
+ get_duration(start_time)
+
+
+def local_endpoint_polling_wrapper(func: Callable, message: str, **kwargs) -> Any:
+ """Wrapper for polling local endpoint operations.
+
+ :param func: Name of the endpoint.
+ :type func: Callable
+ :param message: Message to print out before starting operation write-out.
+ :type message: str
+ :return: The type returned by Func
+ """
+ pool = concurrent.futures.ThreadPoolExecutor()
+ start_time = time.time()
+ event = pool.submit(func, **kwargs)
+ polling_wait(poller=event, start_time=start_time, message=message, is_local=True)
+ return event.result()
+
+
+def validate_response(response: HttpResponse) -> None:
+ """Validates the response of POST requests, throws on error.
+
+ :param HttpResponse response: the response of a POST requests
+ :raises Exception: Raised when response is not json serializable
+ :raises HttpResponseError: Raised when the response signals that an error occurred
+ """
+ r_json = {}
+
+ if response.status_code not in [200, 201]:
+ # We need to check for an empty response body or catch the exception raised.
+ # It is possible the server responded with a 204 No Content response, and json parsing fails.
+ if response.status_code != 204:
+ try:
+ r_json = response.json()
+ except ValueError as e:
+ # exception is not in the json format
+ msg = response.content.decode("utf-8")
+ raise MlException(message=msg, no_personal_data_message=msg) from e
+ failure_msg = r_json.get("error", {}).get("message", response)
+ error_map = {
+ 401: ClientAuthenticationError,
+ 404: ResourceNotFoundError,
+ 408: ServiceRequestTimeoutError,
+ 409: ResourceExistsError,
+ 424: HttpResponseError,
+ }
+ map_error(status_code=response.status_code, response=response, error_map=error_map)
+ raise HttpResponseError(response=response, message=failure_msg, error_format=ARMErrorFormat)
+
+
+def upload_dependencies(deployment: Deployment, orchestrators: OperationOrchestrator) -> None:
+ """Upload code, dependency, model dependencies. For BatchDeployment only register compute.
+
+ :param Deployment deployment: Endpoint deployment object.
+ :param OperationOrchestrator orchestrators: Operation Orchestrator.
+ """
+
+ module_logger.debug("Uploading the dependencies for deployment %s", deployment.name)
+
+ # Create a code asset if code is not already an ARM ID
+ if (
+ deployment.code_configuration
+ and not is_ARM_id_for_resource(deployment.code_configuration.code, AzureMLResourceType.CODE)
+ and not is_registry_id_for_resource(deployment.code_configuration.code)
+ ):
+ if deployment.code_configuration.code.startswith(ARM_ID_PREFIX):
+ deployment.code_configuration.code = orchestrators.get_asset_arm_id(
+ deployment.code_configuration.code[len(ARM_ID_PREFIX) :],
+ azureml_type=AzureMLResourceType.CODE,
+ )
+ else:
+ deployment.code_configuration.code = orchestrators.get_asset_arm_id(
+ Code(base_path=deployment._base_path, path=deployment.code_configuration.code),
+ azureml_type=AzureMLResourceType.CODE,
+ )
+
+ if not is_registry_id_for_resource(deployment.environment):
+ deployment.environment = (
+ orchestrators.get_asset_arm_id(deployment.environment, azureml_type=AzureMLResourceType.ENVIRONMENT)
+ if deployment.environment
+ else None
+ )
+ if not is_registry_id_for_resource(deployment.model):
+ deployment.model = (
+ orchestrators.get_asset_arm_id(deployment.model, azureml_type=AzureMLResourceType.MODEL)
+ if deployment.model
+ else None
+ )
+ if isinstance(deployment, (BatchDeployment, ModelBatchDeployment)) and deployment.compute:
+ deployment.compute = orchestrators.get_asset_arm_id(
+ deployment.compute, azureml_type=AzureMLResourceType.COMPUTE
+ )
+
+
+def validate_scoring_script(deployment):
+ score_script_path = Path(deployment.base_path).joinpath(
+ deployment.code_configuration.code, deployment.scoring_script
+ )
+ try:
+ with open(score_script_path, "r", encoding=DefaultOpenEncoding.READ) as script:
+ contents = script.read()
+ try:
+ ast.parse(contents, score_script_path)
+ except SyntaxError as err:
+ err.filename = err.filename.split("/")[-1]
+ msg = (
+ f"Failed to submit deployment {deployment.name} due to syntax errors "
+ f"in scoring script {err.filename}.\nError on line {err.lineno}: "
+ f"{err.text}\nIf you wish to bypass this validation use --skip-script-validation paramater."
+ )
+
+ np_msg = (
+ "Failed to submit deployment due to syntax errors in deployment script."
+ "\n If you wish to bypass this validation use --skip-script-validation paramater."
+ )
+ raise ValidationException(
+ message=msg,
+ target=(
+ ErrorTarget.BATCH_DEPLOYMENT
+ if isinstance(deployment, BatchDeployment)
+ else ErrorTarget.ONLINE_DEPLOYMENT
+ ),
+ no_personal_data_message=np_msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.CANNOT_PARSE,
+ ) from err
+ except OSError as err:
+ raise MlException(
+ message=f"Failed to open scoring script {err.filename}.",
+ no_personal_data_message="Failed to open scoring script.",
+ ) from err
+
+
+def convert_v1_dataset_to_v2(output_data_set: DataVersion, file_name: str) -> Dict[str, Any]:
+ if file_name:
+ v2_dataset = UriFileJobOutput(
+ uri=f"azureml://datastores/{output_data_set.datastore_id}/paths/{output_data_set.path}/{file_name}"
+ ).serialize()
+ else:
+ v2_dataset = UriFileJobOutput(
+ uri=f"azureml://datastores/{output_data_set.datastore_id}/paths/{output_data_set.path}"
+ ).serialize()
+ return {"output_name": v2_dataset}