diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_batch_endpoint_operations.py')
-rw-r--r-- | .venv/lib/python3.12/site-packages/azure/ai/ml/operations/_batch_endpoint_operations.py | 553 |
1 files changed, 553 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_batch_endpoint_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_batch_endpoint_operations.py new file mode 100644 index 00000000..650185b3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_batch_endpoint_operations.py @@ -0,0 +1,553 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import json +import os +import re +import uuid +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Optional, cast + +from marshmallow.exceptions import ValidationError as SchemaValidationError + +from azure.ai.ml._artifacts._artifact_utilities import _upload_and_generate_remote_uri +from azure.ai.ml._azure_environments import _get_aml_resource_id_from_metadata, _resource_to_scopes +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._restclient.v2020_09_01_dataplanepreview.models import BatchJobResource +from azure.ai.ml._restclient.v2023_10_01 import AzureMachineLearningServices as ServiceClient102023 +from azure.ai.ml._schema._deployment.batch.batch_job import BatchJobSchema +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationsContainer, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._arm_id_utils import is_ARM_id_for_resource, remove_aml_prefix +from azure.ai.ml._utils._azureml_polling import AzureMLPolling +from azure.ai.ml._utils._endpoint_utils import convert_v1_dataset_to_v2, validate_response +from azure.ai.ml._utils._http_utils import HttpPipeline +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml._utils.utils import ( + _get_mfe_base_url_from_discovery_service, + is_private_preview_enabled, + modified_operation_client, +) +from azure.ai.ml.constants._common import ( + ARM_ID_FULL_PREFIX, + AZUREML_REGEX_FORMAT, + BASE_PATH_CONTEXT_KEY, + HTTP_PREFIX, + LONG_URI_REGEX_FORMAT, + PARAMS_OVERRIDE_KEY, + SHORT_URI_REGEX_FORMAT, + AssetTypes, + AzureMLResourceType, + InputTypes, + LROConfigurations, +) +from azure.ai.ml.constants._endpoint import EndpointInvokeFields, EndpointYamlFields +from azure.ai.ml.entities import BatchEndpoint, BatchJob +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, MlException, ValidationErrorType, ValidationException +from azure.core.credentials import TokenCredential +from azure.core.exceptions import HttpResponseError, ServiceRequestError, ServiceResponseError +from azure.core.paging import ItemPaged +from azure.core.polling import LROPoller +from azure.core.tracing.decorator import distributed_trace + +from ._operation_orchestrator import OperationOrchestrator + +if TYPE_CHECKING: + from azure.ai.ml.operations import DatastoreOperations + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class BatchEndpointOperations(_ScopeDependentOperations): + """BatchEndpointOperations. + + You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it + for you and attaches it as an attribute. + + :param operation_scope: Scope variables for the operations classes of an MLClient object. + :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope + :param operation_config: Common configuration for operations classes of an MLClient object. + :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig + :param service_client_10_2023: Service client to allow end users to operate on Azure Machine Learning Workspace + resources. + :type service_client_10_2023: ~azure.ai.ml._restclient.v2023_10_01._azure_machine_learning_workspaces. + AzureMachineLearningWorkspaces + :param all_operations: All operations classes of an MLClient object. + :type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer + :param credentials: Credential to use for authentication. + :type credentials: ~azure.core.credentials.TokenCredential + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client_10_2023: ServiceClient102023, + all_operations: OperationsContainer, + credentials: Optional[TokenCredential] = None, + **kwargs: Any, + ): + super(BatchEndpointOperations, self).__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._batch_operation = service_client_10_2023.batch_endpoints + self._batch_deployment_operation = service_client_10_2023.batch_deployments + self._batch_job_endpoint = kwargs.pop("service_client_09_2020_dataplanepreview").batch_job_endpoint + self._all_operations = all_operations + self._credentials = credentials + self._init_kwargs = kwargs + + self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline") + + @property + def _datastore_operations(self) -> "DatastoreOperations": + from azure.ai.ml.operations import DatastoreOperations + + return cast(DatastoreOperations, self._all_operations.all_operations[AzureMLResourceType.DATASTORE]) + + @distributed_trace + @monitor_with_activity(ops_logger, "BatchEndpoint.List", ActivityType.PUBLICAPI) + def list(self) -> ItemPaged[BatchEndpoint]: + """List endpoints of the workspace. + + :return: A list of endpoints + :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.BatchEndpoint] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START batch_endpoint_operations_list] + :end-before: [END batch_endpoint_operations_list] + :language: python + :dedent: 8 + :caption: List example. + """ + return self._batch_operation.list( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + cls=lambda objs: [BatchEndpoint._from_rest_object(obj) for obj in objs], + **self._init_kwargs, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "BatchEndpoint.Get", ActivityType.PUBLICAPI) + def get( + self, + name: str, + ) -> BatchEndpoint: + """Get a Endpoint resource. + + :param name: Name of the endpoint. + :type name: str + :return: Endpoint object retrieved from the service. + :rtype: ~azure.ai.ml.entities.BatchEndpoint + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START batch_endpoint_operations_get] + :end-before: [END batch_endpoint_operations_get] + :language: python + :dedent: 8 + :caption: Get endpoint example. + """ + # first get the endpoint + endpoint = self._batch_operation.get( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=name, + **self._init_kwargs, + ) + + endpoint_data = BatchEndpoint._from_rest_object(endpoint) + return endpoint_data + + @distributed_trace + @monitor_with_activity(ops_logger, "BatchEndpoint.BeginDelete", ActivityType.PUBLICAPI) + def begin_delete(self, name: str) -> LROPoller[None]: + """Delete a batch Endpoint. + + :param name: Name of the batch endpoint. + :type name: str + :return: A poller to track the operation status. + :rtype: ~azure.core.polling.LROPoller[None] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START batch_endpoint_operations_delete] + :end-before: [END batch_endpoint_operations_delete] + :language: python + :dedent: 8 + :caption: Delete endpoint example. + """ + path_format_arguments = { + "endpointName": name, + "resourceGroupName": self._resource_group_name, + "workspaceName": self._workspace_name, + } + + delete_poller = self._batch_operation.begin_delete( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=name, + polling=AzureMLPolling( + LROConfigurations.POLL_INTERVAL, + path_format_arguments=path_format_arguments, + **self._init_kwargs, + ), + polling_interval=LROConfigurations.POLL_INTERVAL, + **self._init_kwargs, + ) + return delete_poller + + @distributed_trace + @monitor_with_activity(ops_logger, "BatchEndpoint.BeginCreateOrUpdate", ActivityType.PUBLICAPI) + def begin_create_or_update(self, endpoint: BatchEndpoint) -> LROPoller[BatchEndpoint]: + """Create or update a batch endpoint. + + :param endpoint: The endpoint entity. + :type endpoint: ~azure.ai.ml.entities.BatchEndpoint + :return: A poller to track the operation status. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.BatchEndpoint] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START batch_endpoint_operations_create_or_update] + :end-before: [END batch_endpoint_operations_create_or_update] + :language: python + :dedent: 8 + :caption: Create endpoint example. + """ + + try: + location = self._get_workspace_location() + + endpoint_resource = endpoint._to_rest_batch_endpoint(location=location) + poller = self._batch_operation.begin_create_or_update( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=endpoint.name, + body=endpoint_resource, + polling=True, + **self._init_kwargs, + cls=lambda response, deserialized, headers: BatchEndpoint._from_rest_object(deserialized), + ) + return poller + + except Exception as ex: + if isinstance(ex, (ValidationException, SchemaValidationError)): + log_and_raise_error(ex) + raise ex + + @distributed_trace + @monitor_with_activity(ops_logger, "BatchEndpoint.Invoke", ActivityType.PUBLICAPI) + def invoke( # pylint: disable=too-many-statements + self, + endpoint_name: str, + *, + deployment_name: Optional[str] = None, + inputs: Optional[Dict[str, Input]] = None, + **kwargs: Any, + ) -> BatchJob: + """Invokes the batch endpoint with the provided payload. + + :param endpoint_name: The endpoint name. + :type endpoint_name: str + :keyword deployment_name: (Optional) The name of a specific deployment to invoke. This is optional. + By default requests are routed to any of the deployments according to the traffic rules. + :paramtype deployment_name: str + :keyword inputs: (Optional) A dictionary of existing data asset, public uri file or folder + to use with the deployment + :paramtype inputs: Dict[str, Input] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if deployment cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.AssetException: Raised if BatchEndpoint assets + (e.g. Data, Code, Model, Environment) cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.ModelException: Raised if BatchEndpoint model cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty directory. + :return: The invoked batch deployment job. + :rtype: ~azure.ai.ml.entities.BatchJob + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START batch_endpoint_operations_invoke] + :end-before: [END batch_endpoint_operations_invoke] + :language: python + :dedent: 8 + :caption: Invoke endpoint example. + """ + outputs = kwargs.get("outputs", None) + job_name = kwargs.get("job_name", None) + params_override = kwargs.get("params_override", None) or [] + experiment_name = kwargs.get("experiment_name", None) + input = kwargs.get("input", None) # pylint: disable=redefined-builtin + # Until this bug is resolved https://msdata.visualstudio.com/Vienna/_workitems/edit/1446538 + if deployment_name: + self._validate_deployment_name(endpoint_name, deployment_name) + + if input and isinstance(input, Input): + if HTTP_PREFIX not in input.path: + self._resolve_input(input, os.getcwd()) + # MFE expects a dictionary as input_data that's why we are using + # "UriFolder" or "UriFile" as keys depending on the input type + if input.type == "uri_folder": + params_override.append({EndpointYamlFields.BATCH_JOB_INPUT_DATA: {"UriFolder": input}}) + elif input.type == "uri_file": + params_override.append({EndpointYamlFields.BATCH_JOB_INPUT_DATA: {"UriFile": input}}) + else: + msg = ( + "Unsupported input type please use a dictionary of either a path on the datastore, public URI, " + "a registered data asset, or a local folder path." + ) + raise ValidationException( + message=msg, + target=ErrorTarget.BATCH_ENDPOINT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + elif inputs: + for key, input_data in inputs.items(): + if ( + isinstance(input_data, Input) + and input_data.type + not in [InputTypes.NUMBER, InputTypes.BOOLEAN, InputTypes.INTEGER, InputTypes.STRING] + and HTTP_PREFIX not in input_data.path + ): + self._resolve_input(input_data, os.getcwd()) + params_override.append({EndpointYamlFields.BATCH_JOB_INPUT_DATA: inputs}) + + properties = {} + + if outputs: + params_override.append({EndpointYamlFields.BATCH_JOB_OUTPUT_DATA: outputs}) + if job_name: + params_override.append({EndpointYamlFields.BATCH_JOB_NAME: job_name}) + if experiment_name: + properties["experimentName"] = experiment_name + + if properties: + params_override.append({EndpointYamlFields.BATCH_JOB_PROPERTIES: properties}) + + # Batch job doesn't have a python class, loading a rest object using params override + context = { + BASE_PATH_CONTEXT_KEY: Path(".").parent, + PARAMS_OVERRIDE_KEY: params_override, + } + + batch_job = BatchJobSchema(context=context).load(data={}) + # update output datastore to arm id if needed + # TODO: Unify datastore name -> arm id logic, TASK: 1104172 + request = {} + if ( + batch_job.output_dataset + and batch_job.output_dataset.datastore_id + and (not is_ARM_id_for_resource(batch_job.output_dataset.datastore_id)) + ): + v2_dataset_dictionary = convert_v1_dataset_to_v2(batch_job.output_dataset, batch_job.output_file_name) + batch_job.output_dataset = None + batch_job.output_file_name = None + request = BatchJobResource(properties=batch_job).serialize() + request["properties"]["outputData"] = v2_dataset_dictionary + else: + request = BatchJobResource(properties=batch_job).serialize() + + endpoint = self._batch_operation.get( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=endpoint_name, + **self._init_kwargs, + ) + + headers = EndpointInvokeFields.DEFAULT_HEADER + ml_audience_scopes = _resource_to_scopes(_get_aml_resource_id_from_metadata()) + module_logger.debug("ml_audience_scopes used: `%s`\n", ml_audience_scopes) + key = self._credentials.get_token(*ml_audience_scopes).token if self._credentials is not None else "" + headers[EndpointInvokeFields.AUTHORIZATION] = f"Bearer {key}" + headers[EndpointInvokeFields.REPEATABILITY_REQUEST_ID] = str(uuid.uuid4()) + + if deployment_name: + headers[EndpointInvokeFields.MODEL_DEPLOYMENT] = deployment_name + + retry_attempts = 0 + while retry_attempts < 5: + try: + response = self._requests_pipeline.post( + endpoint.properties.scoring_uri, + json=request, + headers=headers, + ) + except (ServiceRequestError, ServiceResponseError): + retry_attempts += 1 + continue + break + if retry_attempts == 5: + retry_msg = "Max retry attempts reached while trying to connect to server. Please check connection and invoke again." # pylint: disable=line-too-long + raise MlException(message=retry_msg, no_personal_data_message=retry_msg, target=ErrorTarget.BATCH_ENDPOINT) + validate_response(response) + batch_job = json.loads(response.text()) + return BatchJobResource.deserialize(batch_job) + + @distributed_trace + @monitor_with_activity(ops_logger, "BatchEndpoint.ListJobs", ActivityType.PUBLICAPI) + def list_jobs(self, endpoint_name: str) -> ItemPaged[BatchJob]: + """List jobs under the provided batch endpoint deployment. This is only valid for batch endpoint. + + :param endpoint_name: The endpoint name + :type endpoint_name: str + :return: List of jobs + :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.BatchJob] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START batch_endpoint_operations_list_jobs] + :end-before: [END batch_endpoint_operations_list_jobs] + :language: python + :dedent: 8 + :caption: List jobs example. + """ + + workspace_operations = self._all_operations.all_operations[AzureMLResourceType.WORKSPACE] + mfe_base_uri = _get_mfe_base_url_from_discovery_service( + workspace_operations, self._workspace_name, self._requests_pipeline + ) + + with modified_operation_client(self._batch_job_endpoint, mfe_base_uri): + result = self._batch_job_endpoint.list( + endpoint_name=endpoint_name, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + **self._init_kwargs, + ) + + # This is necessary as the paged result need to be resolved inside the context manager + return list(result) + + def _get_workspace_location(self) -> str: + return str( + self._all_operations.all_operations[AzureMLResourceType.WORKSPACE].get(self._workspace_name).location + ) + + def _validate_deployment_name(self, endpoint_name: str, deployment_name: str) -> None: + deployments_list = self._batch_deployment_operation.list( + endpoint_name=endpoint_name, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + cls=lambda objs: [obj.name for obj in objs], + **self._init_kwargs, + ) + if deployments_list: + if deployment_name not in deployments_list: + msg = f"Deployment name {deployment_name} not found for this endpoint" + raise ValidationException( + message=msg.format(deployment_name), + no_personal_data_message=msg.format("[deployment_name]"), + target=ErrorTarget.DEPLOYMENT, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.RESOURCE_NOT_FOUND, + ) + else: + msg = "No deployment exists for this endpoint" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.DEPLOYMENT, + error_type=ValidationErrorType.RESOURCE_NOT_FOUND, + ) + + def _resolve_input(self, entry: Input, base_path: str) -> None: + # We should not verify anything that is not of type Input + if not isinstance(entry, Input): + return + + # Input path should not be empty + if not entry.path: + msg = "Input path can't be empty for batch endpoint invoke" + raise MlException(message=msg, no_personal_data_message=msg) + + if entry.type in [InputTypes.NUMBER, InputTypes.BOOLEAN, InputTypes.INTEGER, InputTypes.STRING]: + return + + try: + if entry.path.startswith(ARM_ID_FULL_PREFIX): + if not is_ARM_id_for_resource(entry.path, AzureMLResourceType.DATA): + raise ValidationException( + message="Invalid input path", + target=ErrorTarget.BATCH_ENDPOINT, + no_personal_data_message="Invalid input path", + error_type=ValidationErrorType.INVALID_VALUE, + ) + elif os.path.isabs(entry.path): # absolute local path, upload, transform to remote url + if entry.type == AssetTypes.URI_FOLDER and not os.path.isdir(entry.path): + raise ValidationException( + message="There is no folder on target path: {}".format(entry.path), + target=ErrorTarget.BATCH_ENDPOINT, + no_personal_data_message="There is no folder on target path", + error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND, + ) + if entry.type == AssetTypes.URI_FILE and not os.path.isfile(entry.path): + raise ValidationException( + message="There is no file on target path: {}".format(entry.path), + target=ErrorTarget.BATCH_ENDPOINT, + no_personal_data_message="There is no file on target path", + error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND, + ) + # absolute local path + entry.path = _upload_and_generate_remote_uri( + self._operation_scope, + self._datastore_operations, + entry.path, + ) + if entry.type == AssetTypes.URI_FOLDER and entry.path and not entry.path.endswith("/"): + entry.path = entry.path + "/" + elif ":" in entry.path or "@" in entry.path: # Check registered file or folder datastore + # If we receive a datastore path in long/short form we don't need + # to get the arm asset id + if re.match(SHORT_URI_REGEX_FORMAT, entry.path) or re.match(LONG_URI_REGEX_FORMAT, entry.path): + return + if is_private_preview_enabled() and re.match(AZUREML_REGEX_FORMAT, entry.path): + return + asset_type = AzureMLResourceType.DATA + entry.path = remove_aml_prefix(entry.path) + orchestrator = OperationOrchestrator( + self._all_operations, self._operation_scope, self._operation_config + ) + entry.path = orchestrator.get_asset_arm_id(entry.path, asset_type) + else: # relative local path, upload, transform to remote url + local_path = Path(base_path, entry.path).resolve() + entry.path = _upload_and_generate_remote_uri( + self._operation_scope, + self._datastore_operations, + local_path, + ) + if entry.type == AssetTypes.URI_FOLDER and entry.path and not entry.path.endswith("/"): + entry.path = entry.path + "/" + except (MlException, HttpResponseError) as e: + raise e + except Exception as e: + raise ValidationException( + message=f"Supported input path value are: path on the datastore, public URI, " + "a registered data asset, or a local folder path.\n" + f"Met {type(e)}:\n{e}", + target=ErrorTarget.BATCH_ENDPOINT, + no_personal_data_message="Supported input path value are: path on the datastore, " + "public URI, a registered data asset, or a local folder path.", + error=e, + error_type=ValidationErrorType.INVALID_VALUE, + ) from e |