diff options
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/operations')
38 files changed, 16210 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/__init__.py new file mode 100644 index 00000000..4508bd95 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/__init__.py @@ -0,0 +1,65 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +"""Contains supported operations for Azure Machine Learning SDKv2. + +Operations are classes contain logic to interact with backend services, usually auto generated operations call. +""" +__path__ = __import__("pkgutil").extend_path(__path__, __name__) + + +from ._azure_openai_deployment_operations import AzureOpenAIDeploymentOperations +from ._batch_deployment_operations import BatchDeploymentOperations +from ._batch_endpoint_operations import BatchEndpointOperations +from ._component_operations import ComponentOperations +from ._compute_operations import ComputeOperations +from ._data_operations import DataOperations +from ._datastore_operations import DatastoreOperations +from ._environment_operations import EnvironmentOperations +from ._feature_set_operations import FeatureSetOperations +from ._feature_store_entity_operations import FeatureStoreEntityOperations +from ._feature_store_operations import FeatureStoreOperations +from ._index_operations import IndexOperations +from ._job_operations import JobOperations +from ._model_operations import ModelOperations +from ._online_deployment_operations import OnlineDeploymentOperations +from ._online_endpoint_operations import OnlineEndpointOperations +from ._registry_operations import RegistryOperations +from ._schedule_operations import ScheduleOperations +from ._workspace_connections_operations import WorkspaceConnectionsOperations +from ._workspace_operations import WorkspaceOperations +from ._workspace_outbound_rule_operations import WorkspaceOutboundRuleOperations +from ._evaluator_operations import EvaluatorOperations +from ._serverless_endpoint_operations import ServerlessEndpointOperations +from ._marketplace_subscription_operations import MarketplaceSubscriptionOperations +from ._capability_hosts_operations import CapabilityHostsOperations + +__all__ = [ + "ComputeOperations", + "DatastoreOperations", + "JobOperations", + "ModelOperations", + "EvaluatorOperations", + "WorkspaceOperations", + "RegistryOperations", + "OnlineEndpointOperations", + "BatchEndpointOperations", + "OnlineDeploymentOperations", + "BatchDeploymentOperations", + "DataOperations", + "EnvironmentOperations", + "ComponentOperations", + "WorkspaceConnectionsOperations", + "RegistryOperations", + "ScheduleOperations", + "WorkspaceOutboundRuleOperations", + "FeatureSetOperations", + "FeatureStoreEntityOperations", + "FeatureStoreOperations", + "ServerlessEndpointOperations", + "MarketplaceSubscriptionOperations", + "IndexOperations", + "AzureOpenAIDeploymentOperations", + "CapabilityHostsOperations", +] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_azure_openai_deployment_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_azure_openai_deployment_operations.py new file mode 100644 index 00000000..542448a9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_azure_openai_deployment_operations.py @@ -0,0 +1,60 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import logging +from typing import Iterable + +from azure.ai.ml._restclient.v2024_04_01_preview import AzureMachineLearningWorkspaces as ServiceClient2020404Preview +from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations +from azure.ai.ml.entities._autogen_entities.models import AzureOpenAIDeployment + +from ._workspace_connections_operations import WorkspaceConnectionsOperations + +module_logger = logging.getLogger(__name__) + + +class AzureOpenAIDeploymentOperations(_ScopeDependentOperations): + """AzureOpenAIDeploymentOperations. + + You should not instantiate this class directly. Instead, you should + create an MLClient instance that instantiates it for you and + attaches it as an attribute. + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client: ServiceClient2020404Preview, + connections_operations: WorkspaceConnectionsOperations, + ): + super().__init__(operation_scope, operation_config) + self._service_client = service_client.connection + self._workspace_connections_operations = connections_operations + + def list(self, connection_name: str, **kwargs) -> Iterable[AzureOpenAIDeployment]: + """List Azure OpenAI deployments of the workspace. + + :param connection_name: Name of the connection from which to list deployments + :type connection_name: str + :return: A list of Azure OpenAI deployments + :rtype: ~typing.Iterable[~azure.ai.ml.entities.AzureOpenAIDeployment] + """ + connection = self._workspace_connections_operations.get(connection_name) + + def _from_rest_add_connection_name(obj): + from_rest_deployment = AzureOpenAIDeployment._from_rest_object(obj) + from_rest_deployment.connection_name = connection_name + from_rest_deployment.target_url = connection.target + return from_rest_deployment + + return self._service_client.list_deployments( + self._resource_group_name, + self._workspace_name, + connection_name, + cls=lambda objs: [_from_rest_add_connection_name(obj) for obj in objs], + **kwargs, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_batch_deployment_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_batch_deployment_operations.py new file mode 100644 index 00000000..9bde33c8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_batch_deployment_operations.py @@ -0,0 +1,392 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access, too-many-boolean-expressions + +import re +from typing import Any, Optional, TypeVar, Union + +from azure.ai.ml._restclient.v2024_01_01_preview import AzureMachineLearningWorkspaces as ServiceClient012024Preview +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationsContainer, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._arm_id_utils import AMLVersionedArmId +from azure.ai.ml._utils._azureml_polling import AzureMLPolling +from azure.ai.ml._utils._endpoint_utils import upload_dependencies, validate_scoring_script +from azure.ai.ml._utils._http_utils import HttpPipeline +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml._utils._package_utils import package_deployment +from azure.ai.ml._utils.utils import _get_mfe_base_url_from_discovery_service, modified_operation_client +from azure.ai.ml.constants._common import ARM_ID_PREFIX, AzureMLResourceType, LROConfigurations +from azure.ai.ml.entities import BatchDeployment, BatchJob, ModelBatchDeployment, PipelineComponent, PipelineJob +from azure.ai.ml.entities._deployment.pipeline_component_batch_deployment import PipelineComponentBatchDeployment +from azure.core.credentials import TokenCredential +from azure.core.exceptions import HttpResponseError, ResourceNotFoundError +from azure.core.paging import ItemPaged +from azure.core.polling import LROPoller +from azure.core.tracing.decorator import distributed_trace + +from ._operation_orchestrator import OperationOrchestrator + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger +DeploymentType = TypeVar( + "DeploymentType", bound=Union[BatchDeployment, PipelineComponentBatchDeployment, ModelBatchDeployment] +) + + +class BatchDeploymentOperations(_ScopeDependentOperations): + """BatchDeploymentOperations. + + You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it + for you and attaches it as an attribute. + + :param operation_scope: Scope variables for the operations classes of an MLClient object. + :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope + :param operation_config: Common configuration for operations classes of an MLClient object. + :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig + :param service_client_05_2022: Service client to allow end users to operate on Azure Machine Learning Workspace + resources. + :type service_client_05_2022: ~azure.ai.ml._restclient.v2022_05_01._azure_machine_learning_workspaces. + AzureMachineLearningWorkspaces + :param all_operations: All operations classes of an MLClient object. + :type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer + :param credentials: Credential to use for authentication. + :type credentials: ~azure.core.credentials.TokenCredential + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client_01_2024_preview: ServiceClient012024Preview, + all_operations: OperationsContainer, + credentials: Optional[TokenCredential] = None, + **kwargs: Any, + ): + super(BatchDeploymentOperations, self).__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._batch_deployment = service_client_01_2024_preview.batch_deployments + self._batch_job_deployment = kwargs.pop("service_client_09_2020_dataplanepreview").batch_job_deployment + service_client_02_2023_preview = kwargs.pop("service_client_02_2023_preview") + self._component_batch_deployment_operations = service_client_02_2023_preview.batch_deployments + self._batch_endpoint_operations = service_client_01_2024_preview.batch_endpoints + self._component_operations = service_client_02_2023_preview.component_versions + self._all_operations = all_operations + self._credentials = credentials + self._init_kwargs = kwargs + + self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline") + + @distributed_trace + @monitor_with_activity(ops_logger, "BatchDeployment.BeginCreateOrUpdate", ActivityType.PUBLICAPI) + def begin_create_or_update( + self, + deployment: DeploymentType, + *, + skip_script_validation: bool = False, + **kwargs: Any, + ) -> LROPoller[DeploymentType]: + """Create or update a batch deployment. + + :param deployment: The deployment entity. + :type deployment: ~azure.ai.ml.entities.BatchDeployment + :keyword skip_script_validation: If set to True, the script validation will be skipped. Defaults to False. + :paramtype skip_script_validation: bool + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if BatchDeployment cannot be + successfully validated. Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.AssetException: Raised if BatchDeployment assets + (e.g. Data, Code, Model, Environment) cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.ModelException: Raised if BatchDeployment model + cannot be successfully validated. Details will be provided in the error message. + :return: A poller to track the operation status. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.BatchDeployment] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START batch_deployment_operations_begin_create_or_update] + :end-before: [END batch_deployment_operations_begin_create_or_update] + :language: python + :dedent: 8 + :caption: Create example. + """ + if ( + not skip_script_validation + and not isinstance(deployment, PipelineComponentBatchDeployment) + and deployment + and deployment.code_configuration # type: ignore + and not deployment.code_configuration.code.startswith(ARM_ID_PREFIX) # type: ignore + and not re.match(AMLVersionedArmId.REGEX_PATTERN, deployment.code_configuration.code) # type: ignore + ): + validate_scoring_script(deployment) + module_logger.debug("Checking endpoint %s exists", deployment.endpoint_name) + self._batch_endpoint_operations.get( + endpoint_name=deployment.endpoint_name, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + ) + orchestrators = OperationOrchestrator( + operation_container=self._all_operations, + operation_scope=self._operation_scope, + operation_config=self._operation_config, + ) + if isinstance(deployment, PipelineComponentBatchDeployment): + self._validate_component(deployment, orchestrators) # type: ignore + else: + upload_dependencies(deployment, orchestrators) + try: + location = self._get_workspace_location() + if kwargs.pop("package_model", False): + deployment = package_deployment(deployment, self._all_operations.all_operations) + module_logger.info("\nStarting deployment") + deployment_rest = deployment._to_rest_object(location=location) + if isinstance(deployment, PipelineComponentBatchDeployment): # pylint: disable=no-else-return + return self._component_batch_deployment_operations.begin_create_or_update( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=deployment.endpoint_name, + deployment_name=deployment.name, + body=deployment_rest, + **self._init_kwargs, + cls=lambda response, deserialized, headers: PipelineComponentBatchDeployment._from_rest_object( + deserialized + ), + ) + else: + return self._batch_deployment.begin_create_or_update( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=deployment.endpoint_name, + deployment_name=deployment.name, + body=deployment_rest, + **self._init_kwargs, + cls=lambda response, deserialized, headers: BatchDeployment._from_rest_object(deserialized), + ) + except Exception as ex: + raise ex + + @distributed_trace + @monitor_with_activity(ops_logger, "BatchDeployment.Get", ActivityType.PUBLICAPI) + def get(self, name: str, endpoint_name: str) -> BatchDeployment: + """Get a deployment resource. + + :param name: The name of the deployment + :type name: str + :param endpoint_name: The name of the endpoint + :type endpoint_name: str + :return: A deployment entity + :rtype: ~azure.ai.ml.entities.BatchDeployment + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START batch_deployment_operations_get] + :end-before: [END batch_deployment_operations_get] + :language: python + :dedent: 8 + :caption: Get example. + """ + deployment = BatchDeployment._from_rest_object( + self._batch_deployment.get( + endpoint_name=endpoint_name, + deployment_name=name, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + **self._init_kwargs, + ) + ) + deployment.endpoint_name = endpoint_name + return deployment + + @distributed_trace + @monitor_with_activity(ops_logger, "BatchDeployment.BeginDelete", ActivityType.PUBLICAPI) + def begin_delete(self, name: str, endpoint_name: str) -> LROPoller[None]: + """Delete a batch deployment. + + :param name: Name of the batch deployment. + :type name: str + :param endpoint_name: Name of the batch endpoint + :type endpoint_name: str + :return: A poller to track the operation status. + :rtype: ~azure.core.polling.LROPoller[None] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START batch_deployment_operations_delete] + :end-before: [END batch_deployment_operations_delete] + :language: python + :dedent: 8 + :caption: Delete example. + """ + path_format_arguments = { + "endpointName": name, + "resourceGroupName": self._resource_group_name, + "workspaceName": self._workspace_name, + } + + delete_poller = self._batch_deployment.begin_delete( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=endpoint_name, + deployment_name=name, + polling=AzureMLPolling( + LROConfigurations.POLL_INTERVAL, + path_format_arguments=path_format_arguments, + **self._init_kwargs, + ), + polling_interval=LROConfigurations.POLL_INTERVAL, + **self._init_kwargs, + ) + return delete_poller + + @distributed_trace + @monitor_with_activity(ops_logger, "BatchDeployment.List", ActivityType.PUBLICAPI) + def list(self, endpoint_name: str) -> ItemPaged[BatchDeployment]: + """List a deployment resource. + + :param endpoint_name: The name of the endpoint + :type endpoint_name: str + :return: An iterator of deployment entities + :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.BatchDeployment] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START batch_deployment_operations_list] + :end-before: [END batch_deployment_operations_list] + :language: python + :dedent: 8 + :caption: List deployment resource example. + """ + return self._batch_deployment.list( + endpoint_name=endpoint_name, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + cls=lambda objs: [BatchDeployment._from_rest_object(obj) for obj in objs], + **self._init_kwargs, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "BatchDeployment.ListJobs", ActivityType.PUBLICAPI) + def list_jobs(self, endpoint_name: str, *, name: Optional[str] = None) -> ItemPaged[BatchJob]: + """List jobs under the provided batch endpoint deployment. This is only valid for batch endpoint. + + :param endpoint_name: Name of endpoint. + :type endpoint_name: str + :keyword name: (Optional) Name of deployment. + :paramtype name: str + :raise: Exception if endpoint_type is not BATCH_ENDPOINT_TYPE + :return: List of jobs + :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.BatchJob] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START batch_deployment_operations_list_jobs] + :end-before: [END batch_deployment_operations_list_jobs] + :language: python + :dedent: 8 + :caption: List jobs example. + """ + + workspace_operations = self._all_operations.all_operations[AzureMLResourceType.WORKSPACE] + mfe_base_uri = _get_mfe_base_url_from_discovery_service( + workspace_operations, self._workspace_name, self._requests_pipeline + ) + + with modified_operation_client(self._batch_job_deployment, mfe_base_uri): + result = self._batch_job_deployment.list( + endpoint_name=endpoint_name, + deployment_name=name, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + **self._init_kwargs, + ) + + # This is necessary as the paged result need to be resolved inside the context manager + return list(result) + + def _get_workspace_location(self) -> str: + """Get the workspace location + + TODO[TASK 1260265]: can we cache this information and only refresh when the operation_scope is changed? + + :return: The workspace location + :rtype: str + """ + return str( + self._all_operations.all_operations[AzureMLResourceType.WORKSPACE].get(self._workspace_name).location + ) + + def _validate_component(self, deployment: Any, orchestrators: OperationOrchestrator) -> None: + """Validates that the value provided is associated to an existing component or otherwise we will try to create + an anonymous component that will be use for batch deployment. + + :param deployment: Batch deployment + :type deployment: ~azure.ai.ml.entities._deployment.deployment.Deployment + :param orchestrators: Operation Orchestrator + :type orchestrators: _operation_orchestrator.OperationOrchestrator + """ + if isinstance(deployment.component, PipelineComponent): + try: + registered_component = self._all_operations.all_operations[AzureMLResourceType.COMPONENT].get( + name=deployment.component.name, version=deployment.component.version + ) + deployment.component = registered_component.id + except Exception as err: # pylint: disable=W0718 + if isinstance(err, (ResourceNotFoundError, HttpResponseError)): + deployment.component = self._all_operations.all_operations[ + AzureMLResourceType.COMPONENT + ].create_or_update( + name=deployment.component.name, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + component=deployment.component, + version=deployment.component.version, + **self._init_kwargs, + ) + else: + raise err + elif isinstance(deployment.component, str): + component_id = orchestrators.get_asset_arm_id( + deployment.component, azureml_type=AzureMLResourceType.COMPONENT + ) + deployment.component = component_id + elif isinstance(deployment.job_definition, str): + job_component = PipelineComponent(source_job_id=deployment.job_definition) + job_component = self._component_operations.create_or_update( + name=job_component.name, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + body=job_component._to_rest_object(), + version=job_component.version, + **self._init_kwargs, + ) + deployment.component = job_component.id + + elif isinstance(deployment.job_definition, PipelineJob): + try: + registered_job = self._all_operations.all_operations[AzureMLResourceType.JOB].get( + name=deployment.job_definition.name + ) + if registered_job: + job_component = PipelineComponent(source_job_id=registered_job.name) + job_component = self._component_operations.create_or_update( + name=job_component.name, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + body=job_component._to_rest_object(), + version=job_component.version, + **self._init_kwargs, + ) + deployment.component = job_component.id + except ResourceNotFoundError as err: + raise err diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_batch_endpoint_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_batch_endpoint_operations.py new file mode 100644 index 00000000..650185b3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_batch_endpoint_operations.py @@ -0,0 +1,553 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import json +import os +import re +import uuid +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Optional, cast + +from marshmallow.exceptions import ValidationError as SchemaValidationError + +from azure.ai.ml._artifacts._artifact_utilities import _upload_and_generate_remote_uri +from azure.ai.ml._azure_environments import _get_aml_resource_id_from_metadata, _resource_to_scopes +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._restclient.v2020_09_01_dataplanepreview.models import BatchJobResource +from azure.ai.ml._restclient.v2023_10_01 import AzureMachineLearningServices as ServiceClient102023 +from azure.ai.ml._schema._deployment.batch.batch_job import BatchJobSchema +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationsContainer, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._arm_id_utils import is_ARM_id_for_resource, remove_aml_prefix +from azure.ai.ml._utils._azureml_polling import AzureMLPolling +from azure.ai.ml._utils._endpoint_utils import convert_v1_dataset_to_v2, validate_response +from azure.ai.ml._utils._http_utils import HttpPipeline +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml._utils.utils import ( + _get_mfe_base_url_from_discovery_service, + is_private_preview_enabled, + modified_operation_client, +) +from azure.ai.ml.constants._common import ( + ARM_ID_FULL_PREFIX, + AZUREML_REGEX_FORMAT, + BASE_PATH_CONTEXT_KEY, + HTTP_PREFIX, + LONG_URI_REGEX_FORMAT, + PARAMS_OVERRIDE_KEY, + SHORT_URI_REGEX_FORMAT, + AssetTypes, + AzureMLResourceType, + InputTypes, + LROConfigurations, +) +from azure.ai.ml.constants._endpoint import EndpointInvokeFields, EndpointYamlFields +from azure.ai.ml.entities import BatchEndpoint, BatchJob +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, MlException, ValidationErrorType, ValidationException +from azure.core.credentials import TokenCredential +from azure.core.exceptions import HttpResponseError, ServiceRequestError, ServiceResponseError +from azure.core.paging import ItemPaged +from azure.core.polling import LROPoller +from azure.core.tracing.decorator import distributed_trace + +from ._operation_orchestrator import OperationOrchestrator + +if TYPE_CHECKING: + from azure.ai.ml.operations import DatastoreOperations + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class BatchEndpointOperations(_ScopeDependentOperations): + """BatchEndpointOperations. + + You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it + for you and attaches it as an attribute. + + :param operation_scope: Scope variables for the operations classes of an MLClient object. + :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope + :param operation_config: Common configuration for operations classes of an MLClient object. + :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig + :param service_client_10_2023: Service client to allow end users to operate on Azure Machine Learning Workspace + resources. + :type service_client_10_2023: ~azure.ai.ml._restclient.v2023_10_01._azure_machine_learning_workspaces. + AzureMachineLearningWorkspaces + :param all_operations: All operations classes of an MLClient object. + :type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer + :param credentials: Credential to use for authentication. + :type credentials: ~azure.core.credentials.TokenCredential + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client_10_2023: ServiceClient102023, + all_operations: OperationsContainer, + credentials: Optional[TokenCredential] = None, + **kwargs: Any, + ): + super(BatchEndpointOperations, self).__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._batch_operation = service_client_10_2023.batch_endpoints + self._batch_deployment_operation = service_client_10_2023.batch_deployments + self._batch_job_endpoint = kwargs.pop("service_client_09_2020_dataplanepreview").batch_job_endpoint + self._all_operations = all_operations + self._credentials = credentials + self._init_kwargs = kwargs + + self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline") + + @property + def _datastore_operations(self) -> "DatastoreOperations": + from azure.ai.ml.operations import DatastoreOperations + + return cast(DatastoreOperations, self._all_operations.all_operations[AzureMLResourceType.DATASTORE]) + + @distributed_trace + @monitor_with_activity(ops_logger, "BatchEndpoint.List", ActivityType.PUBLICAPI) + def list(self) -> ItemPaged[BatchEndpoint]: + """List endpoints of the workspace. + + :return: A list of endpoints + :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.BatchEndpoint] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START batch_endpoint_operations_list] + :end-before: [END batch_endpoint_operations_list] + :language: python + :dedent: 8 + :caption: List example. + """ + return self._batch_operation.list( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + cls=lambda objs: [BatchEndpoint._from_rest_object(obj) for obj in objs], + **self._init_kwargs, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "BatchEndpoint.Get", ActivityType.PUBLICAPI) + def get( + self, + name: str, + ) -> BatchEndpoint: + """Get a Endpoint resource. + + :param name: Name of the endpoint. + :type name: str + :return: Endpoint object retrieved from the service. + :rtype: ~azure.ai.ml.entities.BatchEndpoint + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START batch_endpoint_operations_get] + :end-before: [END batch_endpoint_operations_get] + :language: python + :dedent: 8 + :caption: Get endpoint example. + """ + # first get the endpoint + endpoint = self._batch_operation.get( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=name, + **self._init_kwargs, + ) + + endpoint_data = BatchEndpoint._from_rest_object(endpoint) + return endpoint_data + + @distributed_trace + @monitor_with_activity(ops_logger, "BatchEndpoint.BeginDelete", ActivityType.PUBLICAPI) + def begin_delete(self, name: str) -> LROPoller[None]: + """Delete a batch Endpoint. + + :param name: Name of the batch endpoint. + :type name: str + :return: A poller to track the operation status. + :rtype: ~azure.core.polling.LROPoller[None] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START batch_endpoint_operations_delete] + :end-before: [END batch_endpoint_operations_delete] + :language: python + :dedent: 8 + :caption: Delete endpoint example. + """ + path_format_arguments = { + "endpointName": name, + "resourceGroupName": self._resource_group_name, + "workspaceName": self._workspace_name, + } + + delete_poller = self._batch_operation.begin_delete( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=name, + polling=AzureMLPolling( + LROConfigurations.POLL_INTERVAL, + path_format_arguments=path_format_arguments, + **self._init_kwargs, + ), + polling_interval=LROConfigurations.POLL_INTERVAL, + **self._init_kwargs, + ) + return delete_poller + + @distributed_trace + @monitor_with_activity(ops_logger, "BatchEndpoint.BeginCreateOrUpdate", ActivityType.PUBLICAPI) + def begin_create_or_update(self, endpoint: BatchEndpoint) -> LROPoller[BatchEndpoint]: + """Create or update a batch endpoint. + + :param endpoint: The endpoint entity. + :type endpoint: ~azure.ai.ml.entities.BatchEndpoint + :return: A poller to track the operation status. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.BatchEndpoint] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START batch_endpoint_operations_create_or_update] + :end-before: [END batch_endpoint_operations_create_or_update] + :language: python + :dedent: 8 + :caption: Create endpoint example. + """ + + try: + location = self._get_workspace_location() + + endpoint_resource = endpoint._to_rest_batch_endpoint(location=location) + poller = self._batch_operation.begin_create_or_update( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=endpoint.name, + body=endpoint_resource, + polling=True, + **self._init_kwargs, + cls=lambda response, deserialized, headers: BatchEndpoint._from_rest_object(deserialized), + ) + return poller + + except Exception as ex: + if isinstance(ex, (ValidationException, SchemaValidationError)): + log_and_raise_error(ex) + raise ex + + @distributed_trace + @monitor_with_activity(ops_logger, "BatchEndpoint.Invoke", ActivityType.PUBLICAPI) + def invoke( # pylint: disable=too-many-statements + self, + endpoint_name: str, + *, + deployment_name: Optional[str] = None, + inputs: Optional[Dict[str, Input]] = None, + **kwargs: Any, + ) -> BatchJob: + """Invokes the batch endpoint with the provided payload. + + :param endpoint_name: The endpoint name. + :type endpoint_name: str + :keyword deployment_name: (Optional) The name of a specific deployment to invoke. This is optional. + By default requests are routed to any of the deployments according to the traffic rules. + :paramtype deployment_name: str + :keyword inputs: (Optional) A dictionary of existing data asset, public uri file or folder + to use with the deployment + :paramtype inputs: Dict[str, Input] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if deployment cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.AssetException: Raised if BatchEndpoint assets + (e.g. Data, Code, Model, Environment) cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.ModelException: Raised if BatchEndpoint model cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty directory. + :return: The invoked batch deployment job. + :rtype: ~azure.ai.ml.entities.BatchJob + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START batch_endpoint_operations_invoke] + :end-before: [END batch_endpoint_operations_invoke] + :language: python + :dedent: 8 + :caption: Invoke endpoint example. + """ + outputs = kwargs.get("outputs", None) + job_name = kwargs.get("job_name", None) + params_override = kwargs.get("params_override", None) or [] + experiment_name = kwargs.get("experiment_name", None) + input = kwargs.get("input", None) # pylint: disable=redefined-builtin + # Until this bug is resolved https://msdata.visualstudio.com/Vienna/_workitems/edit/1446538 + if deployment_name: + self._validate_deployment_name(endpoint_name, deployment_name) + + if input and isinstance(input, Input): + if HTTP_PREFIX not in input.path: + self._resolve_input(input, os.getcwd()) + # MFE expects a dictionary as input_data that's why we are using + # "UriFolder" or "UriFile" as keys depending on the input type + if input.type == "uri_folder": + params_override.append({EndpointYamlFields.BATCH_JOB_INPUT_DATA: {"UriFolder": input}}) + elif input.type == "uri_file": + params_override.append({EndpointYamlFields.BATCH_JOB_INPUT_DATA: {"UriFile": input}}) + else: + msg = ( + "Unsupported input type please use a dictionary of either a path on the datastore, public URI, " + "a registered data asset, or a local folder path." + ) + raise ValidationException( + message=msg, + target=ErrorTarget.BATCH_ENDPOINT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + elif inputs: + for key, input_data in inputs.items(): + if ( + isinstance(input_data, Input) + and input_data.type + not in [InputTypes.NUMBER, InputTypes.BOOLEAN, InputTypes.INTEGER, InputTypes.STRING] + and HTTP_PREFIX not in input_data.path + ): + self._resolve_input(input_data, os.getcwd()) + params_override.append({EndpointYamlFields.BATCH_JOB_INPUT_DATA: inputs}) + + properties = {} + + if outputs: + params_override.append({EndpointYamlFields.BATCH_JOB_OUTPUT_DATA: outputs}) + if job_name: + params_override.append({EndpointYamlFields.BATCH_JOB_NAME: job_name}) + if experiment_name: + properties["experimentName"] = experiment_name + + if properties: + params_override.append({EndpointYamlFields.BATCH_JOB_PROPERTIES: properties}) + + # Batch job doesn't have a python class, loading a rest object using params override + context = { + BASE_PATH_CONTEXT_KEY: Path(".").parent, + PARAMS_OVERRIDE_KEY: params_override, + } + + batch_job = BatchJobSchema(context=context).load(data={}) + # update output datastore to arm id if needed + # TODO: Unify datastore name -> arm id logic, TASK: 1104172 + request = {} + if ( + batch_job.output_dataset + and batch_job.output_dataset.datastore_id + and (not is_ARM_id_for_resource(batch_job.output_dataset.datastore_id)) + ): + v2_dataset_dictionary = convert_v1_dataset_to_v2(batch_job.output_dataset, batch_job.output_file_name) + batch_job.output_dataset = None + batch_job.output_file_name = None + request = BatchJobResource(properties=batch_job).serialize() + request["properties"]["outputData"] = v2_dataset_dictionary + else: + request = BatchJobResource(properties=batch_job).serialize() + + endpoint = self._batch_operation.get( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=endpoint_name, + **self._init_kwargs, + ) + + headers = EndpointInvokeFields.DEFAULT_HEADER + ml_audience_scopes = _resource_to_scopes(_get_aml_resource_id_from_metadata()) + module_logger.debug("ml_audience_scopes used: `%s`\n", ml_audience_scopes) + key = self._credentials.get_token(*ml_audience_scopes).token if self._credentials is not None else "" + headers[EndpointInvokeFields.AUTHORIZATION] = f"Bearer {key}" + headers[EndpointInvokeFields.REPEATABILITY_REQUEST_ID] = str(uuid.uuid4()) + + if deployment_name: + headers[EndpointInvokeFields.MODEL_DEPLOYMENT] = deployment_name + + retry_attempts = 0 + while retry_attempts < 5: + try: + response = self._requests_pipeline.post( + endpoint.properties.scoring_uri, + json=request, + headers=headers, + ) + except (ServiceRequestError, ServiceResponseError): + retry_attempts += 1 + continue + break + if retry_attempts == 5: + retry_msg = "Max retry attempts reached while trying to connect to server. Please check connection and invoke again." # pylint: disable=line-too-long + raise MlException(message=retry_msg, no_personal_data_message=retry_msg, target=ErrorTarget.BATCH_ENDPOINT) + validate_response(response) + batch_job = json.loads(response.text()) + return BatchJobResource.deserialize(batch_job) + + @distributed_trace + @monitor_with_activity(ops_logger, "BatchEndpoint.ListJobs", ActivityType.PUBLICAPI) + def list_jobs(self, endpoint_name: str) -> ItemPaged[BatchJob]: + """List jobs under the provided batch endpoint deployment. This is only valid for batch endpoint. + + :param endpoint_name: The endpoint name + :type endpoint_name: str + :return: List of jobs + :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.BatchJob] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START batch_endpoint_operations_list_jobs] + :end-before: [END batch_endpoint_operations_list_jobs] + :language: python + :dedent: 8 + :caption: List jobs example. + """ + + workspace_operations = self._all_operations.all_operations[AzureMLResourceType.WORKSPACE] + mfe_base_uri = _get_mfe_base_url_from_discovery_service( + workspace_operations, self._workspace_name, self._requests_pipeline + ) + + with modified_operation_client(self._batch_job_endpoint, mfe_base_uri): + result = self._batch_job_endpoint.list( + endpoint_name=endpoint_name, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + **self._init_kwargs, + ) + + # This is necessary as the paged result need to be resolved inside the context manager + return list(result) + + def _get_workspace_location(self) -> str: + return str( + self._all_operations.all_operations[AzureMLResourceType.WORKSPACE].get(self._workspace_name).location + ) + + def _validate_deployment_name(self, endpoint_name: str, deployment_name: str) -> None: + deployments_list = self._batch_deployment_operation.list( + endpoint_name=endpoint_name, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + cls=lambda objs: [obj.name for obj in objs], + **self._init_kwargs, + ) + if deployments_list: + if deployment_name not in deployments_list: + msg = f"Deployment name {deployment_name} not found for this endpoint" + raise ValidationException( + message=msg.format(deployment_name), + no_personal_data_message=msg.format("[deployment_name]"), + target=ErrorTarget.DEPLOYMENT, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.RESOURCE_NOT_FOUND, + ) + else: + msg = "No deployment exists for this endpoint" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.DEPLOYMENT, + error_type=ValidationErrorType.RESOURCE_NOT_FOUND, + ) + + def _resolve_input(self, entry: Input, base_path: str) -> None: + # We should not verify anything that is not of type Input + if not isinstance(entry, Input): + return + + # Input path should not be empty + if not entry.path: + msg = "Input path can't be empty for batch endpoint invoke" + raise MlException(message=msg, no_personal_data_message=msg) + + if entry.type in [InputTypes.NUMBER, InputTypes.BOOLEAN, InputTypes.INTEGER, InputTypes.STRING]: + return + + try: + if entry.path.startswith(ARM_ID_FULL_PREFIX): + if not is_ARM_id_for_resource(entry.path, AzureMLResourceType.DATA): + raise ValidationException( + message="Invalid input path", + target=ErrorTarget.BATCH_ENDPOINT, + no_personal_data_message="Invalid input path", + error_type=ValidationErrorType.INVALID_VALUE, + ) + elif os.path.isabs(entry.path): # absolute local path, upload, transform to remote url + if entry.type == AssetTypes.URI_FOLDER and not os.path.isdir(entry.path): + raise ValidationException( + message="There is no folder on target path: {}".format(entry.path), + target=ErrorTarget.BATCH_ENDPOINT, + no_personal_data_message="There is no folder on target path", + error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND, + ) + if entry.type == AssetTypes.URI_FILE and not os.path.isfile(entry.path): + raise ValidationException( + message="There is no file on target path: {}".format(entry.path), + target=ErrorTarget.BATCH_ENDPOINT, + no_personal_data_message="There is no file on target path", + error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND, + ) + # absolute local path + entry.path = _upload_and_generate_remote_uri( + self._operation_scope, + self._datastore_operations, + entry.path, + ) + if entry.type == AssetTypes.URI_FOLDER and entry.path and not entry.path.endswith("/"): + entry.path = entry.path + "/" + elif ":" in entry.path or "@" in entry.path: # Check registered file or folder datastore + # If we receive a datastore path in long/short form we don't need + # to get the arm asset id + if re.match(SHORT_URI_REGEX_FORMAT, entry.path) or re.match(LONG_URI_REGEX_FORMAT, entry.path): + return + if is_private_preview_enabled() and re.match(AZUREML_REGEX_FORMAT, entry.path): + return + asset_type = AzureMLResourceType.DATA + entry.path = remove_aml_prefix(entry.path) + orchestrator = OperationOrchestrator( + self._all_operations, self._operation_scope, self._operation_config + ) + entry.path = orchestrator.get_asset_arm_id(entry.path, asset_type) + else: # relative local path, upload, transform to remote url + local_path = Path(base_path, entry.path).resolve() + entry.path = _upload_and_generate_remote_uri( + self._operation_scope, + self._datastore_operations, + local_path, + ) + if entry.type == AssetTypes.URI_FOLDER and entry.path and not entry.path.endswith("/"): + entry.path = entry.path + "/" + except (MlException, HttpResponseError) as e: + raise e + except Exception as e: + raise ValidationException( + message=f"Supported input path value are: path on the datastore, public URI, " + "a registered data asset, or a local folder path.\n" + f"Met {type(e)}:\n{e}", + target=ErrorTarget.BATCH_ENDPOINT, + no_personal_data_message="Supported input path value are: path on the datastore, " + "public URI, a registered data asset, or a local folder path.", + error=e, + error_type=ValidationErrorType.INVALID_VALUE, + ) from e diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_capability_hosts_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_capability_hosts_operations.py new file mode 100644 index 00000000..e8cc03f2 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_capability_hosts_operations.py @@ -0,0 +1,304 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access + +from typing import Any, List + +from marshmallow.exceptions import ValidationError as SchemaValidationError + +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._restclient.v2024_10_01_preview import AzureMachineLearningWorkspaces as ServiceClient102024Preview +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationsContainer, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml.constants._common import DEFAULT_STORAGE_CONNECTION_NAME, WorkspaceKind +from azure.ai.ml.entities._workspace._ai_workspaces.capability_host import CapabilityHost +from azure.ai.ml.entities._workspace.workspace import Workspace +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException +from azure.core.credentials import TokenCredential +from azure.core.polling import LROPoller +from azure.core.tracing.decorator import distributed_trace + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class CapabilityHostsOperations(_ScopeDependentOperations): + """CapabilityHostsOperations. + + You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it + for you and attaches it as an attribute. + + :param operation_scope: Scope variables for the operations classes of an MLClient object. + :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope + :param operation_config: Common configuration for operations classes of an MLClient object. + :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig + :param service_client_10_2024: Service client to allow end users to operate on Azure Machine Learning Workspace + resources (ServiceClient102024Preview). + :type service_client_10_2024: ~azure.ai.ml._restclient.v2024_10_01_preview._azure_machine_learning_workspaces.AzureMachineLearningWorkspaces # pylint: disable=line-too-long + :param all_operations: All operations classes of an MLClient object. + :type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer + :param credentials: Credential to use for authentication. + :type credentials: ~azure.core.credentials.TokenCredential + :param kwargs: Additional keyword arguments. + :type kwargs: Any + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client_10_2024: ServiceClient102024Preview, + all_operations: OperationsContainer, + credentials: TokenCredential, + **kwargs: Any, + ): + """Constructor of CapabilityHostsOperations class. + + :param operation_scope: Scope variables for the operations classes of an MLClient object. + :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope + :param operation_config: Common configuration for operations classes of an MLClient object. + :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig + :param service_client_10_2024: Service client to allow end users to operate on Azure Machine Learning Workspace + resources (ServiceClient102024Preview). + :type service_client_10_2024: ~azure.ai.ml._restclient.v2024_10_01_preview._azure_machine_learning_workspaces.AzureMachineLearningWorkspaces # pylint: disable=line-too-long + :param all_operations: All operations classes of an MLClient object. + :type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer + :param credentials: Credential to use for authentication. + :type credentials: ~azure.core.credentials.TokenCredential + :param kwargs: Additional keyword arguments. + :type kwargs: Any + """ + + super(CapabilityHostsOperations, self).__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._all_operations = all_operations + self._capability_hosts_operations = service_client_10_2024.capability_hosts + self._workspace_operations = service_client_10_2024.workspaces + self._credentials = credentials + self._init_kwargs = kwargs + + @experimental + @monitor_with_activity(ops_logger, "CapabilityHost.Get", ActivityType.PUBLICAPI) + @distributed_trace + def get(self, name: str, **kwargs: Any) -> CapabilityHost: + """Retrieve a capability host resource. + + :param name: The name of the capability host to retrieve. + :type name: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if project name or hub name + not provided while creation of MLClient object in workspacename param. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Capabilityhost name is not provided. + Details will be provided in the error message. + :return: CapabilityHost object. + :rtype: ~azure.ai.ml.entities._workspace._ai_workspaces.capability_host.CapabilityHost + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_capability_host.py + :start-after: [START capability_host_get_operation] + :end-before: [END capability_host_get_operation] + :language: python + :dedent: 8 + :caption: Get example. + """ + + self._validate_workspace_name() + + rest_obj = self._capability_hosts_operations.get( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + name=name, + **kwargs, + ) + + capability_host = CapabilityHost._from_rest_object(rest_obj) + + return capability_host + + @experimental + @monitor_with_activity(ops_logger, "CapabilityHost.BeginCreateOrUpdate", ActivityType.PUBLICAPI) + @distributed_trace + def begin_create_or_update(self, capability_host: CapabilityHost, **kwargs: Any) -> LROPoller[CapabilityHost]: + """Begin the creation of a capability host in a Hub or Project workspace. + Note that currently this method can only accept the `create` operation request + and not `update` operation request. + + :param capability_host: The CapabilityHost object containing the details of the capability host to create. + :type capability_host: ~azure.ai.ml.entities.CapabilityHost + :return: An LROPoller object that can be used to track the long-running + operation that is creation of capability host. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities._workspace._ai_workspaces.capability_host.CapabilityHost] # pylint: disable=line-too-long + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_capability_host.py + :start-after: [START capability_host_begin_create_or_update_operation] + :end-before: [END capability_host_begin_create_or_update_operation] + :language: python + :dedent: 8 + :caption: Create example. + """ + try: + self._validate_workspace_name() + + workspace = self._get_workspace() + + self._validate_workspace_kind(workspace._kind) + + self._validate_properties(capability_host, workspace._kind) + + if workspace._kind == WorkspaceKind.PROJECT: + if capability_host.storage_connections is None or len(capability_host.storage_connections) == 0: + capability_host.storage_connections = self._get_default_storage_connections() + + capability_host_resource = capability_host._to_rest_object() + + poller = self._capability_hosts_operations.begin_create_or_update( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + name=capability_host.name, + body=capability_host_resource, + polling=True, + **kwargs, + cls=lambda response, deserialized, headers: CapabilityHost._from_rest_object(deserialized), + ) + return poller + + except Exception as ex: + if isinstance(ex, (ValidationException, SchemaValidationError)): + log_and_raise_error(ex) + raise ex + + @experimental + @distributed_trace + @monitor_with_activity(ops_logger, "CapabilityHost.Delete", ActivityType.PUBLICAPI) + def begin_delete( + self, + name: str, + **kwargs: Any, + ) -> LROPoller[None]: + """Delete capability host. + + :param name: capability host name. + :type name: str + :return: A poller for deletion status + :rtype: ~azure.core.polling.LROPoller[None] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_capability_host.py + :start-after: [START capability_host_delete_operation] + :end-before: [END capability_host_delete_operation] + :language: python + :dedent: 8 + :caption: Delete example. + """ + poller = self._capability_hosts_operations.begin_delete( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + name=name, + polling=True, + **kwargs, + ) + return poller + + def _get_default_storage_connections(self) -> List[str]: + """Retrieve the default storage connections for a capability host. + + :return: A list of default storage connections. + :rtype: List[str] + """ + return [f"{self._workspace_name}/{DEFAULT_STORAGE_CONNECTION_NAME}"] + + def _validate_workspace_kind(self, workspace_kind: str) -> None: + """Validate the workspace kind, it should be either hub or project only. + + :param workspace_kind: The kind of the workspace, either hub or project only. + :type workspace_kind: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if workspace kind is not Hub or Project. + Details will be provided in the error message. + :return: None, or the result of cls(response) + :rtype: None + """ + + valid_kind = workspace_kind in {WorkspaceKind.HUB, WorkspaceKind.PROJECT} + if not valid_kind: + msg = f"Invalid workspace kind: {workspace_kind}. Workspace kind should be either 'Hub' or 'Project'." + raise ValidationException( + message=msg, + target=ErrorTarget.CAPABILITY_HOST, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + + def _validate_properties(self, capability_host: CapabilityHost, workspace_kind: str) -> None: + """Validate the properties of the capability host for project workspace. + + :param capability_host: The capability host to validate. + :type capability_host: CapabilityHost + :param workspace_kind: The kind of the workspace, Project only. + :type workspace_kind: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the OpenAI service connection or + vector store (AISearch) connection is empty for a Project workspace kind. + Details will be provided in the error message. + :return: None, or the result of cls(response) + :rtype: None + """ + + if workspace_kind == WorkspaceKind.PROJECT: + if capability_host.ai_services_connections is None or capability_host.vector_store_connections is None: + msg = "For Project workspace kind, OpenAI service connections and vector store (AISearch) connections are required." # pylint: disable=line-too-long + raise ValidationException( + message=msg, + target=ErrorTarget.CAPABILITY_HOST, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + + def _get_workspace(self) -> Workspace: + """Retrieve the workspace object. + + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if specified Hub or Project do not exist. + Details will be provided in the error message. + :return: Hub or Project object if it exists + :rtype: ~azure.ai.ml.entities._workspace.workspace.Workspace + """ + rest_workspace = self._workspace_operations.get(self._resource_group_name, self._workspace_name) + workspace = Workspace._from_rest_object(rest_workspace) + if workspace is None: + msg = f"Workspace with name {self._workspace_name} does not exist." + raise ValidationException( + message=msg, + target=ErrorTarget.CAPABILITY_HOST, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + return workspace + + def _validate_workspace_name(self) -> None: + """Validates that a hub name or project name is set in the MLClient's workspace name parameter. + + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if project name or hub name + not provided while creation of + MLClient object in workspacename param. Details will be provided in the error message. + :return: None, or the result of cls(response) + :rtype: None + """ + workspace_name = self._workspace_name + if not workspace_name: + msg = "Please pass either a hub name or project name to the workspace_name parameter when initializing an MLClient object." # pylint: disable=line-too-long + raise ValidationException( + message=msg, + target=ErrorTarget.CAPABILITY_HOST, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_code_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_code_operations.py new file mode 100644 index 00000000..7b0e64c3 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_code_operations.py @@ -0,0 +1,307 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import re +from os import PathLike +from pathlib import Path +from typing import Dict, Optional, Union + +from marshmallow.exceptions import ValidationError as SchemaValidationError + +from azure.ai.ml._artifacts._artifact_utilities import ( + _check_and_upload_path, + _get_datastore_name, + _get_snapshot_path_info, + get_datastore_info, +) +from azure.ai.ml._artifacts._constants import ( + ASSET_PATH_ERROR, + CHANGED_ASSET_PATH_MSG, + CHANGED_ASSET_PATH_MSG_NO_PERSONAL_DATA, +) +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import ( + AzureMachineLearningWorkspaces as ServiceClient102021Dataplane, +) +from azure.ai.ml._restclient.v2022_10_01_preview import AzureMachineLearningWorkspaces as ServiceClient102022 +from azure.ai.ml._restclient.v2023_04_01 import AzureMachineLearningWorkspaces as ServiceClient042023 +from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._asset_utils import ( + _get_existing_asset_name_and_version, + get_content_hash_version, + get_storage_info_for_non_registry_asset, +) +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml._utils._registry_utils import get_asset_body_for_registry_storage, get_sas_uri_for_registry_asset +from azure.ai.ml._utils._storage_utils import get_storage_client +from azure.ai.ml.entities._assets import Code +from azure.ai.ml.exceptions import ( + AssetPathException, + ErrorCategory, + ErrorTarget, + ValidationErrorType, + ValidationException, +) +from azure.ai.ml.operations._datastore_operations import DatastoreOperations +from azure.core.exceptions import HttpResponseError + +# pylint: disable=protected-access + + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class CodeOperations(_ScopeDependentOperations): + """Represents a client for performing operations on code assets. + + You should not instantiate this class directly. Instead, you should create MLClient and use this client via the + property MLClient.code + + :param operation_scope: Scope variables for the operations classes of an MLClient object. + :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope + :param operation_config: Common configuration for operations classes of an MLClient object. + :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig + :param service_client: Service client to allow end users to operate on Azure Machine Learning Workspace resources. + :type service_client: typing.Union[ + ~azure.ai.ml._restclient.v2022_10_01_preview._azure_machine_learning_workspaces.AzureMachineLearningWorkspaces, + ~azure.ai.ml._restclient.v2021_10_01_dataplanepreview._azure_machine_learning_workspaces. + AzureMachineLearningWorkspaces, + ~azure.ai.ml._restclient.v2023_04_01._azure_machine_learning_workspaces.AzureMachineLearningWorkspaces] + :param datastore_operations: Represents a client for performing operations on Datastores. + :type datastore_operations: ~azure.ai.ml.operations._datastore_operations.DatastoreOperations + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client: Union[ServiceClient102022, ServiceClient102021Dataplane, ServiceClient042023], + datastore_operations: DatastoreOperations, + **kwargs: Dict, + ): + super(CodeOperations, self).__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._service_client = service_client + self._version_operation = service_client.code_versions + self._container_operation = service_client.code_containers + self._datastore_operation = datastore_operations + self._init_kwargs = kwargs + + @monitor_with_activity(ops_logger, "Code.CreateOrUpdate", ActivityType.PUBLICAPI) + def create_or_update(self, code: Code) -> Code: + """Returns created or updated code asset. + + If not already in storage, asset will be uploaded to the workspace's default datastore. + + :param code: Code asset object. + :type code: Code + :raises ~azure.ai.ml.exceptions.AssetPathException: Raised when the Code artifact path is + already linked to another asset + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Code cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty directory. + :return: Code asset object. + :rtype: ~azure.ai.ml.entities.Code + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START code_operations_create_or_update] + :end-before: [END code_operations_create_or_update] + :language: python + :dedent: 8 + :caption: Create code asset example. + """ + try: + name = code.name + version = code.version + sas_uri = None + blob_uri = None + + if self._registry_name: + sas_uri = get_sas_uri_for_registry_asset( + service_client=self._service_client, + name=name, + version=version, + resource_group=self._resource_group_name, + registry=self._registry_name, + body=get_asset_body_for_registry_storage(self._registry_name, "codes", name, version), + ) + else: + snapshot_path_info = _get_snapshot_path_info(code) + if snapshot_path_info: + _, _, asset_hash = snapshot_path_info + existing_assets = list( + self._version_operation.list( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + name=name, + hash=asset_hash, + hash_version=str(get_content_hash_version()), + ) + ) + + if len(existing_assets) > 0: + existing_asset = existing_assets[0] + name, version = _get_existing_asset_name_and_version(existing_asset) + return self.get(name=name, version=version) + sas_info = get_storage_info_for_non_registry_asset( + service_client=self._service_client, + workspace_name=self._workspace_name, + name=name, + version=version, + resource_group=self._resource_group_name, + ) + sas_uri = sas_info["sas_uri"] + blob_uri = sas_info["blob_uri"] + + code, _ = _check_and_upload_path( + artifact=code, + asset_operations=self, + sas_uri=sas_uri, + artifact_type=ErrorTarget.CODE, + show_progress=self._show_progress, + blob_uri=blob_uri, + ) + + # For anonymous code, if the code already exists in storage, we reuse the name, + # version stored in the storage metadata so the same anonymous code won't be created again. + if code._is_anonymous: + name = code.name + version = code.version + + code_version_resource = code._to_rest_object() + + result = ( + self._version_operation.begin_create_or_update( + name=name, + version=version, + registry_name=self._registry_name, + resource_group_name=self._operation_scope.resource_group_name, + body=code_version_resource, + **self._init_kwargs, + ).result() + if self._registry_name + else self._version_operation.create_or_update( + name=name, + version=version, + workspace_name=self._workspace_name, + resource_group_name=self._operation_scope.resource_group_name, + body=code_version_resource, + **self._init_kwargs, + ) + ) + + if not result: + return self.get(name=name, version=version) + return Code._from_rest_object(result) + except Exception as ex: + if isinstance(ex, (ValidationException, SchemaValidationError)): + log_and_raise_error(ex) + elif isinstance(ex, HttpResponseError): + # service side raises an exception if we attempt to update an existing asset's asset path + if str(ex) == ASSET_PATH_ERROR: + raise AssetPathException( + message=CHANGED_ASSET_PATH_MSG, + target=ErrorTarget.CODE, + no_personal_data_message=CHANGED_ASSET_PATH_MSG_NO_PERSONAL_DATA, + error_category=ErrorCategory.USER_ERROR, + ) from ex + raise ex + + @monitor_with_activity(ops_logger, "Code.Get", ActivityType.PUBLICAPI) + def get(self, name: str, version: str) -> Code: + """Returns information about the specified code asset. + + :param name: Name of the code asset. + :type name: str + :param version: Version of the code asset. + :type version: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Code cannot be successfully validated. + Details will be provided in the error message. + :return: Code asset object. + :rtype: ~azure.ai.ml.entities.Code + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START code_operations_get] + :end-before: [END code_operations_get] + :language: python + :dedent: 8 + :caption: Get code asset example. + """ + return self._get(name=name, version=version) + + # this is a public API but CodeOperations is hidden, so it may only monitor internal calls + @monitor_with_activity(ops_logger, "Code.Download", ActivityType.PUBLICAPI) + def download(self, name: str, version: str, download_path: Union[PathLike, str]) -> None: + """Download content of a code. + + :param str name: Name of the code. + :param str version: Version of the code. + :param Union[PathLike, str] download_path: Local path as download destination, + defaults to current working directory of the current user. Contents will be overwritten. + :raise: ResourceNotFoundError if can't find a code matching provided name. + """ + output_dir = Path(download_path) + if output_dir.is_dir(): + # an OSError will be raised if the directory is not empty + output_dir.rmdir() + output_dir.mkdir(parents=True) + + code = self._get(name=name, version=version) + + # TODO: how should we maintain this regex? + m = re.match( + r"https://(?P<account_name>.+)\.blob\.core\.windows\.net" + r"(:[0-9]+)?/(?P<container_name>.+)/(?P<blob_name>.*)", + str(code.path), + ) + if not m: + raise ValueError(f"Invalid code path: {code.path}") + + datastore_info = get_datastore_info( + self._datastore_operation, + # always use WORKSPACE_BLOB_STORE + name=_get_datastore_name(), + container_name=m.group("container_name"), + ) + storage_client = get_storage_client(**datastore_info) + storage_client.download( + starts_with=m.group("blob_name"), + destination=output_dir.as_posix(), + ) + if not output_dir.is_dir() or not any(output_dir.iterdir()): + raise RuntimeError(f"Failed to download code to {output_dir}") + + def _get(self, name: str, version: Optional[str] = None) -> Code: + if not version: + msg = "Code asset version must be specified as part of name parameter, in format 'name:version'." + raise ValidationException( + message=msg, + target=ErrorTarget.CODE, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + code_version_resource = ( + self._version_operation.get( + name=name, + version=version, + resource_group_name=self._operation_scope.resource_group_name, + registry_name=self._registry_name, + **self._init_kwargs, + ) + if self._registry_name + else self._version_operation.get( + name=name, + version=version, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + **self._init_kwargs, + ) + ) + return Code._from_rest_object(code_version_resource) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_component_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_component_operations.py new file mode 100644 index 00000000..f9e43f1d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_component_operations.py @@ -0,0 +1,1289 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access,too-many-lines +import time +import collections +import types +from functools import partial +from inspect import Parameter, signature +from os import PathLike +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast +import hashlib + +from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import ( + AzureMachineLearningWorkspaces as ServiceClient102021Dataplane, +) +from azure.ai.ml._restclient.v2024_01_01_preview import ( + AzureMachineLearningWorkspaces as ServiceClient012024, +) +from azure.ai.ml._restclient.v2024_01_01_preview.models import ( + ComponentVersion, + ListViewType, +) +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationsContainer, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._telemetry import ( + ActivityType, + monitor_with_activity, + monitor_with_telemetry_mixin, +) +from azure.ai.ml._utils._asset_utils import ( + _archive_or_restore, + _create_or_update_autoincrement, + _get_file_hash, + _get_latest, + _get_next_version_from_container, + _resolve_label_to_asset, + get_ignore_file, + get_upload_files_from_folder, + IgnoreFile, + delete_two_catalog_files, + create_catalog_files, +) +from azure.ai.ml._utils._azureml_polling import AzureMLPolling +from azure.ai.ml._utils._endpoint_utils import polling_wait +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml._vendor.azure_resources.operations import DeploymentsOperations +from azure.ai.ml.constants._common import ( + DEFAULT_COMPONENT_VERSION, + DEFAULT_LABEL_NAME, + AzureMLResourceType, + DefaultOpenEncoding, + LROConfigurations, +) +from azure.ai.ml.entities import Component, ValidationResult +from azure.ai.ml.exceptions import ( + ComponentException, + ErrorCategory, + ErrorTarget, + ValidationException, +) +from azure.core.exceptions import HttpResponseError, ResourceNotFoundError + +from .._utils._cache_utils import CachedNodeResolver +from .._utils._experimental import experimental +from .._utils.utils import extract_name_and_version, is_data_binding_expression +from ..entities._builders import BaseNode +from ..entities._builders.condition_node import ConditionNode +from ..entities._builders.control_flow_node import LoopNode +from ..entities._component.automl_component import AutoMLComponent +from ..entities._component.code import ComponentCodeMixin +from ..entities._component.pipeline_component import PipelineComponent +from ..entities._job.pipeline._attr_dict import has_attr_safe +from ._code_operations import CodeOperations +from ._environment_operations import EnvironmentOperations +from ._operation_orchestrator import OperationOrchestrator, _AssetResolver +from ._workspace_operations import WorkspaceOperations + +ops_logger = OpsLogger(__name__) +logger, module_logger = ops_logger.package_logger, ops_logger.module_logger + + +class ComponentOperations(_ScopeDependentOperations): + """ComponentOperations. + + You should not instantiate this class directly. Instead, you should + create an MLClient instance that instantiates it for you and + attaches it as an attribute. + + :param operation_scope: The operation scope. + :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope + :param operation_config: The operation configuration. + :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig + :param service_client: The service client for API operations. + :type service_client: Union[ + ~azure.ai.ml._restclient.v2022_10_01.AzureMachineLearningWorkspaces, + ~azure.ai.ml._restclient.v2021_10_01_dataplanepreview.AzureMachineLearningWorkspaces] + :param all_operations: The container for all available operations. + :type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer + :param preflight_operation: The preflight operation for deployments. + :type preflight_operation: Optional[~azure.ai.ml._vendor.azure_resources.operations.DeploymentsOperations] + :param kwargs: Additional keyword arguments. + :type kwargs: Dict + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client: Union[ServiceClient012024, ServiceClient102021Dataplane], + all_operations: OperationsContainer, + preflight_operation: Optional[DeploymentsOperations] = None, + **kwargs: Dict, + ) -> None: + super(ComponentOperations, self).__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._version_operation = service_client.component_versions + self._preflight_operation = preflight_operation + self._container_operation = service_client.component_containers + self._all_operations = all_operations + self._init_args = kwargs + # Maps a label to a function which given an asset name, + # returns the asset associated with the label + self._managed_label_resolver = {"latest": self._get_latest_version} + self._orchestrators = OperationOrchestrator(self._all_operations, self._operation_scope, self._operation_config) + + self._client_key: Optional[str] = None + + @property + def _code_operations(self) -> CodeOperations: + res: CodeOperations = self._all_operations.get_operation( # type: ignore[misc] + AzureMLResourceType.CODE, lambda x: isinstance(x, CodeOperations) + ) + return res + + @property + def _environment_operations(self) -> EnvironmentOperations: + return cast( + EnvironmentOperations, + self._all_operations.get_operation( # type: ignore[misc] + AzureMLResourceType.ENVIRONMENT, + lambda x: isinstance(x, EnvironmentOperations), + ), + ) + + @property + def _workspace_operations(self) -> WorkspaceOperations: + return cast( + WorkspaceOperations, + self._all_operations.get_operation( # type: ignore[misc] + AzureMLResourceType.WORKSPACE, + lambda x: isinstance(x, WorkspaceOperations), + ), + ) + + @property + def _job_operations(self) -> Any: + from ._job_operations import JobOperations + + return self._all_operations.get_operation( # type: ignore[misc] + AzureMLResourceType.JOB, lambda x: isinstance(x, JobOperations) + ) + + @monitor_with_activity(ops_logger, "Component.List", ActivityType.PUBLICAPI) + def list( + self, + name: Union[str, None] = None, + *, + list_view_type: ListViewType = ListViewType.ACTIVE_ONLY, + ) -> Iterable[Component]: + """List specific component or components of the workspace. + + :param name: Component name, if not set, list all components of the workspace + :type name: Optional[str] + :keyword list_view_type: View type for including/excluding (for example) archived components. + Default: ACTIVE_ONLY. + :type list_view_type: Optional[ListViewType] + :return: An iterator like instance of component objects + :rtype: ~azure.core.paging.ItemPaged[Component] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START component_operations_list] + :end-before: [END component_operations_list] + :language: python + :dedent: 8 + :caption: List component example. + """ + + if name: + return cast( + Iterable[Component], + ( + self._version_operation.list( + name=name, + resource_group_name=self._resource_group_name, + registry_name=self._registry_name, + **self._init_args, + cls=lambda objs: [Component._from_rest_object(obj) for obj in objs], + ) + if self._registry_name + else self._version_operation.list( + name=name, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + list_view_type=list_view_type, + **self._init_args, + cls=lambda objs: [Component._from_rest_object(obj) for obj in objs], + ) + ), + ) + return cast( + Iterable[Component], + ( + self._container_operation.list( + resource_group_name=self._resource_group_name, + registry_name=self._registry_name, + **self._init_args, + cls=lambda objs: [Component._from_container_rest_object(obj) for obj in objs], + ) + if self._registry_name + else self._container_operation.list( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + list_view_type=list_view_type, + **self._init_args, + cls=lambda objs: [Component._from_container_rest_object(obj) for obj in objs], + ) + ), + ) + + @monitor_with_telemetry_mixin(ops_logger, "ComponentVersion.Get", ActivityType.INTERNALCALL) + def _get_component_version(self, name: str, version: Optional[str] = DEFAULT_COMPONENT_VERSION) -> ComponentVersion: + """Returns ComponentVersion information about the specified component name and version. + + :param name: Name of the code component. + :type name: str + :param version: Version of the component. + :type version: Optional[str] + :return: The ComponentVersion object of the specified component name and version. + :rtype: ~azure.ai.ml.entities.ComponentVersion + """ + result = ( + self._version_operation.get( + name=name, + version=version, + resource_group_name=self._resource_group_name, + registry_name=self._registry_name, + **self._init_args, + ) + if self._registry_name + else self._version_operation.get( + name=name, + version=version, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + **self._init_args, + ) + ) + return result + + @monitor_with_telemetry_mixin(ops_logger, "Component.Get", ActivityType.PUBLICAPI) + def get(self, name: str, version: Optional[str] = None, label: Optional[str] = None) -> Component: + """Returns information about the specified component. + + :param name: Name of the code component. + :type name: str + :param version: Version of the component. + :type version: Optional[str] + :param label: Label of the component, mutually exclusive with version. + :type label: Optional[str] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Component cannot be successfully + identified and retrieved. Details will be provided in the error message. + :return: The specified component object. + :rtype: ~azure.ai.ml.entities.Component + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START component_operations_get] + :end-before: [END component_operations_get] + :language: python + :dedent: 8 + :caption: Get component example. + """ + return self._get(name=name, version=version, label=label) + + def _localize_code(self, component: Component, base_dir: Path) -> None: + if not isinstance(component, ComponentCodeMixin): + return + code = component._get_origin_code_value() + if not isinstance(code, str): + return + # registry code will keep the "azureml:" prefix can be used directly + if code.startswith("azureml://registries"): + return + + target_code_value = "./code" + self._code_operations.download( + **extract_name_and_version(code), + download_path=base_dir.joinpath(target_code_value), + ) + + setattr(component, component._get_code_field_name(), target_code_value) + + def _localize_environment(self, component: Component, base_dir: Path) -> None: + from azure.ai.ml.entities import ParallelComponent + + parent: Any = None + if hasattr(component, "environment"): + parent = component + elif isinstance(component, ParallelComponent): + parent = component.task + else: + return + + # environment can be None + if not isinstance(parent.environment, str): + return + # registry environment will keep the "azureml:" prefix can be used directly + if parent.environment.startswith("azureml://registries"): + return + + environment = self._environment_operations.get(**extract_name_and_version(parent.environment)) + environment._localize(base_path=base_dir.absolute().as_posix()) + parent.environment = environment + + @experimental + @monitor_with_telemetry_mixin(ops_logger, "Component.Download", ActivityType.PUBLICAPI) + def download( + self, + name: str, + download_path: Union[PathLike, str] = ".", + *, + version: Optional[str] = None, + ) -> None: + """Download the specified component and its dependencies to local. Local component can be used to create + the component in another workspace or for offline development. + + :param name: Name of the code component. + :type name: str + :param Union[PathLike, str] download_path: Local path as download destination, + defaults to current working directory of the current user. Will be created if not exists. + :type download_path: str + :keyword version: Version of the component. + :paramtype version: Optional[str] + :raises ~OSError: Raised if download_path is pointing to an existing directory that is not empty. + identified and retrieved. Details will be provided in the error message. + :return: The specified component object. + :rtype: ~azure.ai.ml.entities.Component + """ + download_path = Path(download_path) + component = self._get(name=name, version=version) + self._resolve_azureml_id(component) + + output_dir = Path(download_path) + if output_dir.is_dir(): + # an OSError will be raised if the directory is not empty + output_dir.rmdir() + output_dir.mkdir(parents=True) + # download code + self._localize_code(component, output_dir) + + # download environment + self._localize_environment(component, output_dir) + + component._localize(output_dir.absolute().as_posix()) + (output_dir / "component_spec.yaml").write_text(component._to_yaml(), encoding=DefaultOpenEncoding.WRITE) + + def _get(self, name: str, version: Optional[str] = None, label: Optional[str] = None) -> Component: + if version and label: + msg = "Cannot specify both version and label." + raise ValidationException( + message=msg, + target=ErrorTarget.COMPONENT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + + if not version and not label: + label = DEFAULT_LABEL_NAME + + if label == DEFAULT_LABEL_NAME: + label = None + version = DEFAULT_COMPONENT_VERSION + + if label: + return _resolve_label_to_asset(self, name, label) + + result = self._get_component_version(name, version) + component = Component._from_rest_object(result) + self._resolve_azureml_id(component, jobs_only=True) + return component + + @experimental + @monitor_with_telemetry_mixin(ops_logger, "Component.Validate", ActivityType.PUBLICAPI) + def validate( + self, + component: Union[Component, types.FunctionType], + raise_on_failure: bool = False, + **kwargs: Any, + ) -> ValidationResult: + """validate a specified component. if there are inline defined + entities, e.g. Environment, Code, they won't be created. + + :param component: The component object or a mldesigner component function that generates component object + :type component: Union[Component, types.FunctionType] + :param raise_on_failure: Whether to raise exception on validation error. Defaults to False + :type raise_on_failure: bool + :return: All validation errors + :rtype: ~azure.ai.ml.entities.ValidationResult + """ + return self._validate( + component, + raise_on_failure=raise_on_failure, + # TODO 2330505: change this to True after remote validation is ready + skip_remote_validation=kwargs.pop("skip_remote_validation", True), + ) + + @monitor_with_telemetry_mixin(ops_logger, "Component.Validate", ActivityType.INTERNALCALL) + def _validate( + self, + component: Union[Component, types.FunctionType], + raise_on_failure: bool, + skip_remote_validation: bool, + ) -> ValidationResult: + """Implementation of validate. Add this function to avoid calling validate() directly in create_or_update(), + which will impact telemetry statistics & bring experimental warning in create_or_update(). + + :param component: The component + :type component: Union[Component, types.FunctionType] + :param raise_on_failure: Whether to raise on failure. + :type raise_on_failure: bool + :param skip_remote_validation: Whether to skip remote validation. + :type skip_remote_validation: bool + :return: The validation result + :rtype: ValidationResult + """ + # Update component when the input is a component function + if isinstance(component, types.FunctionType): + component = _refine_component(component) + + # local validation + result = component._validate(raise_error=raise_on_failure) + # remote validation, note that preflight_operation is not available for registry client + if not skip_remote_validation and self._preflight_operation: + workspace = self._workspace_operations.get() + remote_validation_result = self._preflight_operation.begin_validate( + resource_group_name=self._resource_group_name, + deployment_name=self._workspace_name, + parameters=component._build_rest_object_for_remote_validation( + location=workspace.location, + workspace_name=self._workspace_name, + ), + **self._init_args, + ) + result.merge_with( + # pylint: disable=protected-access + component._build_validation_result_from_rest_object(remote_validation_result.result()), + overwrite=True, + ) + # resolve location for diagnostics from remote validation + result.resolve_location_for_diagnostics(component._source_path) # type: ignore + return component._try_raise( # pylint: disable=protected-access + result, + raise_error=raise_on_failure, + ) + + def _update_flow_rest_object(self, rest_component_resource: Any) -> None: + import re + + from azure.ai.ml._utils._arm_id_utils import AMLVersionedArmId + + component_spec = rest_component_resource.properties.component_spec + code, flow_file_name = AMLVersionedArmId(component_spec["code"]), component_spec.pop("flow_file_name") + # TODO: avoid remote request here if met performance issue + created_code = self._code_operations.get(name=code.asset_name, version=code.asset_version) + # remove port number and append flow file name to get full uri for flow.dag.yaml + component_spec["flow_definition_uri"] = f"{re.sub(r':[0-9]+/', '/', created_code.path)}/{flow_file_name}" + + def _reset_version_if_no_change(self, component: Component, current_name: str, current_version: str) -> Tuple: + """Reset component version to default version if there's no change in the component. + + :param component: The component object + :type component: Component + :param current_name: The component name + :type current_name: str + :param current_version: The component version + :type current_version: str + :return: The new version and rest component resource + :rtype: Tuple[str, ComponentVersion] + """ + rest_component_resource = component._to_rest_object() + + try: + client_component_hash = rest_component_resource.properties.properties.get("client_component_hash") + remote_component_version = self._get_component_version(name=current_name) # will raise error if not found. + remote_component_hash = remote_component_version.properties.properties.get("client_component_hash") + if client_component_hash == remote_component_hash: + component.version = remote_component_version.properties.component_spec.get( + "version" + ) # only update the default version component instead of creating a new version component + logger.warning( + "The component is not modified compared to the default version " + "and the new version component registration is skipped." + ) + return component.version, component._to_rest_object() + except ResourceNotFoundError as e: + logger.info("Failed to get component version, %s", e) + except Exception as e: # pylint: disable=W0718 + logger.error("Failed to compare client_component_hash, %s", e) + + return current_version, rest_component_resource + + def _create_or_update_component_version( + self, + component: Component, + name: str, + version: Optional[str], + rest_component_resource: Any, + ) -> Any: + try: + if self._registry_name: + start_time = time.time() + path_format_arguments = { + "componentName": component.name, + "resourceGroupName": self._resource_group_name, + "registryName": self._registry_name, + } + poller = self._version_operation.begin_create_or_update( + name=name, + version=version, + resource_group_name=self._operation_scope.resource_group_name, + registry_name=self._registry_name, + body=rest_component_resource, + polling=AzureMLPolling( + LROConfigurations.POLL_INTERVAL, + path_format_arguments=path_format_arguments, + ), + ) + message = f"Creating/updating registry component {component.name} with version {component.version} " + polling_wait(poller=poller, start_time=start_time, message=message, timeout=None) + + else: + # _auto_increment_version can be True for non-registry component creation operation; + # and anonymous component should use hash as version + if not component._is_anonymous and component._auto_increment_version: + return _create_or_update_autoincrement( + name=name, + body=rest_component_resource, + version_operation=self._version_operation, + container_operation=self._container_operation, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + **self._init_args, + ) + + return self._version_operation.create_or_update( + name=name, + version=version, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + body=rest_component_resource, + **self._init_args, + ) + except Exception as e: + raise e + + return None + + @monitor_with_telemetry_mixin( + logger, + "Component.CreateOrUpdate", + ActivityType.PUBLICAPI, + extra_keys=["is_anonymous"], + ) + def create_or_update( + self, + component: Component, + version: Optional[str] = None, + *, + skip_validation: bool = False, + **kwargs: Any, + ) -> Component: + """Create or update a specified component. if there're inline defined + entities, e.g. Environment, Code, they'll be created together with the + component. + + :param component: The component object or a mldesigner component function that generates component object + :type component: Union[Component, types.FunctionType] + :param version: The component version to override. + :type version: str + :keyword skip_validation: whether to skip validation before creating/updating the component, defaults to False + :paramtype skip_validation: bool + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Component cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.AssetException: Raised if Component assets + (e.g. Data, Code, Model, Environment) cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.ComponentException: Raised if Component type is unsupported. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.ModelException: Raised if Component model cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty directory. + :return: The specified component object. + :rtype: ~azure.ai.ml.entities.Component + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START component_operations_create_or_update] + :end-before: [END component_operations_create_or_update] + :language: python + :dedent: 8 + :caption: Create component example. + """ + # Update component when the input is a component function + if isinstance(component, types.FunctionType): + component = _refine_component(component) + if version is not None: + component.version = version + # In non-registry scenario, if component does not have version, no need to get next version here. + # As Component property version has setter that updates `_auto_increment_version` in-place, then + # a component will get a version after its creation, and it will always use this version in its + # future creation operations, which breaks version auto increment mechanism. + if self._registry_name and not component.version and component._auto_increment_version: + component.version = _get_next_version_from_container( + name=component.name, + container_operation=self._container_operation, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + registry_name=self._registry_name, + **self._init_args, + ) + + if not component._is_anonymous: + component._is_anonymous = kwargs.pop("is_anonymous", False) + + if not skip_validation: + self._validate(component, raise_on_failure=True, skip_remote_validation=True) + + # Create all dependent resources + # Only upload dependencies if component is NOT IPP + if not component._intellectual_property: + self._resolve_arm_id_or_upload_dependencies(component) + + name, version = component._get_rest_name_version() + if not component._is_anonymous and kwargs.get("skip_if_no_change"): + version, rest_component_resource = self._reset_version_if_no_change( + component, + current_name=name, + current_version=str(version), + ) + else: + rest_component_resource = component._to_rest_object() + + # TODO: remove this after server side support directly using client created code + from azure.ai.ml.entities._component.flow import FlowComponent + + if isinstance(component, FlowComponent): + self._update_flow_rest_object(rest_component_resource) + + result = self._create_or_update_component_version( + component, + name, + version, + rest_component_resource, + ) + + if not result: + component = self.get(name=component.name, version=component.version) + else: + component = Component._from_rest_object(result) + + self._resolve_azureml_id( + component=component, + jobs_only=True, + ) + return component + + @experimental + def prepare_for_sign(self, component: Component) -> None: + ignore_file = IgnoreFile() + + if isinstance(component, ComponentCodeMixin): + with component._build_code() as code: + delete_two_catalog_files(code.path) + ignore_file = get_ignore_file(code.path) if code._ignore_file is None else ignore_file + file_list = get_upload_files_from_folder(code.path, ignore_file=ignore_file) + json_stub = {} + json_stub["HashAlgorithm"] = "SHA256" + json_stub["CatalogItems"] = {} # type: ignore + + for file_path, file_name in sorted(file_list, key=lambda x: str(x[1]).lower()): + file_hash = _get_file_hash(file_path, hashlib.sha256()).hexdigest().upper() + json_stub["CatalogItems"][file_name] = file_hash # type: ignore + + json_stub["CatalogItems"] = collections.OrderedDict( # type: ignore + sorted(json_stub["CatalogItems"].items()) # type: ignore + ) + create_catalog_files(code.path, json_stub) + + @monitor_with_telemetry_mixin(ops_logger, "Component.Archive", ActivityType.PUBLICAPI) + def archive( + self, + name: str, + version: Optional[str] = None, + label: Optional[str] = None, + # pylint:disable=unused-argument + **kwargs: Any, + ) -> None: + """Archive a component. + + :param name: Name of the component. + :type name: str + :param version: Version of the component. + :type version: str + :param label: Label of the component. (mutually exclusive with version). + :type label: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START component_operations_archive] + :end-before: [END component_operations_archive] + :language: python + :dedent: 8 + :caption: Archive component example. + """ + _archive_or_restore( + asset_operations=self, + version_operation=self._version_operation, + container_operation=self._container_operation, + is_archived=True, + name=name, + version=version, + label=label, + ) + + @monitor_with_telemetry_mixin(ops_logger, "Component.Restore", ActivityType.PUBLICAPI) + def restore( + self, + name: str, + version: Optional[str] = None, + label: Optional[str] = None, + # pylint:disable=unused-argument + **kwargs: Any, + ) -> None: + """Restore an archived component. + + :param name: Name of the component. + :type name: str + :param version: Version of the component. + :type version: str + :param label: Label of the component. (mutually exclusive with version). + :type label: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START component_operations_restore] + :end-before: [END component_operations_restore] + :language: python + :dedent: 8 + :caption: Restore component example. + """ + _archive_or_restore( + asset_operations=self, + version_operation=self._version_operation, + container_operation=self._container_operation, + is_archived=False, + name=name, + version=version, + label=label, + ) + + def _get_latest_version(self, component_name: str) -> Component: + """Returns the latest version of the asset with the given name. + + Latest is defined as the most recently created, not the most + recently updated. + + :param component_name: The component name + :type component_name: str + :return: A latest version of the named Component + :rtype: Component + """ + + result = ( + _get_latest( + component_name, + self._version_operation, + self._resource_group_name, + workspace_name=None, + registry_name=self._registry_name, + ) + if self._registry_name + else _get_latest( + component_name, + self._version_operation, + self._resource_group_name, + self._workspace_name, + ) + ) + return Component._from_rest_object(result) + + @classmethod + def _try_resolve_environment_for_component( + cls, component: Union[BaseNode, str], _: str, resolver: _AssetResolver + ) -> None: + if isinstance(component, BaseNode): + component = component._component # pylint: disable=protected-access + + if isinstance(component, str): + return + potential_parents: List[BaseNode] = [component] + if hasattr(component, "task"): + potential_parents.append(component.task) + for parent in potential_parents: + # for internal component, environment may be a dict or InternalEnvironment object + # in these two scenarios, we don't need to resolve the environment; + # Note for not directly importing InternalEnvironment and check with `isinstance`: + # import from azure.ai.ml._internal will enable internal component feature for all users, + # therefore, use type().__name__ to avoid import and execute type check + if not hasattr(parent, "environment"): + continue + if isinstance(parent.environment, dict): + continue + if type(parent.environment).__name__ == "InternalEnvironment": + continue + parent.environment = resolver(parent.environment, azureml_type=AzureMLResourceType.ENVIRONMENT) + + def _resolve_azureml_id(self, component: Component, jobs_only: bool = False) -> None: + # TODO: remove the parameter `jobs_only`. Some tests are expecting an arm id after resolving for now. + resolver = self._orchestrators.resolve_azureml_id + self._resolve_dependencies_for_component(component, resolver, jobs_only=jobs_only) + + def _resolve_arm_id_or_upload_dependencies(self, component: Component) -> None: + resolver = OperationOrchestrator( + self._all_operations, self._operation_scope, self._operation_config + ).get_asset_arm_id + + self._resolve_dependencies_for_component(component, resolver) + + def _resolve_dependencies_for_component( + self, + component: Component, + resolver: Callable, + *, + jobs_only: bool = False, + ) -> None: + # for now, many tests are expecting long arm id instead of short id for environment and code + if not jobs_only: + if isinstance(component, AutoMLComponent): + # no extra dependency for automl component + return + + # type check for potential Job type, which is unexpected here. + if not isinstance(component, Component): + msg = f"Non supported component type: {type(component)}" + raise ValidationException( + message=msg, + target=ErrorTarget.COMPONENT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + + # resolve component's code + _try_resolve_code_for_component(component=component, resolver=resolver) + # resolve component's environment + self._try_resolve_environment_for_component( + component=component, # type: ignore + resolver=resolver, + _="", + ) + + self._resolve_dependencies_for_pipeline_component_jobs( + component, + resolver=resolver, + ) + + def _resolve_inputs_for_pipeline_component_jobs(self, jobs: Dict[str, Any], base_path: str) -> None: + """Resolve inputs for jobs in a pipeline component. + + :param jobs: A dict of nodes in a pipeline component. + :type jobs: Dict[str, Any] + :param base_path: The base path used to resolve inputs. Usually it's the base path of the pipeline component. + :type base_path: str + """ + from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob + + for _, job_instance in jobs.items(): + # resolve inputs for each job's component + if isinstance(job_instance, BaseNode): + node: BaseNode = job_instance + self._job_operations._resolve_job_inputs( + # parameter group input need to be flattened first + self._job_operations._flatten_group_inputs(node.inputs), + base_path, + ) + elif isinstance(job_instance, AutoMLJob): + self._job_operations._resolve_automl_job_inputs(job_instance) + + @classmethod + def _resolve_binding_on_supported_fields_for_node(cls, node: BaseNode) -> None: + """Resolve all PipelineInput(binding from sdk) on supported fields to string. + + :param node: The node + :type node: BaseNode + """ + from azure.ai.ml.entities._job.pipeline._attr_dict import ( + try_get_non_arbitrary_attr, + ) + from azure.ai.ml.entities._job.pipeline._io import PipelineInput + + # compute binding to pipeline input is supported on node. + supported_fields = ["compute", "compute_name"] + for field_name in supported_fields: + val = try_get_non_arbitrary_attr(node, field_name) + if isinstance(val, PipelineInput): + # Put binding string to field + setattr(node, field_name, val._data_binding()) + + @classmethod + def _try_resolve_node_level_task_for_parallel_node(cls, node: BaseNode, _: str, resolver: _AssetResolver) -> None: + """Resolve node.task.code for parallel node if it's a reference to node.component.task.code. + + This is a hack operation. + + parallel_node.task.code won't be resolved directly for now, but it will be resolved if + parallel_node.task is a reference to parallel_node.component.task. Then when filling back + parallel_node.component.task.code, parallel_node.task.code will be changed as well. + + However, if we enable in-memory/on-disk cache for component resolution, such change + won't happen, so we resolve node level task code manually here. + + Note that we will always use resolved node.component.code to fill back node.task.code + given code overwrite on parallel node won't take effect for now. This is to make behaviors + consistent across os and python versions. + + The ideal solution should be done after PRS team decides how to handle parallel.task.code + + :param node: The node + :type node: BaseNode + :param _: The component name + :type _: str + :param resolver: The resolver function + :type resolver: _AssetResolver + """ + from azure.ai.ml.entities import Parallel, ParallelComponent + + if not isinstance(node, Parallel): + return + component = node._component # pylint: disable=protected-access + if not isinstance(component, ParallelComponent): + return + if not node.task: + return + + if node.task.code: + _try_resolve_code_for_component( + component, + resolver=resolver, + ) + node.task.code = component.code + if node.task.environment: + node.task.environment = resolver(component.environment, azureml_type=AzureMLResourceType.ENVIRONMENT) + + @classmethod + def _set_default_display_name_for_anonymous_component_in_node(cls, node: BaseNode, default_name: str) -> None: + """Set default display name for anonymous component in a node. + If node._component is an anonymous component and without display name, set the default display name. + + :param node: The node + :type node: BaseNode + :param default_name: The default name to set + :type default_name: str + """ + if not isinstance(node, BaseNode): + return + component = node._component + if isinstance(component, PipelineComponent): + return + # Set display name as node name + # TODO: the same anonymous component with different node name will have different anonymous hash + # as their display name will be different. + if ( + isinstance(component, Component) + # check if component is anonymous and not created based on its id. We can't directly check + # node._component._is_anonymous as it will be set to True on component creation, + # which is later than this check + and not component.id + and not component.display_name + ): + component.display_name = default_name + + @classmethod + def _try_resolve_compute_for_node(cls, node: BaseNode, _: str, resolver: _AssetResolver) -> None: + """Resolve compute for base node. + + :param node: The node + :type node: BaseNode + :param _: The node name + :type _: str + :param resolver: The resolver function + :type resolver: _AssetResolver + """ + if not isinstance(node, BaseNode): + return + if not isinstance(node._component, PipelineComponent): + # Resolve compute for other type + # Keep data binding expression as they are + if not is_data_binding_expression(node.compute): + # Get compute for each job + node.compute = resolver(node.compute, azureml_type=AzureMLResourceType.COMPUTE) + if has_attr_safe(node, "compute_name") and not is_data_binding_expression(node.compute_name): + node.compute_name = resolver(node.compute_name, azureml_type=AzureMLResourceType.COMPUTE) + + @classmethod + def _divide_nodes_to_resolve_into_layers( + cls, + component: PipelineComponent, + extra_operations: List[Callable[[BaseNode, str], Any]], + ) -> List: + """Traverse the pipeline component and divide nodes to resolve into layers. Note that all leaf nodes will be + put in the last layer. + For example, for below pipeline component, assuming that all nodes need to be resolved: + A + /|\ + B C D + | | + E F + | + G + return value will be: + [ + [("B", B), ("C", C)], + [("E", E)], + [("D", D), ("F", F), ("G", G)], + ] + + :param component: The pipeline component to resolve. + :type component: PipelineComponent + :param extra_operations: Extra operations to apply on nodes during the traversing. + :type extra_operations: List[Callable[Callable[[BaseNode, str], Any]]] + :return: A list of layers of nodes to resolve. + :rtype: List[List[Tuple[str, BaseNode]]] + """ + nodes_to_process = list(component.jobs.items()) + layers: List = [] + leaf_nodes = [] + + while nodes_to_process: + layers.append([]) + new_nodes_to_process = [] + for key, job_instance in nodes_to_process: + cls._resolve_binding_on_supported_fields_for_node(job_instance) + if isinstance(job_instance, LoopNode): + job_instance = job_instance.body + + for extra_operation in extra_operations: + extra_operation(job_instance, key) + + if isinstance(job_instance, BaseNode) and isinstance(job_instance._component, PipelineComponent): + # candidates for next layer + new_nodes_to_process.extend(job_instance.component.jobs.items()) + # use layers to store pipeline nodes in each layer for now + layers[-1].append((key, job_instance)) + else: + # note that LoopNode has already been replaced by its body here + leaf_nodes.append((key, job_instance)) + nodes_to_process = new_nodes_to_process + + # if there is subgraph, the last item in layers will be empty for now as all leaf nodes are stored in leaf_nodes + if len(layers) != 0: + layers.pop() + layers.append(leaf_nodes) + + return layers + + def _get_workspace_key(self) -> str: + try: + workspace_rest = self._workspace_operations._operation.get( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + ) + return str(workspace_rest.workspace_id) + except HttpResponseError: + return "{}/{}/{}".format(self._subscription_id, self._resource_group_name, self._workspace_name) + + def _get_registry_key(self) -> str: + """Get key for used registry. + + Note that, although registry id is in registry discovery response, it is not in RegistryDiscoveryDto; and we'll + lose the information after deserialization. + To avoid changing related rest client, we simply use registry related information from self to construct + registry key, which means that on-disk cache will be invalid if a registry is deleted and then created + again with the same name. + + :return: The registry key + :rtype: str + """ + return "{}/{}/{}".format(self._subscription_id, self._resource_group_name, self._registry_name) + + def _get_client_key(self) -> str: + """Get key for used client. + Key should be able to uniquely identify used registry or workspace. + + :return: The client key + :rtype: str + """ + # check cache first + if self._client_key: + return self._client_key + + # registry name has a higher priority comparing to workspace name according to current __init__ implementation + # of MLClient + if self._registry_name: + self._client_key = "registry/" + self._get_registry_key() + elif self._workspace_name: + self._client_key = "workspace/" + self._get_workspace_key() + else: + # This should never happen. + raise ValueError("Either workspace name or registry name must be provided to use component operations.") + return self._client_key + + def _resolve_dependencies_for_pipeline_component_jobs( + self, + component: Union[Component, str], + resolver: _AssetResolver, + ) -> None: + """Resolve dependencies for pipeline component jobs. + Will directly return if component is not a pipeline component. + + :param component: The pipeline component to resolve. + :type component: Union[Component, str] + :param resolver: The resolver to resolve the dependencies. + :type resolver: _AssetResolver + """ + if not isinstance(component, PipelineComponent) or not component.jobs: + return + + from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob + + self._resolve_inputs_for_pipeline_component_jobs(component.jobs, component._base_path) + + # This is a preparation for concurrent resolution. Nodes will be resolved later layer by layer + # from bottom to top, as hash calculation of a parent node will be impacted by resolution + # of its child nodes. + layers = self._divide_nodes_to_resolve_into_layers( + component, + extra_operations=[ + # no need to do this as we now keep the original component name for anonymous components + # self._set_default_display_name_for_anonymous_component_in_node, + partial( + self._try_resolve_node_level_task_for_parallel_node, + resolver=resolver, + ), + partial(self._try_resolve_environment_for_component, resolver=resolver), + partial(self._try_resolve_compute_for_node, resolver=resolver), + # should we resolve code here after we do extra operations concurrently? + ], + ) + + # cache anonymous component only for now + # request level in-memory cache can be a better solution for other type of assets as they are + # relatively simple and of small number of distinct instances + component_cache = CachedNodeResolver( + resolver=resolver, + client_key=self._get_client_key(), + ) + + for layer in reversed(layers): + for _, job_instance in layer: + if isinstance(job_instance, AutoMLJob): + # only compute is resolved here + self._job_operations._resolve_arm_id_for_automl_job(job_instance, resolver, inside_pipeline=True) + elif isinstance(job_instance, BaseNode): + component_cache.register_node_for_lazy_resolution(job_instance) + elif isinstance(job_instance, ConditionNode): + pass + else: + msg = f"Non supported job type in Pipeline: {type(job_instance)}" + raise ComponentException( + message=msg, + target=ErrorTarget.COMPONENT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + + component_cache.resolve_nodes() + + +def _refine_component(component_func: Any) -> Component: + """Return the component of function that is decorated by command + component decorator. + + :param component_func: Function that is decorated by command component decorator + :type component_func: types.FunctionType + :return: Component entity of target function + :rtype: Component + """ + + def check_parameter_type(f: Any) -> None: + """Check all parameter is annotated or has a default value with clear type(not None). + + :param f: The component function + :type f: types.FunctionType + """ + annotations = getattr(f, "__annotations__", {}) + func_parameters = signature(f).parameters + defaults_dict = {key: val.default for key, val in func_parameters.items()} + variable_inputs = [ + key for key, val in func_parameters.items() if val.kind in [val.VAR_POSITIONAL, val.VAR_KEYWORD] + ] + if variable_inputs: + msg = "Cannot register the component {} with variable inputs {!r}." + raise ValidationException( + message=msg.format(f.__name__, variable_inputs), + no_personal_data_message=msg.format("[keys]", "[name]"), + target=ErrorTarget.COMPONENT, + error_category=ErrorCategory.USER_ERROR, + ) + unknown_type_keys = [ + key for key, val in defaults_dict.items() if key not in annotations and val is Parameter.empty + ] + if unknown_type_keys: + msg = "Unknown type of parameter {} in pipeline func {!r}, please add type annotation." + raise ValidationException( + message=msg.format(unknown_type_keys, f.__name__), + no_personal_data_message=msg.format("[keys]", "[name]"), + target=ErrorTarget.COMPONENT, + error_category=ErrorCategory.USER_ERROR, + ) + + def check_non_pipeline_inputs(f: Any) -> None: + """Check whether non_pipeline_inputs exist in pipeline builder. + + :param f: The component function + :type f: types.FunctionType + """ + if f._pipeline_builder.non_pipeline_parameter_names: + msg = "Cannot register pipeline component {!r} with non_pipeline_inputs." + raise ValidationException( + message=msg.format(f.__name__), + no_personal_data_message=msg.format(""), + target=ErrorTarget.COMPONENT, + error_category=ErrorCategory.USER_ERROR, + ) + + if hasattr(component_func, "_is_mldesigner_component") and component_func._is_mldesigner_component: + return component_func.component + if hasattr(component_func, "_is_dsl_func") and component_func._is_dsl_func: + check_non_pipeline_inputs(component_func) + check_parameter_type(component_func) + if component_func._job_settings: + module_logger.warning( + "Job settings %s on pipeline function '%s' are ignored when creating PipelineComponent.", + component_func._job_settings, + component_func.__name__, + ) + # Normally pipeline component are created when dsl.pipeline inputs are provided + # so pipeline input .result() can resolve to correct value. + # When pipeline component created without dsl.pipeline inputs, pipeline input .result() won't work. + return component_func._pipeline_builder.build(user_provided_kwargs={}) + msg = "Function must be a dsl or mldesigner component function: {!r}" + raise ValidationException( + message=msg.format(component_func), + no_personal_data_message=msg.format("component"), + error_category=ErrorCategory.USER_ERROR, + target=ErrorTarget.COMPONENT, + ) + + +def _try_resolve_code_for_component(component: Component, resolver: _AssetResolver) -> None: + if isinstance(component, ComponentCodeMixin): + with component._build_code() as code: + if code is None: + code = component._get_origin_code_value() + if code is None: + return + component._fill_back_code_value(resolver(code, azureml_type=AzureMLResourceType.CODE)) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_compute_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_compute_operations.py new file mode 100644 index 00000000..7990a3fa --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_compute_operations.py @@ -0,0 +1,447 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Any, Dict, Iterable, Optional, cast + +from azure.ai.ml._restclient.v2023_08_01_preview import AzureMachineLearningWorkspaces as ServiceClient022023Preview +from azure.ai.ml._restclient.v2024_04_01_preview import AzureMachineLearningWorkspaces as ServiceClient042024Preview +from azure.ai.ml._restclient.v2024_04_01_preview.models import SsoSetting +from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml.constants._common import COMPUTE_UPDATE_ERROR +from azure.ai.ml.constants._compute import ComputeType +from azure.ai.ml.entities import AmlComputeNodeInfo, Compute, Usage, VmSize +from azure.core.polling import LROPoller +from azure.core.tracing.decorator import distributed_trace + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class ComputeOperations(_ScopeDependentOperations): + """ComputeOperations. + + This class should not be instantiated directly. Instead, use the `compute` attribute of an MLClient object. + + :param operation_scope: Scope variables for the operations classes of an MLClient object. + :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope + :param operation_config: Common configuration for operations classes of an MLClient object. + :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig + :param service_client: Service client to allow end users to operate on Azure Machine Learning + Workspace resources. + :type service_client: ~azure.ai.ml._restclient.v2023_02_01_preview.AzureMachineLearningWorkspaces + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client: ServiceClient022023Preview, + service_client_2024: ServiceClient042024Preview, + **kwargs: Dict, + ) -> None: + super(ComputeOperations, self).__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._operation = service_client.compute + self._operation2024 = service_client_2024.compute + self._workspace_operations = service_client.workspaces + self._vmsize_operations = service_client.virtual_machine_sizes + self._usage_operations = service_client.usages + self._init_kwargs = kwargs + + @distributed_trace + @monitor_with_activity(ops_logger, "Compute.List", ActivityType.PUBLICAPI) + def list(self, *, compute_type: Optional[str] = None) -> Iterable[Compute]: + """List computes of the workspace. + + :keyword compute_type: The type of the compute to be listed, case-insensitive. Defaults to AMLCompute. + :paramtype compute_type: Optional[str] + :return: An iterator like instance of Compute objects. + :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.Compute] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START compute_operations_list] + :end-before: [END compute_operations_list] + :language: python + :dedent: 8 + :caption: Retrieving a list of the AzureML Kubernetes compute resources in a workspace. + """ + + return cast( + Iterable[Compute], + self._operation.list( + self._operation_scope.resource_group_name, + self._workspace_name, + cls=lambda objs: [ + Compute._from_rest_object(obj) + for obj in objs + if compute_type is None or str(Compute._from_rest_object(obj).type).lower() == compute_type.lower() + ], + ), + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "Compute.Get", ActivityType.PUBLICAPI) + def get(self, name: str) -> Compute: + """Get a compute resource. + + :param name: Name of the compute resource. + :type name: str + :return: A Compute object. + :rtype: ~azure.ai.ml.entities.Compute + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START compute_operations_get] + :end-before: [END compute_operations_get] + :language: python + :dedent: 8 + :caption: Retrieving a compute resource from a workspace. + """ + + rest_obj = self._operation.get( + self._operation_scope.resource_group_name, + self._workspace_name, + name, + ) + return Compute._from_rest_object(rest_obj) + + @distributed_trace + @monitor_with_activity(ops_logger, "Compute.ListNodes", ActivityType.PUBLICAPI) + def list_nodes(self, name: str) -> Iterable[AmlComputeNodeInfo]: + """Retrieve a list of a compute resource's nodes. + + :param name: Name of the compute resource. + :type name: str + :return: An iterator-like instance of AmlComputeNodeInfo objects. + :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.AmlComputeNodeInfo] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START compute_operations_list_nodes] + :end-before: [END compute_operations_list_nodes] + :language: python + :dedent: 8 + :caption: Retrieving a list of nodes from a compute resource. + """ + return cast( + Iterable[AmlComputeNodeInfo], + self._operation.list_nodes( + self._operation_scope.resource_group_name, + self._workspace_name, + name, + cls=lambda objs: [AmlComputeNodeInfo._from_rest_object(obj) for obj in objs], + ), + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "Compute.BeginCreateOrUpdate", ActivityType.PUBLICAPI) + def begin_create_or_update(self, compute: Compute) -> LROPoller[Compute]: + """Create and register a compute resource. + + :param compute: The compute resource definition. + :type compute: ~azure.ai.ml.entities.Compute + :return: An instance of LROPoller that returns a Compute object once the + long-running operation is complete. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.Compute] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START compute_operations_create_update] + :end-before: [END compute_operations_create_update] + :language: python + :dedent: 8 + :caption: Creating and registering a compute resource. + """ + if compute.type != ComputeType.AMLCOMPUTE: + if compute.location: + module_logger.warning( + "Warning: 'Location' is not supported for compute type %s and will not be used.", + compute.type, + ) + compute.location = self._get_workspace_location() + + if not compute.location: + compute.location = self._get_workspace_location() + + compute._set_full_subnet_name( + self._operation_scope.subscription_id, + self._operation_scope.resource_group_name, + ) + + compute_rest_obj = compute._to_rest_object() + + poller = self._operation.begin_create_or_update( + self._operation_scope.resource_group_name, + self._workspace_name, + compute_name=compute.name, + parameters=compute_rest_obj, + polling=True, + cls=lambda response, deserialized, headers: Compute._from_rest_object(deserialized), + ) + + return poller + + @distributed_trace + @monitor_with_activity(ops_logger, "Compute.Attach", ActivityType.PUBLICAPI) + def begin_attach(self, compute: Compute, **kwargs: Any) -> LROPoller[Compute]: + """Attach a compute resource to the workspace. + + :param compute: The compute resource definition. + :type compute: ~azure.ai.ml.entities.Compute + :return: An instance of LROPoller that returns a Compute object once the + long-running operation is complete. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.Compute] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START compute_operations_attach] + :end-before: [END compute_operations_attach] + :language: python + :dedent: 8 + :caption: Attaching a compute resource to the workspace. + """ + return self.begin_create_or_update(compute=compute, **kwargs) + + @distributed_trace + @monitor_with_activity(ops_logger, "Compute.BeginUpdate", ActivityType.PUBLICAPI) + def begin_update(self, compute: Compute) -> LROPoller[Compute]: + """Update a compute resource. Currently only valid for AmlCompute resource types. + + :param compute: The compute resource definition. + :type compute: ~azure.ai.ml.entities.Compute + :return: An instance of LROPoller that returns a Compute object once the + long-running operation is complete. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.Compute] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START compute_operations_update] + :end-before: [END compute_operations_update] + :language: python + :dedent: 8 + :caption: Updating an AmlCompute resource. + """ + if not compute.type == ComputeType.AMLCOMPUTE: + COMPUTE_UPDATE_ERROR.format(compute.name, compute.type) + + compute_rest_obj = compute._to_rest_object() + + poller = self._operation.begin_create_or_update( + self._operation_scope.resource_group_name, + self._workspace_name, + compute_name=compute.name, + parameters=compute_rest_obj, + polling=True, + cls=lambda response, deserialized, headers: Compute._from_rest_object(deserialized), + ) + + return poller + + @distributed_trace + @monitor_with_activity(ops_logger, "Compute.BeginDelete", ActivityType.PUBLICAPI) + def begin_delete(self, name: str, *, action: str = "Delete") -> LROPoller[None]: + """Delete or detach a compute resource. + + :param name: The name of the compute resource. + :type name: str + :keyword action: Action to perform. Possible values: ["Delete", "Detach"]. Defaults to "Delete". + :type action: str + :return: A poller to track the operation status. + :rtype: ~azure.core.polling.LROPoller[None] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START compute_operations_delete] + :end-before: [END compute_operations_delete] + :language: python + :dedent: 8 + :caption: Delete compute example. + """ + return self._operation.begin_delete( + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + compute_name=name, + underlying_resource_action=action, + **self._init_kwargs, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "Compute.BeginStart", ActivityType.PUBLICAPI) + def begin_start(self, name: str) -> LROPoller[None]: + """Start a compute instance. + + :param name: The name of the compute instance. + :type name: str + :return: A poller to track the operation status. + :rtype: azure.core.polling.LROPoller[None] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START compute_operations_start] + :end-before: [END compute_operations_start] + :language: python + :dedent: 8 + :caption: Starting a compute instance. + """ + + return self._operation.begin_start( + self._operation_scope.resource_group_name, + self._workspace_name, + name, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "Compute.BeginStop", ActivityType.PUBLICAPI) + def begin_stop(self, name: str) -> LROPoller[None]: + """Stop a compute instance. + + :param name: The name of the compute instance. + :type name: str + :return: A poller to track the operation status. + :rtype: azure.core.polling.LROPoller[None] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START compute_operations_stop] + :end-before: [END compute_operations_stop] + :language: python + :dedent: 8 + :caption: Stopping a compute instance. + """ + return self._operation.begin_stop( + self._operation_scope.resource_group_name, + self._workspace_name, + name, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "Compute.BeginRestart", ActivityType.PUBLICAPI) + def begin_restart(self, name: str) -> LROPoller[None]: + """Restart a compute instance. + + :param name: The name of the compute instance. + :type name: str + :return: A poller to track the operation status. + :rtype: azure.core.polling.LROPoller[None] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START compute_operations_restart] + :end-before: [END compute_operations_restart] + :language: python + :dedent: 8 + :caption: Restarting a stopped compute instance. + """ + return self._operation.begin_restart( + self._operation_scope.resource_group_name, + self._workspace_name, + name, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "Compute.ListUsage", ActivityType.PUBLICAPI) + def list_usage(self, *, location: Optional[str] = None) -> Iterable[Usage]: + """List the current usage information as well as AzureML resource limits for the + given subscription and location. + + :keyword location: The location for which resource usage is queried. + Defaults to workspace location. + :paramtype location: Optional[str] + :return: An iterator over current usage info objects. + :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.Usage] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START compute_operations_list_usage] + :end-before: [END compute_operations_list_usage] + :language: python + :dedent: 8 + :caption: Listing resource usage for the workspace location. + """ + if not location: + location = self._get_workspace_location() + return cast( + Iterable[Usage], + self._usage_operations.list( + location=location, + cls=lambda objs: [Usage._from_rest_object(obj) for obj in objs], + ), + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "Compute.ListSizes", ActivityType.PUBLICAPI) + def list_sizes(self, *, location: Optional[str] = None, compute_type: Optional[str] = None) -> Iterable[VmSize]: + """List the supported VM sizes in a location. + + :keyword location: The location upon which virtual-machine-sizes is queried. + Defaults to workspace location. + :paramtype location: str + :keyword compute_type: The type of the compute to be listed, case-insensitive. Defaults to AMLCompute. + :paramtype compute_type: Optional[str] + :return: An iterator over virtual machine size objects. + :rtype: Iterable[~azure.ai.ml.entities.VmSize] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_compute.py + :start-after: [START compute_operations_list_sizes] + :end-before: [END compute_operations_list_sizes] + :language: python + :dedent: 8 + :caption: Listing the supported VM sizes in the workspace location. + """ + if not location: + location = self._get_workspace_location() + size_list = self._vmsize_operations.list(location=location) + if not size_list: + return [] + if compute_type: + return [ + VmSize._from_rest_object(item) + for item in size_list.value + if compute_type.lower() in (supported_type.lower() for supported_type in item.supported_compute_types) + ] + return [VmSize._from_rest_object(item) for item in size_list.value] + + @distributed_trace + @monitor_with_activity(ops_logger, "Compute.enablesso", ActivityType.PUBLICAPI) + @experimental + def enable_sso(self, *, name: str, enable_sso: bool = True, **kwargs: Any) -> None: + """enable sso for a compute instance. + + :keyword name: Name of the compute instance. + :paramtype name: str + :keyword enable_sso: enable sso bool flag + Default to True + :paramtype enable_sso: bool + """ + + self._operation2024.update_sso_settings( + self._operation_scope.resource_group_name, + self._workspace_name, + name, + parameters=SsoSetting(enable_sso=enable_sso), + **kwargs, + ) + + def _get_workspace_location(self) -> str: + workspace = self._workspace_operations.get(self._resource_group_name, self._workspace_name) + return str(workspace.location) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_data_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_data_operations.py new file mode 100644 index 00000000..a54b1739 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_data_operations.py @@ -0,0 +1,891 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access,no-value-for-parameter + +import os +import time +import uuid +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Dict, Generator, Iterable, List, Optional, Union, cast + +from marshmallow.exceptions import ValidationError as SchemaValidationError + +from azure.ai.ml._artifacts._artifact_utilities import _check_and_upload_path +from azure.ai.ml._artifacts._constants import ( + ASSET_PATH_ERROR, + CHANGED_ASSET_PATH_MSG, + CHANGED_ASSET_PATH_MSG_NO_PERSONAL_DATA, +) +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import ( + AzureMachineLearningWorkspaces as ServiceClient102021Dataplane, +) +from azure.ai.ml._restclient.v2023_04_01_preview import AzureMachineLearningWorkspaces as ServiceClient042023_preview +from azure.ai.ml._restclient.v2023_04_01_preview.models import ListViewType +from azure.ai.ml._restclient.v2024_01_01_preview import AzureMachineLearningWorkspaces as ServiceClient012024_preview +from azure.ai.ml._restclient.v2024_01_01_preview.models import ComputeInstanceDataMount +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationsContainer, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._asset_utils import ( + _archive_or_restore, + _check_or_modify_auto_delete_setting, + _create_or_update_autoincrement, + _get_latest_version_from_container, + _resolve_label_to_asset, + _validate_auto_delete_setting_in_data_output, + _validate_workspace_managed_datastore, +) +from azure.ai.ml._utils._data_utils import ( + download_mltable_metadata_schema, + read_local_mltable_metadata_contents, + read_remote_mltable_metadata_contents, + validate_mltable_metadata, +) +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils._http_utils import HttpPipeline +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml._utils._registry_utils import ( + get_asset_body_for_registry_storage, + get_registry_client, + get_sas_uri_for_registry_asset, +) +from azure.ai.ml._utils.utils import is_url +from azure.ai.ml.constants._common import ( + ASSET_ID_FORMAT, + MLTABLE_METADATA_SCHEMA_URL_FALLBACK, + AssetTypes, + AzureMLResourceType, +) +from azure.ai.ml.data_transfer import import_data as import_data_func +from azure.ai.ml.entities import PipelineJob, PipelineJobSettings +from azure.ai.ml.entities._assets import Data, WorkspaceAssetReference +from azure.ai.ml.entities._data.mltable_metadata import MLTableMetadata +from azure.ai.ml.entities._data_import.data_import import DataImport +from azure.ai.ml.entities._inputs_outputs import Output +from azure.ai.ml.entities._inputs_outputs.external_data import Database +from azure.ai.ml.exceptions import ( + AssetPathException, + ErrorCategory, + ErrorTarget, + MlException, + ValidationErrorType, + ValidationException, +) +from azure.ai.ml.operations._datastore_operations import DatastoreOperations +from azure.core.exceptions import HttpResponseError, ResourceNotFoundError +from azure.core.paging import ItemPaged + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class DataOperations(_ScopeDependentOperations): + """DataOperations. + + You should not instantiate this class directly. Instead, you should + create an MLClient instance that instantiates it for you and + attaches it as an attribute. + + :param operation_scope: Scope variables for the operations classes of an MLClient object. + :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope + :param operation_config: Common configuration for operations classes of an MLClient object. + :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig + :param service_client: Service client to allow end users to operate on Azure Machine Learning Workspace + resources (ServiceClient042023Preview or ServiceClient102021Dataplane). + :type service_client: typing.Union[ + ~azure.ai.ml._restclient.v2023_04_01_preview._azure_machine_learning_workspaces.AzureMachineLearningWorkspaces, + ~azure.ai.ml._restclient.v2021_10_01_dataplanepreview._azure_machine_learning_workspaces. + AzureMachineLearningWorkspaces] + :param datastore_operations: Represents a client for performing operations on Datastores. + :type datastore_operations: ~azure.ai.ml.operations._datastore_operations.DatastoreOperations + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client: Union[ServiceClient042023_preview, ServiceClient102021Dataplane], + service_client_012024_preview: ServiceClient012024_preview, + datastore_operations: DatastoreOperations, + **kwargs: Any, + ): + super(DataOperations, self).__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._operation = service_client.data_versions + self._container_operation = service_client.data_containers + self._datastore_operation = datastore_operations + self._compute_operation = service_client_012024_preview.compute + self._service_client = service_client + self._init_kwargs = kwargs + self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline") + self._all_operations: OperationsContainer = kwargs.pop("all_operations") + # Maps a label to a function which given an asset name, + # returns the asset associated with the label + self._managed_label_resolver = {"latest": self._get_latest_version} + + @monitor_with_activity(ops_logger, "Data.List", ActivityType.PUBLICAPI) + def list( + self, + name: Optional[str] = None, + *, + list_view_type: ListViewType = ListViewType.ACTIVE_ONLY, + ) -> ItemPaged[Data]: + """List the data assets of the workspace. + + :param name: Name of a specific data asset, optional. + :type name: Optional[str] + :keyword list_view_type: View type for including/excluding (for example) archived data assets. + Default: ACTIVE_ONLY. + :type list_view_type: Optional[ListViewType] + :return: An iterator like instance of Data objects + :rtype: ~azure.core.paging.ItemPaged[Data] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START data_operations_list] + :end-before: [END data_operations_list] + :language: python + :dedent: 8 + :caption: List data assets example. + """ + if name: + return ( + self._operation.list( + name=name, + registry_name=self._registry_name, + cls=lambda objs: [Data._from_rest_object(obj) for obj in objs], + list_view_type=list_view_type, + **self._scope_kwargs, + ) + if self._registry_name + else self._operation.list( + name=name, + workspace_name=self._workspace_name, + cls=lambda objs: [Data._from_rest_object(obj) for obj in objs], + list_view_type=list_view_type, + **self._scope_kwargs, + ) + ) + return ( + self._container_operation.list( + registry_name=self._registry_name, + cls=lambda objs: [Data._from_container_rest_object(obj) for obj in objs], + list_view_type=list_view_type, + **self._scope_kwargs, + ) + if self._registry_name + else self._container_operation.list( + workspace_name=self._workspace_name, + cls=lambda objs: [Data._from_container_rest_object(obj) for obj in objs], + list_view_type=list_view_type, + **self._scope_kwargs, + ) + ) + + def _get(self, name: Optional[str], version: Optional[str] = None) -> Data: + if version: + return ( + self._operation.get( + name=name, + version=version, + registry_name=self._registry_name, + **self._scope_kwargs, + **self._init_kwargs, + ) + if self._registry_name + else self._operation.get( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + name=name, + version=version, + **self._init_kwargs, + ) + ) + return ( + self._container_operation.get( + name=name, + registry_name=self._registry_name, + **self._scope_kwargs, + **self._init_kwargs, + ) + if self._registry_name + else self._container_operation.get( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + name=name, + **self._init_kwargs, + ) + ) + + @monitor_with_activity(ops_logger, "Data.Get", ActivityType.PUBLICAPI) + def get(self, name: str, version: Optional[str] = None, label: Optional[str] = None) -> Data: # type: ignore + """Get the specified data asset. + + :param name: Name of data asset. + :type name: str + :param version: Version of data asset. + :type version: str + :param label: Label of the data asset. (mutually exclusive with version) + :type label: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Data cannot be successfully + identified and retrieved. Details will be provided in the error message. + :return: Data asset object. + :rtype: ~azure.ai.ml.entities.Data + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START data_operations_get] + :end-before: [END data_operations_get] + :language: python + :dedent: 8 + :caption: Get data assets example. + """ + try: + if version and label: + msg = "Cannot specify both version and label." + raise ValidationException( + message=msg, + target=ErrorTarget.DATA, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + if label: + return _resolve_label_to_asset(self, name, label) + + if not version: + msg = "Must provide either version or label." + raise ValidationException( + message=msg, + target=ErrorTarget.DATA, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + data_version_resource = self._get(name, version) + return Data._from_rest_object(data_version_resource) + except (ValidationException, SchemaValidationError) as ex: + log_and_raise_error(ex) + + @monitor_with_activity(ops_logger, "Data.CreateOrUpdate", ActivityType.PUBLICAPI) + def create_or_update(self, data: Data) -> Data: + """Returns created or updated data asset. + + If not already in storage, asset will be uploaded to the workspace's blob storage. + + :param data: Data asset object. + :type data: azure.ai.ml.entities.Data + :raises ~azure.ai.ml.exceptions.AssetPathException: Raised when the Data artifact path is + already linked to another asset + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Data cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty directory. + :return: Data asset object. + :rtype: ~azure.ai.ml.entities.Data + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START data_operations_create_or_update] + :end-before: [END data_operations_create_or_update] + :language: python + :dedent: 8 + :caption: Create data assets example. + """ + try: + name = data.name + if not data.version and self._registry_name: + msg = "Data asset version is required for registry" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.DATA, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + version = data.version + + sas_uri = None + if self._registry_name: + # If the data asset is a workspace asset, promote to registry + if isinstance(data, WorkspaceAssetReference): + try: + self._operation.get( + name=data.name, + version=data.version, + resource_group_name=self._resource_group_name, + registry_name=self._registry_name, + ) + except Exception as err: # pylint: disable=W0718 + if isinstance(err, ResourceNotFoundError): + pass + else: + raise err + else: + msg = "An data asset with this name and version already exists in registry" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.DATA, + error_category=ErrorCategory.USER_ERROR, + ) + data_res_obj = data._to_rest_object() + result = self._service_client.resource_management_asset_reference.begin_import_method( + resource_group_name=self._resource_group_name, + registry_name=self._registry_name, + body=data_res_obj, + ).result() + + if not result: + data_res_obj = self._get(name=data.name, version=data.version) + return Data._from_rest_object(data_res_obj) + + sas_uri = get_sas_uri_for_registry_asset( + service_client=self._service_client, + name=name, + version=version, + resource_group=self._resource_group_name, + registry=self._registry_name, + body=get_asset_body_for_registry_storage(self._registry_name, "data", name, version), + ) + + referenced_uris = self._validate(data) + if referenced_uris: + data._referenced_uris = referenced_uris + + data, _ = _check_and_upload_path( + artifact=data, + asset_operations=self, + sas_uri=sas_uri, + artifact_type=ErrorTarget.DATA, + show_progress=self._show_progress, + ) + + _check_or_modify_auto_delete_setting(data.auto_delete_setting) + + data_version_resource = data._to_rest_object() + auto_increment_version = data._auto_increment_version + + if auto_increment_version: + result = _create_or_update_autoincrement( + name=data.name, + body=data_version_resource, + version_operation=self._operation, + container_operation=self._container_operation, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + **self._init_kwargs, + ) + else: + result = ( + self._operation.begin_create_or_update( + name=name, + version=version, + registry_name=self._registry_name, + body=data_version_resource, + **self._scope_kwargs, + ).result() + if self._registry_name + else self._operation.create_or_update( + name=name, + version=version, + workspace_name=self._workspace_name, + body=data_version_resource, + **self._scope_kwargs, + ) + ) + + if not result and self._registry_name: + result = self._get(name=name, version=version) + + return Data._from_rest_object(result) + except Exception as ex: + if isinstance(ex, (ValidationException, SchemaValidationError)): + log_and_raise_error(ex) + elif isinstance(ex, HttpResponseError): + # service side raises an exception if we attempt to update an existing asset's asset path + if str(ex) == ASSET_PATH_ERROR: + raise AssetPathException( + message=CHANGED_ASSET_PATH_MSG, + tartget=ErrorTarget.DATA, + no_personal_data_message=CHANGED_ASSET_PATH_MSG_NO_PERSONAL_DATA, + error_category=ErrorCategory.USER_ERROR, + ) from ex + raise ex + + @monitor_with_activity(ops_logger, "Data.ImportData", ActivityType.PUBLICAPI) + @experimental + def import_data(self, data_import: DataImport, **kwargs: Any) -> PipelineJob: + """Returns the data import job that is creating the data asset. + + :param data_import: DataImport object. + :type data_import: azure.ai.ml.entities.DataImport + :return: data import job object. + :rtype: ~azure.ai.ml.entities.PipelineJob + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START data_operations_import_data] + :end-before: [END data_operations_import_data] + :language: python + :dedent: 8 + :caption: Import data assets example. + """ + + experiment_name = "data_import_" + str(data_import.name) + data_import.type = AssetTypes.MLTABLE if isinstance(data_import.source, Database) else AssetTypes.URI_FOLDER + + # avoid specifying auto_delete_setting in job output now + _validate_auto_delete_setting_in_data_output(data_import.auto_delete_setting) + + # block cumtomer specified path on managed datastore + data_import.path = _validate_workspace_managed_datastore(data_import.path) + + if "${{name}}" not in str(data_import.path): + data_import.path = data_import.path.rstrip("/") + "/${{name}}" # type: ignore + import_job = import_data_func( + description=data_import.description or experiment_name, + display_name=experiment_name, + experiment_name=experiment_name, + compute="serverless", + source=data_import.source, + outputs={ + "sink": Output( + type=data_import.type, + path=data_import.path, # type: ignore + name=data_import.name, + version=data_import.version, + ) + }, + ) + import_pipeline = PipelineJob( + description=data_import.description or experiment_name, + tags=data_import.tags, + display_name=experiment_name, + experiment_name=experiment_name, + properties=data_import.properties or {}, + settings=PipelineJobSettings(force_rerun=True), + jobs={experiment_name: import_job}, + ) + import_pipeline.properties["azureml.materializationAssetName"] = data_import.name + return self._all_operations.all_operations[AzureMLResourceType.JOB].create_or_update( + job=import_pipeline, skip_validation=True, **kwargs + ) + + @monitor_with_activity(ops_logger, "Data.ListMaterializationStatus", ActivityType.PUBLICAPI) + def list_materialization_status( + self, + name: str, + *, + list_view_type: ListViewType = ListViewType.ACTIVE_ONLY, + **kwargs: Any, + ) -> Iterable[PipelineJob]: + """List materialization jobs of the asset. + + :param name: name of asset being created by the materialization jobs. + :type name: str + :keyword list_view_type: View type for including/excluding (for example) archived jobs. Default: ACTIVE_ONLY. + :paramtype list_view_type: Optional[ListViewType] + :return: An iterator like instance of Job objects. + :rtype: ~azure.core.paging.ItemPaged[PipelineJob] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START data_operations_list_materialization_status] + :end-before: [END data_operations_list_materialization_status] + :language: python + :dedent: 8 + :caption: List materialization jobs example. + """ + + return cast( + Iterable[PipelineJob], + self._all_operations.all_operations[AzureMLResourceType.JOB].list( + job_type="Pipeline", + asset_name=name, + list_view_type=list_view_type, + **kwargs, + ), + ) + + @monitor_with_activity(ops_logger, "Data.Validate", ActivityType.INTERNALCALL) + def _validate(self, data: Data) -> Optional[List[str]]: + if not data.path: + msg = "Missing data path. Path is required for data." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + error_type=ValidationErrorType.MISSING_FIELD, + target=ErrorTarget.DATA, + error_category=ErrorCategory.USER_ERROR, + ) + + asset_path = str(data.path) + asset_type = data.type + base_path = data.base_path + + if asset_type == AssetTypes.MLTABLE: + if is_url(asset_path): + try: + metadata_contents = read_remote_mltable_metadata_contents( + base_uri=asset_path, + datastore_operations=self._datastore_operation, + requests_pipeline=self._requests_pipeline, + ) + metadata_yaml_path = None + except Exception: # pylint: disable=W0718 + # skip validation for remote MLTable when the contents cannot be read + module_logger.info("Unable to access MLTable metadata at path %s", asset_path) + return None + else: + metadata_contents = read_local_mltable_metadata_contents(path=asset_path) + metadata_yaml_path = Path(asset_path, "MLTable") + metadata = MLTableMetadata._load(metadata_contents, metadata_yaml_path) + mltable_metadata_schema = self._try_get_mltable_metadata_jsonschema(data._mltable_schema_url) + if mltable_metadata_schema and not data._skip_validation: + validate_mltable_metadata( + mltable_metadata_dict=metadata_contents, + mltable_schema=mltable_metadata_schema, + ) + return cast(Optional[List[str]], metadata.referenced_uris()) + + if is_url(asset_path): + # skip validation for remote URI_FILE or URI_FOLDER + pass + elif os.path.isabs(asset_path): + _assert_local_path_matches_asset_type(asset_path, asset_type) + else: + abs_path = Path(base_path, asset_path).resolve() + _assert_local_path_matches_asset_type(str(abs_path), asset_type) + + return None + + def _try_get_mltable_metadata_jsonschema(self, mltable_schema_url: Optional[str]) -> Optional[Dict]: + if mltable_schema_url is None: + mltable_schema_url = MLTABLE_METADATA_SCHEMA_URL_FALLBACK + try: + return cast(Optional[Dict], download_mltable_metadata_schema(mltable_schema_url, self._requests_pipeline)) + except Exception: # pylint: disable=W0718 + module_logger.info( + 'Failed to download MLTable metadata jsonschema from "%s", skipping validation', + mltable_schema_url, + ) + return None + + @monitor_with_activity(ops_logger, "Data.Archive", ActivityType.PUBLICAPI) + def archive( + self, + name: str, + version: Optional[str] = None, + label: Optional[str] = None, + # pylint:disable=unused-argument + **kwargs: Any, + ) -> None: + """Archive a data asset. + + :param name: Name of data asset. + :type name: str + :param version: Version of data asset. + :type version: str + :param label: Label of the data asset. (mutually exclusive with version) + :type label: str + :return: None + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START data_operations_archive] + :end-before: [END data_operations_archive] + :language: python + :dedent: 8 + :caption: Archive data asset example. + """ + + _archive_or_restore( + asset_operations=self, + version_operation=self._operation, + container_operation=self._container_operation, + is_archived=True, + name=name, + version=version, + label=label, + ) + + @monitor_with_activity(ops_logger, "Data.Restore", ActivityType.PUBLICAPI) + def restore( + self, + name: str, + version: Optional[str] = None, + label: Optional[str] = None, + # pylint:disable=unused-argument + **kwargs: Any, + ) -> None: + """Restore an archived data asset. + + :param name: Name of data asset. + :type name: str + :param version: Version of data asset. + :type version: str + :param label: Label of the data asset. (mutually exclusive with version) + :type label: str + :return: None + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START data_operations_restore] + :end-before: [END data_operations_restore] + :language: python + :dedent: 8 + :caption: Restore data asset example. + """ + + _archive_or_restore( + asset_operations=self, + version_operation=self._operation, + container_operation=self._container_operation, + is_archived=False, + name=name, + version=version, + label=label, + ) + + def _get_latest_version(self, name: str) -> Data: + """Returns the latest version of the asset with the given name. Latest is defined as the most recently created, + not the most recently updated. + + :param name: The asset name + :type name: str + :return: The latest asset + :rtype: Data + """ + latest_version = _get_latest_version_from_container( + name, + self._container_operation, + self._resource_group_name, + self._workspace_name, + self._registry_name, + ) + return self.get(name, version=latest_version) + + @monitor_with_activity(ops_logger, "data.Share", ActivityType.PUBLICAPI) + @experimental + def share( + self, + name: str, + version: str, + *, + share_with_name: str, + share_with_version: str, + registry_name: str, + **kwargs: Any, + ) -> Data: + """Share a data asset from workspace to registry. + + :param name: Name of data asset. + :type name: str + :param version: Version of data asset. + :type version: str + :keyword share_with_name: Name of data asset to share with. + :paramtype share_with_name: str + :keyword share_with_version: Version of data asset to share with. + :paramtype share_with_version: str + :keyword registry_name: Name of the destination registry. + :paramtype registry_name: str + :return: Data asset object. + :rtype: ~azure.ai.ml.entities.Data + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START data_operations_share] + :end-before: [END data_operations_share] + :language: python + :dedent: 8 + :caption: Share data asset example. + """ + + # Get workspace info to get workspace GUID + workspace = self._service_client.workspaces.get( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + **kwargs, + ) + workspace_guid = workspace.workspace_id + workspace_location = workspace.location + + # Get data asset ID + asset_id = ASSET_ID_FORMAT.format( + workspace_location, + workspace_guid, + AzureMLResourceType.DATA, + name, + version, + ) + + data_ref = WorkspaceAssetReference( + name=share_with_name if share_with_name else name, + version=share_with_version if share_with_version else version, + asset_id=asset_id, + ) + + with self._set_registry_client(registry_name): + return self.create_or_update(data_ref) + + @monitor_with_activity(ops_logger, "data.Mount", ActivityType.PUBLICAPI) + @experimental + def mount( + self, + path: str, + *, + mount_point: Optional[str] = None, + mode: str = "ro_mount", + debug: bool = False, + persistent: bool = False, + **kwargs, + ) -> None: + """Mount a data asset to a local path, so that you can access data inside it + under a local path with any tools of your choice. + + :param path: The data asset path to mount, in the form of `azureml:<name>` or `azureml:<name>:<version>`. + :type path: str + :keyword mount_point: A local path used as mount point. + :type mount_point: str + :keyword mode: Mount mode. Only `ro_mount` (read-only) is supported for data asset mount. + :type mode: str + :keyword debug: Whether to enable verbose logging. + :type debug: bool + :keyword persistent: Whether to persist the mount after reboot. Applies only when running on Compute Instance, + where the 'CI_NAME' environment variable is set." + :type persistent: bool + :return: None + """ + + assert mode in ["ro_mount", "rw_mount"], "mode should be either `ro_mount` or `rw_mount`" + read_only = mode == "ro_mount" + assert read_only, "read-write mount for data asset is not supported yet" + + ci_name = os.environ.get("CI_NAME") + assert not persistent or ( + persistent and ci_name is not None + ), "persistent mount is only supported on Compute Instance, where the 'CI_NAME' environment variable is set." + + try: + from azureml.dataprep import rslex_fuse_subprocess_wrapper + except ImportError as exc: + raise MlException( + "Mount operations requires package azureml-dataprep-rslex installed. " + + "You can install it with Azure ML SDK with `pip install azure-ai-ml[mount]`." + ) from exc + + uri = rslex_fuse_subprocess_wrapper.build_data_asset_uri( + self._operation_scope._subscription_id, self._resource_group_name, self._workspace_name, path + ) + if persistent and ci_name is not None: + mount_name = f"unified_mount_{str(uuid.uuid4()).replace('-', '')}" + self._compute_operation.update_data_mounts( + self._resource_group_name, + self._workspace_name, + ci_name, + [ + ComputeInstanceDataMount( + source=uri, + source_type="URI", + mount_name=mount_name, + mount_action="Mount", + mount_path=mount_point or "", + ) + ], + api_version="2021-01-01", + **kwargs, + ) + print(f"Mount requested [name: {mount_name}]. Waiting for completion ...") + while True: + compute = self._compute_operation.get(self._resource_group_name, self._workspace_name, ci_name) + mounts = compute.properties.properties.data_mounts + try: + mount = [mount for mount in mounts if mount.mount_name == mount_name][0] + if mount.mount_state == "Mounted": + print(f"Mounted [name: {mount_name}].") + break + if mount.mount_state == "MountRequested": + pass + elif mount.mount_state == "MountFailed": + msg = f"Mount failed [name: {mount_name}]: {mount.error}" + raise MlException(message=msg, no_personal_data_message=msg) + else: + msg = f"Got unexpected mount state [name: {mount_name}]: {mount.mount_state}" + raise MlException(message=msg, no_personal_data_message=msg) + except IndexError: + pass + time.sleep(5) + + else: + rslex_fuse_subprocess_wrapper.start_fuse_mount_subprocess( + uri, mount_point, read_only, debug, credential=self._operation._config.credential + ) + + @contextmanager + # pylint: disable-next=docstring-missing-return,docstring-missing-rtype + def _set_registry_client(self, registry_name: str) -> Generator: + """Sets the registry client for the data operations. + + :param registry_name: Name of the registry. + :type registry_name: str + """ + rg_ = self._operation_scope._resource_group_name + sub_ = self._operation_scope._subscription_id + registry_ = self._operation_scope.registry_name + client_ = self._service_client + data_versions_operation_ = self._operation + + try: + _client, _rg, _sub = get_registry_client(self._service_client._config.credential, registry_name) + self._operation_scope.registry_name = registry_name + self._operation_scope._resource_group_name = _rg + self._operation_scope._subscription_id = _sub + self._service_client = _client + self._operation = _client.data_versions + yield + finally: + self._operation_scope.registry_name = registry_ + self._operation_scope._resource_group_name = rg_ + self._operation_scope._subscription_id = sub_ + self._service_client = client_ + self._operation = data_versions_operation_ + + +def _assert_local_path_matches_asset_type( + local_path: str, + asset_type: str, +) -> None: + # assert file system type matches asset type + if asset_type == AssetTypes.URI_FOLDER and not os.path.isdir(local_path): + raise ValidationException( + message="File path does not match asset type {}: {}".format(asset_type, local_path), + no_personal_data_message="File path does not match asset type {}".format(asset_type), + target=ErrorTarget.DATA, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND, + ) + if asset_type == AssetTypes.URI_FILE and not os.path.isfile(local_path): + raise ValidationException( + message="File path does not match asset type {}: {}".format(asset_type, local_path), + no_personal_data_message="File path does not match asset type {}".format(asset_type), + target=ErrorTarget.DATA, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_dataset_dataplane_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_dataset_dataplane_operations.py new file mode 100644 index 00000000..d9a95074 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_dataset_dataplane_operations.py @@ -0,0 +1,32 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from typing import List + +from azure.ai.ml._restclient.dataset_dataplane import AzureMachineLearningWorkspaces as ServiceClientDatasetDataplane +from azure.ai.ml._restclient.dataset_dataplane.models import BatchDataUriResponse, BatchGetResolvedURIs +from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations + +module_logger = logging.getLogger(__name__) + + +class DatasetDataplaneOperations(_ScopeDependentOperations): + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client: ServiceClientDatasetDataplane, + ): + super().__init__(operation_scope, operation_config) + self._operation = service_client.data_version + + def get_batch_dataset_uris(self, dataset_ids: List[str]) -> BatchDataUriResponse: + batch_uri_request = BatchGetResolvedURIs(values=dataset_ids) + return self._operation.batch_get_resolved_uris( + self._operation_scope.subscription_id, + self._operation_scope.resource_group_name, + self._workspace_name, + body=batch_uri_request, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_datastore_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_datastore_operations.py new file mode 100644 index 00000000..74a518e9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_datastore_operations.py @@ -0,0 +1,329 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import time +import uuid +from typing import Dict, Iterable, Optional, cast + +from marshmallow.exceptions import ValidationError as SchemaValidationError + +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._restclient.v2024_01_01_preview import AzureMachineLearningWorkspaces as ServiceClient012024Preview +from azure.ai.ml._restclient.v2024_01_01_preview.models import ComputeInstanceDataMount +from azure.ai.ml._restclient.v2024_07_01_preview import AzureMachineLearningWorkspaces as ServiceClient072024Preview +from azure.ai.ml._restclient.v2024_07_01_preview.models import Datastore as DatastoreData +from azure.ai.ml._restclient.v2024_07_01_preview.models import DatastoreSecrets, NoneDatastoreCredentials, SecretExpiry +from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml.entities._datastore.datastore import Datastore +from azure.ai.ml.exceptions import MlException, ValidationException + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class DatastoreOperations(_ScopeDependentOperations): + """Represents a client for performing operations on Datastores. + + You should not instantiate this class directly. Instead, you should create MLClient and use this client via the + property MLClient.datastores + + :param operation_scope: Scope variables for the operations classes of an MLClient object. + :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope + :param operation_config: Common configuration for operations classes of an MLClient object. + :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig + :param serviceclient_2024_01_01_preview: Service client to allow end users to operate on Azure Machine Learning + Workspace resources. + :type serviceclient_2024_01_01_preview: ~azure.ai.ml._restclient.v2023_01_01_preview. + _azure_machine_learning_workspaces.AzureMachineLearningWorkspaces + :param serviceclient_2024_07_01_preview: Service client to allow end users to operate on Azure Machine Learning + Workspace resources. + :type serviceclient_2024_07_01_preview: ~azure.ai.ml._restclient.v2024_07_01_preview. + _azure_machine_learning_workspaces.AzureMachineLearningWorkspaces + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + serviceclient_2024_01_01_preview: ServiceClient012024Preview, + serviceclient_2024_07_01_preview: ServiceClient072024Preview, + **kwargs: Dict, + ): + super(DatastoreOperations, self).__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._operation = serviceclient_2024_07_01_preview.datastores + self._compute_operation = serviceclient_2024_01_01_preview.compute + self._credential = serviceclient_2024_07_01_preview._config.credential + self._init_kwargs = kwargs + + @monitor_with_activity(ops_logger, "Datastore.List", ActivityType.PUBLICAPI) + def list(self, *, include_secrets: bool = False) -> Iterable[Datastore]: + """Lists all datastores and associated information within a workspace. + + :keyword include_secrets: Include datastore secrets in returned datastores, defaults to False + :paramtype include_secrets: bool + :return: An iterator like instance of Datastore objects + :rtype: ~azure.core.paging.ItemPaged[Datastore] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START datastore_operations_list] + :end-before: [END datastore_operations_list] + :language: python + :dedent: 8 + :caption: List datastore example. + """ + + def _list_helper(datastore_resource: Datastore, include_secrets: bool) -> Datastore: + if include_secrets: + self._fetch_and_populate_secret(datastore_resource) + return Datastore._from_rest_object(datastore_resource) + + return cast( + Iterable[Datastore], + self._operation.list( + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + cls=lambda objs: [_list_helper(obj, include_secrets) for obj in objs], + **self._init_kwargs, + ), + ) + + @monitor_with_activity(ops_logger, "Datastore.ListSecrets", ActivityType.PUBLICAPI) + def _list_secrets(self, name: str, expirable_secret: bool = False) -> DatastoreSecrets: + return self._operation.list_secrets( + name=name, + body=SecretExpiry(expirable_secret=expirable_secret), + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + **self._init_kwargs, + ) + + @monitor_with_activity(ops_logger, "Datastore.Delete", ActivityType.PUBLICAPI) + def delete(self, name: str) -> None: + """Deletes a datastore reference with the given name from the workspace. This method does not delete the actual + datastore or underlying data in the datastore. + + :param name: Name of the datastore + :type name: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START datastore_operations_delete] + :end-before: [END datastore_operations_delete] + :language: python + :dedent: 8 + :caption: Delete datastore example. + """ + + self._operation.delete( + name=name, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + **self._init_kwargs, + ) + + @monitor_with_activity(ops_logger, "Datastore.Get", ActivityType.PUBLICAPI) + def get(self, name: str, *, include_secrets: bool = False) -> Datastore: # type: ignore + """Returns information about the datastore referenced by the given name. + + :param name: Datastore name + :type name: str + :keyword include_secrets: Include datastore secrets in the returned datastore, defaults to False + :paramtype include_secrets: bool + :return: Datastore with the specified name. + :rtype: Datastore + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START datastore_operations_get] + :end-before: [END datastore_operations_get] + :language: python + :dedent: 8 + :caption: Get datastore example. + """ + try: + datastore_resource = self._operation.get( + name=name, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + **self._init_kwargs, + ) + if include_secrets: + self._fetch_and_populate_secret(datastore_resource) + return Datastore._from_rest_object(datastore_resource) + except (ValidationException, SchemaValidationError) as ex: + log_and_raise_error(ex) + + def _fetch_and_populate_secret(self, datastore_resource: DatastoreData) -> None: + if datastore_resource.name and not isinstance( + datastore_resource.properties.credentials, NoneDatastoreCredentials + ): + secrets = self._list_secrets(name=datastore_resource.name, expirable_secret=True) + datastore_resource.properties.credentials.secrets = secrets + + @monitor_with_activity(ops_logger, "Datastore.GetDefault", ActivityType.PUBLICAPI) + def get_default(self, *, include_secrets: bool = False) -> Datastore: # type: ignore + """Returns the workspace's default datastore. + + :keyword include_secrets: Include datastore secrets in the returned datastore, defaults to False + :paramtype include_secrets: bool + :return: The default datastore. + :rtype: Datastore + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START datastore_operations_get_default] + :end-before: [END datastore_operations_get_default] + :language: python + :dedent: 8 + :caption: Get default datastore example. + """ + try: + datastore_resource = self._operation.list( + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + is_default=True, + **self._init_kwargs, + ).next() + if include_secrets: + self._fetch_and_populate_secret(datastore_resource) + return Datastore._from_rest_object(datastore_resource) + except (ValidationException, SchemaValidationError) as ex: + log_and_raise_error(ex) + + @monitor_with_activity(ops_logger, "Datastore.CreateOrUpdate", ActivityType.PUBLICAPI) + def create_or_update(self, datastore: Datastore) -> Datastore: # type: ignore + """Attaches the passed in datastore to the workspace or updates the datastore if it already exists. + + :param datastore: The configuration of the datastore to attach. + :type datastore: Datastore + :return: The attached datastore. + :rtype: Datastore + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START datastore_operations_create_or_update] + :end-before: [END datastore_operations_create_or_update] + :language: python + :dedent: 8 + :caption: Create datastore example. + """ + try: + ds_request = datastore._to_rest_object() + datastore_resource = self._operation.create_or_update( + name=datastore.name, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + body=ds_request, + skip_validation=True, + ) + return Datastore._from_rest_object(datastore_resource) + except Exception as ex: # pylint: disable=W0718 + if isinstance(ex, (ValidationException, SchemaValidationError)): + log_and_raise_error(ex) + else: + raise ex + + @monitor_with_activity(ops_logger, "Datastore.Mount", ActivityType.PUBLICAPI) + @experimental + def mount( + self, + path: str, + *, + mount_point: Optional[str] = None, + mode: str = "ro_mount", + debug: bool = False, + persistent: bool = False, + **kwargs, + ) -> None: + """Mount a datastore to a local path, so that you can access data inside it + under a local path with any tools of your choice. + + :param path: The data store path to mount, in the form of `<name>` or `azureml://datastores/<name>`. + :type path: str + :keyword mount_point: A local path used as mount point. + :type mount_point: str + :keyword mode: Mount mode, either `ro_mount` (read-only) or `rw_mount` (read-write). + :type mode: str + :keyword debug: Whether to enable verbose logging. + :type debug: bool + :keyword persistent: Whether to persist the mount after reboot. Applies only when running on Compute Instance, + where the 'CI_NAME' environment variable is set." + :type persistent: bool + :return: None + """ + + assert mode in ["ro_mount", "rw_mount"], "mode should be either `ro_mount` or `rw_mount`" + read_only = mode == "ro_mount" + + import os + + ci_name = os.environ.get("CI_NAME") + assert not persistent or ( + persistent and ci_name is not None + ), "persistent mount is only supported on Compute Instance, where the 'CI_NAME' environment variable is set." + + try: + from azureml.dataprep import rslex_fuse_subprocess_wrapper + except ImportError as exc: + msg = "Mount operations requires package azureml-dataprep-rslex installed. \ + You can install it with Azure ML SDK with `pip install azure-ai-ml[mount]`." + raise MlException(message=msg, no_personal_data_message=msg) from exc + + uri = rslex_fuse_subprocess_wrapper.build_datastore_uri( + self._operation_scope._subscription_id, self._resource_group_name, self._workspace_name, path + ) + if persistent and ci_name is not None: + mount_name = f"unified_mount_{str(uuid.uuid4()).replace('-', '')}" + self._compute_operation.update_data_mounts( + self._resource_group_name, + self._workspace_name, + ci_name, + [ + ComputeInstanceDataMount( + source=uri, + source_type="URI", + mount_name=mount_name, + mount_action="Mount", + mount_path=mount_point or "", + ) + ], + api_version="2021-01-01", + **kwargs, + ) + print(f"Mount requested [name: {mount_name}]. Waiting for completion ...") + while True: + compute = self._compute_operation.get(self._resource_group_name, self._workspace_name, ci_name) + mounts = compute.properties.properties.data_mounts + try: + mount = [mount for mount in mounts if mount.mount_name == mount_name][0] + if mount.mount_state == "Mounted": + print(f"Mounted [name: {mount_name}].") + break + if mount.mount_state == "MountRequested": + pass + elif mount.mount_state == "MountFailed": + msg = f"Mount failed [name: {mount_name}]: {mount.error}" + raise MlException(message=msg, no_personal_data_message=msg) + else: + msg = f"Got unexpected mount state [name: {mount_name}]: {mount.mount_state}" + raise MlException(message=msg, no_personal_data_message=msg) + except IndexError: + pass + time.sleep(5) + else: + rslex_fuse_subprocess_wrapper.start_fuse_mount_subprocess( + uri, mount_point, read_only, debug, credential=self._operation._config.credential + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_environment_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_environment_operations.py new file mode 100644 index 00000000..228204a7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_environment_operations.py @@ -0,0 +1,569 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from contextlib import contextmanager +from typing import Any, Generator, Iterable, Optional, Union, cast + +from marshmallow.exceptions import ValidationError as SchemaValidationError + +from azure.ai.ml._artifacts._artifact_utilities import _check_and_upload_env_build_context +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import ( + AzureMachineLearningWorkspaces as ServiceClient102021Dataplane, +) +from azure.ai.ml._restclient.v2023_04_01_preview import AzureMachineLearningWorkspaces as ServiceClient042023Preview +from azure.ai.ml._restclient.v2023_04_01_preview.models import EnvironmentVersion, ListViewType +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationsContainer, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._asset_utils import ( + _archive_or_restore, + _get_latest, + _get_next_latest_versions_from_container, + _resolve_label_to_asset, +) +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml._utils._registry_utils import ( + get_asset_body_for_registry_storage, + get_registry_client, + get_sas_uri_for_registry_asset, +) +from azure.ai.ml.constants._common import ARM_ID_PREFIX, ASSET_ID_FORMAT, AzureMLResourceType +from azure.ai.ml.entities._assets import Environment, WorkspaceAssetReference +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException +from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class EnvironmentOperations(_ScopeDependentOperations): + """EnvironmentOperations. + + You should not instantiate this class directly. Instead, you should + create an MLClient instance that instantiates it for you and + attaches it as an attribute. + + :param operation_scope: Scope variables for the operations classes of an MLClient object. + :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope + :param operation_config: Common configuration for operations classes of an MLClient object. + :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig + :param service_client: Service client to allow end users to operate on Azure Machine Learning Workspace + resources (ServiceClient042023Preview or ServiceClient102021Dataplane). + :type service_client: typing.Union[ + ~azure.ai.ml._restclient.v2023_04_01_preview._azure_machine_learning_workspaces.AzureMachineLearningWorkspaces, + ~azure.ai.ml._restclient.v2021_10_01_dataplanepreview._azure_machine_learning_workspaces. + AzureMachineLearningWorkspaces] + :param all_operations: All operations classes of an MLClient object. + :type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client: Union[ServiceClient042023Preview, ServiceClient102021Dataplane], + all_operations: OperationsContainer, + **kwargs: Any, + ): + super(EnvironmentOperations, self).__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._kwargs = kwargs + self._containers_operations = service_client.environment_containers + self._version_operations = service_client.environment_versions + self._service_client = service_client + self._all_operations = all_operations + self._datastore_operation = all_operations.all_operations[AzureMLResourceType.DATASTORE] + + # Maps a label to a function which given an asset name, + # returns the asset associated with the label + self._managed_label_resolver = {"latest": self._get_latest_version} + + @monitor_with_activity(ops_logger, "Environment.CreateOrUpdate", ActivityType.PUBLICAPI) + def create_or_update(self, environment: Environment) -> Environment: # type: ignore + """Returns created or updated environment asset. + + :param environment: Environment object + :type environment: ~azure.ai.ml.entities._assets.Environment + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Environment cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty directory. + :return: Created or updated Environment object + :rtype: ~azure.ai.ml.entities.Environment + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START env_operations_create_or_update] + :end-before: [END env_operations_create_or_update] + :language: python + :dedent: 8 + :caption: Create environment. + """ + try: + if not environment.version and environment._auto_increment_version: + + next_version, latest_version = _get_next_latest_versions_from_container( + name=environment.name, + container_operation=self._containers_operations, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + registry_name=self._registry_name, + **self._kwargs, + ) + # If user not passing the version, SDK will try to update the latest version + return self._try_update_latest_version(next_version, latest_version, environment) + + sas_uri = None + if self._registry_name: + if isinstance(environment, WorkspaceAssetReference): + # verify that environment is not already in registry + try: + self._version_operations.get( + name=environment.name, + version=environment.version, + resource_group_name=self._resource_group_name, + registry_name=self._registry_name, + ) + except Exception as err: # pylint: disable=W0718 + if isinstance(err, ResourceNotFoundError): + pass + else: + raise err + else: + msg = "A environment with this name and version already exists in registry" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.ENVIRONMENT, + error_category=ErrorCategory.USER_ERROR, + ) + + environment_rest = environment._to_rest_object() + result = self._service_client.resource_management_asset_reference.begin_import_method( + resource_group_name=self._resource_group_name, + registry_name=self._registry_name, + body=environment_rest, + **self._kwargs, + ).result() + + if not result: + env_rest_obj = self._get(name=environment.name, version=environment.version) + return Environment._from_rest_object(env_rest_obj) + + sas_uri = get_sas_uri_for_registry_asset( + service_client=self._service_client, + name=environment.name, + version=environment.version, + resource_group=self._resource_group_name, + registry=self._registry_name, + body=get_asset_body_for_registry_storage( + self._registry_name, + "environments", + environment.name, + environment.version, + ), + ) + + # upload only in case of when its not registry + # or successfully acquire sas_uri + if not self._registry_name or sas_uri: + environment = _check_and_upload_env_build_context( + environment=environment, + operations=self, + sas_uri=sas_uri, + show_progress=self._show_progress, + ) + env_version_resource = environment._to_rest_object() + env_rest_obj = ( + self._version_operations.begin_create_or_update( + name=environment.name, + version=environment.version, + registry_name=self._registry_name, + body=env_version_resource, + **self._scope_kwargs, + **self._kwargs, + ).result() + if self._registry_name + else self._version_operations.create_or_update( + name=environment.name, + version=environment.version, + workspace_name=self._workspace_name, + body=env_version_resource, + **self._scope_kwargs, + **self._kwargs, + ) + ) + if not env_rest_obj and self._registry_name: + env_rest_obj = self._get(name=str(environment.name), version=environment.version) + return Environment._from_rest_object(env_rest_obj) + except Exception as ex: # pylint: disable=W0718 + if isinstance(ex, SchemaValidationError): + log_and_raise_error(ex) + else: + raise ex + + def _try_update_latest_version( + self, next_version: str, latest_version: str, environment: Environment + ) -> Environment: + env = None + if self._registry_name: + environment.version = next_version + env = self.create_or_update(environment) + else: + environment.version = latest_version + try: # Try to update the latest version + env = self.create_or_update(environment) + except Exception as ex: # pylint: disable=W0718 + if isinstance(ex, ResourceExistsError): + environment.version = next_version + env = self.create_or_update(environment) + else: + raise ex + return env + + def _get(self, name: str, version: Optional[str] = None) -> EnvironmentVersion: + if version: + return ( + self._version_operations.get( + name=name, + version=version, + registry_name=self._registry_name, + **self._scope_kwargs, + **self._kwargs, + ) + if self._registry_name + else self._version_operations.get( + name=name, + version=version, + workspace_name=self._workspace_name, + **self._scope_kwargs, + **self._kwargs, + ) + ) + return ( + self._containers_operations.get( + name=name, + registry_name=self._registry_name, + **self._scope_kwargs, + **self._kwargs, + ) + if self._registry_name + else self._containers_operations.get( + name=name, + workspace_name=self._workspace_name, + **self._scope_kwargs, + **self._kwargs, + ) + ) + + @monitor_with_activity(ops_logger, "Environment.Get", ActivityType.PUBLICAPI) + def get(self, name: str, version: Optional[str] = None, label: Optional[str] = None) -> Environment: + """Returns the specified environment asset. + + :param name: Name of the environment. + :type name: str + :param version: Version of the environment. + :type version: str + :param label: Label of the environment. (mutually exclusive with version) + :type label: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Environment cannot be successfully validated. + Details will be provided in the error message. + :return: Environment object + :rtype: ~azure.ai.ml.entities.Environment + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START env_operations_get] + :end-before: [END env_operations_get] + :language: python + :dedent: 8 + :caption: Get example. + """ + if version and label: + msg = "Cannot specify both version and label." + raise ValidationException( + message=msg, + target=ErrorTarget.ENVIRONMENT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + if label: + return _resolve_label_to_asset(self, name, label) + + if not version: + msg = "Must provide either version or label." + raise ValidationException( + message=msg, + target=ErrorTarget.ENVIRONMENT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + name = _preprocess_environment_name(name) + env_version_resource = self._get(name, version) + + return Environment._from_rest_object(env_version_resource) + + @monitor_with_activity(ops_logger, "Environment.List", ActivityType.PUBLICAPI) + def list( + self, + name: Optional[str] = None, + *, + list_view_type: ListViewType = ListViewType.ACTIVE_ONLY, + ) -> Iterable[Environment]: + """List all environment assets in workspace. + + :param name: Name of the environment. + :type name: Optional[str] + :keyword list_view_type: View type for including/excluding (for example) archived environments. + Default: ACTIVE_ONLY. + :type list_view_type: Optional[ListViewType] + :return: An iterator like instance of Environment objects. + :rtype: ~azure.core.paging.ItemPaged[Environment] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START env_operations_list] + :end-before: [END env_operations_list] + :language: python + :dedent: 8 + :caption: List example. + """ + if name: + return cast( + Iterable[Environment], + ( + self._version_operations.list( + name=name, + registry_name=self._registry_name, + cls=lambda objs: [Environment._from_rest_object(obj) for obj in objs], + **self._scope_kwargs, + **self._kwargs, + ) + if self._registry_name + else self._version_operations.list( + name=name, + workspace_name=self._workspace_name, + cls=lambda objs: [Environment._from_rest_object(obj) for obj in objs], + list_view_type=list_view_type, + **self._scope_kwargs, + **self._kwargs, + ) + ), + ) + return cast( + Iterable[Environment], + ( + self._containers_operations.list( + registry_name=self._registry_name, + cls=lambda objs: [Environment._from_container_rest_object(obj) for obj in objs], + **self._scope_kwargs, + **self._kwargs, + ) + if self._registry_name + else self._containers_operations.list( + workspace_name=self._workspace_name, + cls=lambda objs: [Environment._from_container_rest_object(obj) for obj in objs], + list_view_type=list_view_type, + **self._scope_kwargs, + **self._kwargs, + ) + ), + ) + + @monitor_with_activity(ops_logger, "Environment.Delete", ActivityType.PUBLICAPI) + def archive( + self, + name: str, + version: Optional[str] = None, + label: Optional[str] = None, + # pylint:disable=unused-argument + **kwargs: Any, + ) -> None: + """Archive an environment or an environment version. + + :param name: Name of the environment. + :type name: str + :param version: Version of the environment. + :type version: str + :param label: Label of the environment. (mutually exclusive with version) + :type label: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START env_operations_archive] + :end-before: [END env_operations_archive] + :language: python + :dedent: 8 + :caption: Archive example. + """ + name = _preprocess_environment_name(name) + _archive_or_restore( + asset_operations=self, + version_operation=self._version_operations, + container_operation=self._containers_operations, + is_archived=True, + name=name, + version=version, + label=label, + ) + + @monitor_with_activity(ops_logger, "Environment.Restore", ActivityType.PUBLICAPI) + def restore( + self, + name: str, + version: Optional[str] = None, + label: Optional[str] = None, + # pylint:disable=unused-argument + **kwargs: Any, + ) -> None: + """Restore an archived environment version. + + :param name: Name of the environment. + :type name: str + :param version: Version of the environment. + :type version: str + :param label: Label of the environment. (mutually exclusive with version) + :type label: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START env_operations_restore] + :end-before: [END env_operations_restore] + :language: python + :dedent: 8 + :caption: Restore example. + """ + name = _preprocess_environment_name(name) + _archive_or_restore( + asset_operations=self, + version_operation=self._version_operations, + container_operation=self._containers_operations, + is_archived=False, + name=name, + version=version, + label=label, + ) + + def _get_latest_version(self, name: str) -> Environment: + """Returns the latest version of the asset with the given name. + + Latest is defined as the most recently created, not the most + recently updated. + + :param name: The asset name + :type name: str + :return: The latest version of the named environment + :rtype: Environment + """ + result = _get_latest( + name, + self._version_operations, + self._resource_group_name, + self._workspace_name, + self._registry_name, + ) + return Environment._from_rest_object(result) + + @monitor_with_activity(ops_logger, "Environment.Share", ActivityType.PUBLICAPI) + @experimental + def share( + self, + name: str, + version: str, + *, + share_with_name: str, + share_with_version: str, + registry_name: str, + ) -> Environment: + """Share a environment asset from workspace to registry. + + :param name: Name of environment asset. + :type name: str + :param version: Version of environment asset. + :type version: str + :keyword share_with_name: Name of environment asset to share with. + :paramtype share_with_name: str + :keyword share_with_version: Version of environment asset to share with. + :paramtype share_with_version: str + :keyword registry_name: Name of the destination registry. + :paramtype registry_name: str + :return: Environment asset object. + :rtype: ~azure.ai.ml.entities.Environment + """ + + # Get workspace info to get workspace GUID + workspace = self._service_client.workspaces.get( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + ) + workspace_guid = workspace.workspace_id + workspace_location = workspace.location + + # Get environment asset ID + asset_id = ASSET_ID_FORMAT.format( + workspace_location, + workspace_guid, + AzureMLResourceType.ENVIRONMENT, + name, + version, + ) + + environment_ref = WorkspaceAssetReference( + name=share_with_name if share_with_name else name, + version=share_with_version if share_with_version else version, + asset_id=asset_id, + ) + + with self._set_registry_client(registry_name): + return self.create_or_update(environment_ref) + + @contextmanager + # pylint: disable-next=docstring-missing-return,docstring-missing-rtype + def _set_registry_client(self, registry_name: str) -> Generator: + """Sets the registry client for the environment operations. + + :param registry_name: Name of the registry. + :type registry_name: str + """ + rg_ = self._operation_scope._resource_group_name + sub_ = self._operation_scope._subscription_id + registry_ = self._operation_scope.registry_name + client_ = self._service_client + environment_versions_operation_ = self._version_operations + + try: + _client, _rg, _sub = get_registry_client(self._service_client._config.credential, registry_name) + self._operation_scope.registry_name = registry_name + self._operation_scope._resource_group_name = _rg + self._operation_scope._subscription_id = _sub + self._service_client = _client + self._version_operations = _client.environment_versions + yield + finally: + self._operation_scope.registry_name = registry_ + self._operation_scope._resource_group_name = rg_ + self._operation_scope._subscription_id = sub_ + self._service_client = client_ + self._version_operations = environment_versions_operation_ + + +def _preprocess_environment_name(environment_name: str) -> str: + if environment_name.startswith(ARM_ID_PREFIX): + return environment_name[len(ARM_ID_PREFIX) :] + return environment_name diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_evaluator_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_evaluator_operations.py new file mode 100644 index 00000000..47167947 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_evaluator_operations.py @@ -0,0 +1,222 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from os import PathLike +from typing import Any, Dict, Iterable, Optional, Union, cast + +from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import ( + AzureMachineLearningWorkspaces as ServiceClient102021Dataplane, +) +from azure.ai.ml._restclient.v2023_08_01_preview import AzureMachineLearningWorkspaces as ServiceClient082023Preview +from azure.ai.ml._restclient.v2023_08_01_preview.models import ListViewType +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationsContainer, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml._utils.utils import _get_evaluator_properties, _is_evaluator +from azure.ai.ml.entities._assets import Model +from azure.ai.ml.entities._assets.workspace_asset_reference import WorkspaceAssetReference +from azure.ai.ml.exceptions import UnsupportedOperationError +from azure.ai.ml.operations._datastore_operations import DatastoreOperations +from azure.ai.ml.operations._model_operations import ModelOperations +from azure.core.exceptions import ResourceNotFoundError + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class EvaluatorOperations(_ScopeDependentOperations): + """EvaluatorOperations. + + You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it + for you and attaches it as an attribute. + + :param operation_scope: Scope variables for the operations classes of an MLClient object. + :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope + :param operation_config: Common configuration for operations classes of an MLClient object. + :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig + :param service_client: Service client to allow end users to operate on Azure Machine Learning Workspace + resources (ServiceClient082023Preview or ServiceClient102021Dataplane). + :type service_client: typing.Union[ + azure.ai.ml._restclient.v2023_04_01_preview._azure_machine_learning_workspaces.AzureMachineLearningWorkspaces, + azure.ai.ml._restclient.v2021_10_01_dataplanepreview._azure_machine_learning_workspaces. + AzureMachineLearningWorkspaces] + :param datastore_operations: Represents a client for performing operations on Datastores. + :type datastore_operations: ~azure.ai.ml.operations._datastore_operations.DatastoreOperations + :param all_operations: All operations classes of an MLClient object. + :type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer + :param kwargs: A dictionary of additional configuration parameters. + :type kwargs: dict + """ + + # pylint: disable=unused-argument + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client: Union[ServiceClient082023Preview, ServiceClient102021Dataplane], + datastore_operations: DatastoreOperations, + all_operations: Optional[OperationsContainer] = None, + **kwargs, + ): + super(EvaluatorOperations, self).__init__(operation_scope, operation_config) + + ops_logger.update_filter() + self._model_op = ModelOperations( + operation_scope=operation_scope, + operation_config=operation_config, + service_client=service_client, + datastore_operations=datastore_operations, + all_operations=all_operations, + **{ModelOperations._IS_EVALUATOR: True}, + **kwargs, + ) + self._operation_scope = self._model_op._operation_scope + self._datastore_operation = self._model_op._datastore_operation + + @monitor_with_activity(ops_logger, "Evaluator.CreateOrUpdate", ActivityType.PUBLICAPI) + def create_or_update( # type: ignore + self, model: Union[Model, WorkspaceAssetReference], **kwargs: Any + ) -> Model: # TODO: Are we going to implement job_name? + """Returns created or updated model asset. + + :param model: Model asset object. + :type model: ~azure.ai.ml.entities.Model + :raises ~azure.ai.ml.exceptions.AssetPathException: Raised when the Model artifact path is + already linked to another asset + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Model cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty directory. + :return: Model asset object. + :rtype: ~azure.ai.ml.entities.Model + """ + model.properties.update(_get_evaluator_properties()) + return self._model_op.create_or_update(model) + + def _raise_if_not_evaluator(self, properties: Optional[Dict[str, Any]], message: str) -> None: + """ + :param properties: The properties of a model. + :type properties: dict[str, str] + :param message: The message to be set on exception. + :type message: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if model is not an + evaluator. + """ + if properties is not None and not _is_evaluator(properties): + raise ResourceNotFoundError( + message=message, + response=None, + ) + + @monitor_with_activity(ops_logger, "Evaluator.Get", ActivityType.PUBLICAPI) + def get(self, name: str, *, version: Optional[str] = None, label: Optional[str] = None, **kwargs) -> Model: + """Returns information about the specified model asset. + + :param name: Name of the model. + :type name: str + :keyword version: Version of the model. + :paramtype version: str + :keyword label: Label of the model. (mutually exclusive with version) + :paramtype label: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Model cannot be successfully validated. + Details will be provided in the error message. + :return: Model asset object. + :rtype: ~azure.ai.ml.entities.Model + """ + model = self._model_op.get(name, version, label) + + properties = None if model is None else model.properties + self._raise_if_not_evaluator( + properties, + f"Evaluator {name} with version {version} not found.", + ) + + return model + + @monitor_with_activity(ops_logger, "Evaluator.Download", ActivityType.PUBLICAPI) + def download(self, name: str, version: str, download_path: Union[PathLike, str] = ".", **kwargs: Any) -> None: + """Download files related to a model. + + :param name: Name of the model. + :type name: str + :param version: Version of the model. + :type version: str + :param download_path: Local path as download destination, defaults to current working directory of the current + user. Contents will be overwritten. + :type download_path: Union[PathLike, str] + :raises ResourceNotFoundError: if can't find a model matching provided name. + """ + self._model_op.download(name, version, download_path) + + @monitor_with_activity(ops_logger, "Evaluator.List", ActivityType.PUBLICAPI) + def list( + self, + name: str, + stage: Optional[str] = None, + *, + list_view_type: ListViewType = ListViewType.ACTIVE_ONLY, + **kwargs: Any, + ) -> Iterable[Model]: + """List all model assets in workspace. + + :param name: Name of the model. + :type name: str + :param stage: The Model stage + :type stage: Optional[str] + :keyword list_view_type: View type for including/excluding (for example) archived models. + Defaults to :attr:`ListViewType.ACTIVE_ONLY`. + :paramtype list_view_type: ListViewType + :return: An iterator like instance of Model objects + :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.Model] + """ + properties_str = "is-promptflow=true,is-evaluator=true" + if name: + return cast( + Iterable[Model], + ( + self._model_op._model_versions_operation.list( + name=name, + registry_name=self._model_op._registry_name, + cls=lambda objs: [Model._from_rest_object(obj) for obj in objs], + properties=properties_str, + **self._model_op._scope_kwargs, + ) + if self._registry_name + else self._model_op._model_versions_operation.list( + name=name, + workspace_name=self._model_op._workspace_name, + cls=lambda objs: [Model._from_rest_object(obj) for obj in objs], + list_view_type=list_view_type, + properties=properties_str, + stage=stage, + **self._model_op._scope_kwargs, + ) + ), + ) + # ModelContainer object does not carry properties. + raise UnsupportedOperationError("list on evaluation operations without name provided") + # TODO: Implement filtering of the ModelContainerOperations list output + # return cast( + # Iterable[Model], ( + # self._model_container_operation.list( + # registry_name=self._registry_name, + # cls=lambda objs: [Model._from_container_rest_object(obj) for obj in objs], + # list_view_type=list_view_type, + # **self._scope_kwargs, + # ) + # if self._registry_name + # else self._model_container_operation.list( + # workspace_name=self._workspace_name, + # cls=lambda objs: [Model._from_container_rest_object(obj) for obj in objs], + # list_view_type=list_view_type, + # **self._scope_kwargs, + # ) + # ) + # ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_set_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_set_operations.py new file mode 100644 index 00000000..db51295f --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_set_operations.py @@ -0,0 +1,456 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import json +import os +from copy import deepcopy +from datetime import datetime +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +from marshmallow.exceptions import ValidationError as SchemaValidationError + +from azure.ai.ml._artifacts._artifact_utilities import _check_and_upload_path +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._restclient.v2023_08_01_preview import AzureMachineLearningWorkspaces as ServiceClient082023Preview +from azure.ai.ml._restclient.v2023_10_01 import AzureMachineLearningServices as ServiceClient102023 +from azure.ai.ml._restclient.v2023_10_01.models import ( + FeaturesetVersion, + FeaturesetVersionBackfillRequest, + FeatureWindow, +) +from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._feature_store_utils import ( + _archive_or_restore, + _datetime_to_str, + read_feature_set_metadata, + read_remote_feature_set_spec_metadata, +) +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml._utils.utils import is_url +from azure.ai.ml.constants import ListViewType +from azure.ai.ml.entities._assets._artifacts.feature_set import FeatureSet +from azure.ai.ml.entities._feature_set.data_availability_status import DataAvailabilityStatus +from azure.ai.ml.entities._feature_set.feature import Feature +from azure.ai.ml.entities._feature_set.feature_set_backfill_metadata import FeatureSetBackfillMetadata +from azure.ai.ml.entities._feature_set.feature_set_materialization_metadata import FeatureSetMaterializationMetadata +from azure.ai.ml.entities._feature_set.featureset_spec_metadata import FeaturesetSpecMetadata +from azure.ai.ml.entities._feature_set.materialization_compute_resource import MaterializationComputeResource +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException +from azure.ai.ml.operations._datastore_operations import DatastoreOperations +from azure.core.paging import ItemPaged +from azure.core.polling import LROPoller +from azure.core.tracing.decorator import distributed_trace + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class FeatureSetOperations(_ScopeDependentOperations): + """FeatureSetOperations. + + You should not instantiate this class directly. Instead, you should + create an MLClient instance that instantiates it for you and + attaches it as an attribute. + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client: ServiceClient102023, + service_client_for_jobs: ServiceClient082023Preview, + datastore_operations: DatastoreOperations, + **kwargs: Dict, + ): + super(FeatureSetOperations, self).__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._operation = service_client.featureset_versions + self._container_operation = service_client.featureset_containers + self._jobs_operation = service_client_for_jobs.jobs + self._feature_operation = service_client.features + self._service_client = service_client + self._datastore_operation = datastore_operations + self._init_kwargs = kwargs + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureSet.List", ActivityType.PUBLICAPI) + def list( + self, + name: Optional[str] = None, + *, + list_view_type: ListViewType = ListViewType.ACTIVE_ONLY, + **kwargs: Dict, + ) -> ItemPaged[FeatureSet]: + """List the FeatureSet assets of the workspace. + + :param name: Name of a specific FeatureSet asset, optional. + :type name: Optional[str] + :keyword list_view_type: View type for including/excluding (for example) archived FeatureSet assets. + Defaults to ACTIVE_ONLY. + :type list_view_type: Optional[ListViewType] + :return: An iterator like instance of FeatureSet objects + :rtype: ~azure.core.paging.ItemPaged[FeatureSet] + """ + if name: + return self._operation.list( + workspace_name=self._workspace_name, + name=name, + cls=lambda objs: [FeatureSet._from_rest_object(obj) for obj in objs], + list_view_type=list_view_type, + **self._scope_kwargs, + **kwargs, + ) + return self._container_operation.list( + workspace_name=self._workspace_name, + cls=lambda objs: [FeatureSet._from_container_rest_object(obj) for obj in objs], + list_view_type=list_view_type, + **self._scope_kwargs, + **kwargs, + ) + + def _get(self, name: str, version: Optional[str] = None, **kwargs: Dict) -> FeaturesetVersion: + return self._operation.get( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + name=name, + version=version, + **self._init_kwargs, + **kwargs, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureSet.Get", ActivityType.PUBLICAPI) + def get(self, name: str, version: str, **kwargs: Dict) -> FeatureSet: # type: ignore + """Get the specified FeatureSet asset. + + :param name: Name of FeatureSet asset. + :type name: str + :param version: Version of FeatureSet asset. + :type version: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if FeatureSet cannot be successfully + identified and retrieved. Details will be provided in the error message. + :raises ~azure.core.exceptions.HttpResponseError: Raised if the corresponding name and version cannot be + retrieved from the service. + :return: FeatureSet asset object. + :rtype: ~azure.ai.ml.entities.FeatureSet + """ + try: + featureset_version_resource = self._get(name, version, **kwargs) + return FeatureSet._from_rest_object(featureset_version_resource) # type: ignore[return-value] + except (ValidationException, SchemaValidationError) as ex: + log_and_raise_error(ex) + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureSet.BeginCreateOrUpdate", ActivityType.PUBLICAPI) + def begin_create_or_update(self, featureset: FeatureSet, **kwargs: Dict) -> LROPoller[FeatureSet]: + """Create or update FeatureSet + + :param featureset: FeatureSet definition. + :type featureset: FeatureSet + :return: An instance of LROPoller that returns a FeatureSet. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.FeatureSet] + """ + featureset_copy = deepcopy(featureset) + + featureset_spec = self._validate_and_get_feature_set_spec(featureset_copy) + featureset_copy.properties["featuresetPropertiesVersion"] = "1" + featureset_copy.properties["featuresetProperties"] = json.dumps(featureset_spec._to_dict()) + + sas_uri = None + + if not is_url(featureset_copy.path): + with open(os.path.join(str(featureset_copy.path), ".amlignore"), mode="w", encoding="utf-8") as f: + f.write(".*\n*.amltmp\n*.amltemp") + + featureset_copy, _ = _check_and_upload_path( + artifact=featureset_copy, asset_operations=self, sas_uri=sas_uri, artifact_type=ErrorTarget.FEATURE_SET + ) + + featureset_resource = FeatureSet._to_rest_object(featureset_copy) + + return self._operation.begin_create_or_update( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + name=featureset_copy.name, + version=featureset_copy.version, + body=featureset_resource, + **kwargs, + cls=lambda response, deserialized, headers: FeatureSet._from_rest_object(deserialized), + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureSet.BeginBackFill", ActivityType.PUBLICAPI) + def begin_backfill( + self, + *, + name: str, + version: str, + feature_window_start_time: Optional[datetime] = None, + feature_window_end_time: Optional[datetime] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, + compute_resource: Optional[MaterializationComputeResource] = None, + spark_configuration: Optional[Dict[str, str]] = None, + data_status: Optional[List[Union[str, DataAvailabilityStatus]]] = None, + job_id: Optional[str] = None, + **kwargs: Dict, + ) -> LROPoller[FeatureSetBackfillMetadata]: + """Backfill. + + :keyword name: Feature set name. This is case-sensitive. + :paramtype name: str + :keyword version: Version identifier. This is case-sensitive. + :paramtype version: str + :keyword feature_window_start_time: Start time of the feature window to be materialized. + :paramtype feature_window_start_time: datetime + :keyword feature_window_end_time: End time of the feature window to be materialized. + :paramtype feature_window_end_time: datetime + :keyword display_name: Specifies description. + :paramtype display_name: str + :keyword description: Specifies description. + :paramtype description: str + :keyword tags: A set of tags. Specifies the tags. + :paramtype tags: dict[str, str] + :keyword compute_resource: Specifies the compute resource settings. + :paramtype compute_resource: ~azure.ai.ml.entities.MaterializationComputeResource + :keyword spark_configuration: Specifies the spark compute settings. + :paramtype spark_configuration: dict[str, str] + :keyword data_status: Specifies the data status that you want to backfill. + :paramtype data_status: list[str or ~azure.ai.ml.entities.DataAvailabilityStatus] + :keyword job_id: The job id. + :paramtype job_id: str + :return: An instance of LROPoller that returns ~azure.ai.ml.entities.FeatureSetBackfillMetadata + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.FeatureSetBackfillMetadata] + """ + request_body: FeaturesetVersionBackfillRequest = FeaturesetVersionBackfillRequest( + description=description, + display_name=display_name, + feature_window=FeatureWindow( + feature_window_start=feature_window_start_time, feature_window_end=feature_window_end_time + ), + resource=compute_resource._to_rest_object() if compute_resource else None, + spark_configuration=spark_configuration, + data_availability_status=data_status, + job_id=job_id, + tags=tags, + ) + + return self._operation.begin_backfill( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + name=name, + version=version, + body=request_body, + **kwargs, + cls=lambda response, deserialized, headers: FeatureSetBackfillMetadata._from_rest_object(deserialized), + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureSet.ListMaterializationOperation", ActivityType.PUBLICAPI) + def list_materialization_operations( + self, + name: str, + version: str, + *, + feature_window_start_time: Optional[Union[str, datetime]] = None, + feature_window_end_time: Optional[Union[str, datetime]] = None, + filters: Optional[str] = None, + **kwargs: Dict, + ) -> ItemPaged[FeatureSetMaterializationMetadata]: + """List Materialization operation. + + :param name: Feature set name. + :type name: str + :param version: Feature set version. + :type version: str + :keyword feature_window_start_time: Start time of the feature window to filter materialization jobs. + :paramtype feature_window_start_time: Union[str, datetime] + :keyword feature_window_end_time: End time of the feature window to filter materialization jobs. + :paramtype feature_window_end_time: Union[str, datetime] + :keyword filters: Comma-separated list of tag names (and optionally values). Example: tag1,tag2=value2. + :paramtype filters: str + :return: An iterator like instance of ~azure.ai.ml.entities.FeatureSetMaterializationMetadata objects + :rtype: ~azure.core.paging.ItemPaged[FeatureSetMaterializationMetadata] + """ + feature_window_start_time = _datetime_to_str(feature_window_start_time) if feature_window_start_time else None + feature_window_end_time = _datetime_to_str(feature_window_end_time) if feature_window_end_time else None + properties = f"azureml.FeatureSetName={name},azureml.FeatureSetVersion={version}" + if feature_window_start_time: + properties = properties + f",azureml.FeatureWindowStart={feature_window_start_time}" + if feature_window_end_time: + properties = properties + f",azureml.FeatureWindowEnd={feature_window_end_time}" + + materialization_jobs = self._jobs_operation.list( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + properties=properties, + tag=filters, + cls=lambda objs: [FeatureSetMaterializationMetadata._from_rest_object(obj) for obj in objs], + **kwargs, + ) + return materialization_jobs + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureSet.ListFeatures", ActivityType.PUBLICAPI) + def list_features( + self, + feature_set_name: str, + version: str, + *, + feature_name: Optional[str] = None, + description: Optional[str] = None, + tags: Optional[str] = None, + **kwargs: Dict, + ) -> ItemPaged[Feature]: + """List features + + :param feature_set_name: Feature set name. + :type feature_set_name: str + :param version: Feature set version. + :type version: str + :keyword feature_name: feature name. + :paramtype feature_name: str + :keyword description: Description of the featureset. + :paramtype description: str + :keyword tags: Comma-separated list of tag names (and optionally values). Example: tag1,tag2=value2. + :paramtype tags: str + :return: An iterator like instance of Feature objects + :rtype: ~azure.core.paging.ItemPaged[Feature] + """ + features = self._feature_operation.list( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + featureset_name=feature_set_name, + featureset_version=version, + tags=tags, + feature_name=feature_name, + description=description, + **kwargs, + cls=lambda objs: [Feature._from_rest_object(obj) for obj in objs], + ) + return features + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureSet.GetFeature", ActivityType.PUBLICAPI) + def get_feature( + self, feature_set_name: str, version: str, *, feature_name: str, **kwargs: Dict + ) -> Optional["Feature"]: + """Get Feature + + :param feature_set_name: Feature set name. + :type feature_set_name: str + :param version: Feature set version. + :type version: str + :keyword feature_name: The feature name. This argument is case-sensitive. + :paramtype feature_name: str + :return: Feature object + :rtype: ~azure.ai.ml.entities.Feature + """ + feature = self._feature_operation.get( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + featureset_name=feature_set_name, + featureset_version=version, + feature_name=feature_name, + **kwargs, + ) + + return Feature._from_rest_object(feature) # type: ignore[return-value] + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureSet.Archive", ActivityType.PUBLICAPI) + def archive( + self, + name: str, + version: str, + **kwargs: Dict, + ) -> None: + """Archive a FeatureSet asset. + + :param name: Name of FeatureSet asset. + :type name: str + :param version: Version of FeatureSet asset. + :type version: str + :return: None + """ + + _archive_or_restore( + asset_operations=self, + version_operation=self._operation, + is_archived=True, + name=name, + version=version, + **kwargs, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureSet.Restore", ActivityType.PUBLICAPI) + def restore( + self, + name: str, + version: str, + **kwargs: Dict, + ) -> None: + """Restore an archived FeatureSet asset. + + :param name: Name of FeatureSet asset. + :type name: str + :param version: Version of FeatureSet asset. + :type version: str + :return: None + """ + + _archive_or_restore( + asset_operations=self, + version_operation=self._operation, + is_archived=False, + name=name, + version=version, + **kwargs, + ) + + def _validate_and_get_feature_set_spec(self, featureset: FeatureSet) -> FeaturesetSpecMetadata: + if not (featureset.specification and featureset.specification.path): + msg = "Missing FeatureSet specification path. Path is required for feature set." + raise ValidationException( + message=msg, + no_personal_data_message=msg, + error_type=ValidationErrorType.MISSING_FIELD, + target=ErrorTarget.FEATURE_SET, + error_category=ErrorCategory.USER_ERROR, + ) + + featureset_spec_path: Any = str(featureset.specification.path) + if is_url(featureset_spec_path): + try: + featureset_spec_contents = read_remote_feature_set_spec_metadata( + base_uri=featureset_spec_path, + datastore_operations=self._datastore_operation, + ) + featureset_spec_path = None + except Exception as ex: + module_logger.info("Unable to access FeaturesetSpec at path %s", featureset_spec_path) + raise ex + return FeaturesetSpecMetadata._load(featureset_spec_contents, featureset_spec_path) + + if not os.path.isabs(featureset_spec_path): + featureset_spec_path = Path(featureset.base_path, featureset_spec_path).resolve() + + if not os.path.isdir(featureset_spec_path): + raise ValidationException( + message="No such directory: {}".format(featureset_spec_path), + no_personal_data_message="No such directory", + target=ErrorTarget.FEATURE_SET, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND, + ) + + featureset_spec_contents = read_feature_set_metadata(path=featureset_spec_path) + featureset_spec_yaml_path = Path(featureset_spec_path, "FeatureSetSpec.yaml") + return FeaturesetSpecMetadata._load(featureset_spec_contents, featureset_spec_yaml_path) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_store_entity_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_store_entity_operations.py new file mode 100644 index 00000000..e0ce9587 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_store_entity_operations.py @@ -0,0 +1,191 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Dict, Optional + +from marshmallow.exceptions import ValidationError as SchemaValidationError + +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._restclient.v2023_10_01 import AzureMachineLearningServices as ServiceClient102023 +from azure.ai.ml._restclient.v2023_10_01.models import FeaturestoreEntityVersion +from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._feature_store_utils import _archive_or_restore +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml.constants import ListViewType +from azure.ai.ml.entities._feature_store_entity.feature_store_entity import FeatureStoreEntity +from azure.ai.ml.exceptions import ValidationException +from azure.core.paging import ItemPaged +from azure.core.polling import LROPoller +from azure.core.tracing.decorator import distributed_trace + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class FeatureStoreEntityOperations(_ScopeDependentOperations): + """FeatureStoreEntityOperations. + + You should not instantiate this class directly. Instead, you should + create an MLClient instance that instantiates it for you and + attaches it as an attribute. + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client: ServiceClient102023, + **kwargs: Dict, + ): + super(FeatureStoreEntityOperations, self).__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._operation = service_client.featurestore_entity_versions + self._container_operation = service_client.featurestore_entity_containers + self._service_client = service_client + self._init_kwargs = kwargs + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureStoreEntity.List", ActivityType.PUBLICAPI) + def list( + self, + name: Optional[str] = None, + *, + list_view_type: ListViewType = ListViewType.ACTIVE_ONLY, + **kwargs: Dict, + ) -> ItemPaged[FeatureStoreEntity]: + """List the FeatureStoreEntity assets of the workspace. + + :param name: Name of a specific FeatureStoreEntity asset, optional. + :type name: Optional[str] + :keyword list_view_type: View type for including/excluding (for example) archived FeatureStoreEntity assets. + Default: ACTIVE_ONLY. + :paramtype list_view_type: Optional[ListViewType] + :return: An iterator like instance of FeatureStoreEntity objects + :rtype: ~azure.core.paging.ItemPaged[FeatureStoreEntity] + """ + if name: + return self._operation.list( + workspace_name=self._workspace_name, + name=name, + cls=lambda objs: [FeatureStoreEntity._from_rest_object(obj) for obj in objs], + list_view_type=list_view_type, + **self._scope_kwargs, + **kwargs, + ) + return self._container_operation.list( + workspace_name=self._workspace_name, + cls=lambda objs: [FeatureStoreEntity._from_container_rest_object(obj) for obj in objs], + list_view_type=list_view_type, + **self._scope_kwargs, + **kwargs, + ) + + def _get(self, name: str, version: Optional[str] = None, **kwargs: Dict) -> FeaturestoreEntityVersion: + return self._operation.get( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + name=name, + version=version, + **self._init_kwargs, + **kwargs, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureStoreEntity.Get", ActivityType.PUBLICAPI) + def get(self, name: str, version: str, **kwargs: Dict) -> FeatureStoreEntity: # type: ignore + """Get the specified FeatureStoreEntity asset. + + :param name: Name of FeatureStoreEntity asset. + :type name: str + :param version: Version of FeatureStoreEntity asset. + :type version: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if FeatureStoreEntity cannot be successfully + identified and retrieved. Details will be provided in the error message. + :return: FeatureStoreEntity asset object. + :rtype: ~azure.ai.ml.entities.FeatureStoreEntity + """ + try: + feature_store_entity_version_resource = self._get(name, version, **kwargs) + return FeatureStoreEntity._from_rest_object(feature_store_entity_version_resource) + except (ValidationException, SchemaValidationError) as ex: + log_and_raise_error(ex) + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureStoreEntity.BeginCreateOrUpdate", ActivityType.PUBLICAPI) + def begin_create_or_update( + self, feature_store_entity: FeatureStoreEntity, **kwargs: Dict + ) -> LROPoller[FeatureStoreEntity]: + """Create or update FeatureStoreEntity + + :param feature_store_entity: FeatureStoreEntity definition. + :type feature_store_entity: FeatureStoreEntity + :return: An instance of LROPoller that returns a FeatureStoreEntity. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.FeatureStoreEntity] + """ + feature_store_entity_resource = FeatureStoreEntity._to_rest_object(feature_store_entity) + + return self._operation.begin_create_or_update( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + name=feature_store_entity.name, + version=feature_store_entity.version, + body=feature_store_entity_resource, + cls=lambda response, deserialized, headers: FeatureStoreEntity._from_rest_object(deserialized), + **kwargs, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureStoreEntity.Archive", ActivityType.PUBLICAPI) + def archive( + self, + name: str, + version: str, + **kwargs: Dict, + ) -> None: + """Archive a FeatureStoreEntity asset. + + :param name: Name of FeatureStoreEntity asset. + :type name: str + :param version: Version of FeatureStoreEntity asset. + :type version: str + :return: None + """ + + _archive_or_restore( + asset_operations=self, + version_operation=self._operation, + is_archived=True, + name=name, + version=version, + **kwargs, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureStoreEntity.Restore", ActivityType.PUBLICAPI) + def restore( + self, + name: str, + version: str, + **kwargs: Dict, + ) -> None: + """Restore an archived FeatureStoreEntity asset. + + :param name: Name of FeatureStoreEntity asset. + :type name: str + :param version: Version of FeatureStoreEntity asset. + :type version: str + :return: None + """ + + _archive_or_restore( + asset_operations=self, + version_operation=self._operation, + is_archived=False, + name=name, + version=version, + **kwargs, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_store_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_store_operations.py new file mode 100644 index 00000000..a7f8e93d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_store_operations.py @@ -0,0 +1,566 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import re +import uuid +from typing import Any, Dict, Iterable, Optional, cast + +from marshmallow import ValidationError + +from azure.ai.ml._restclient.v2024_10_01_preview import AzureMachineLearningWorkspaces as ServiceClient102024Preview +from azure.ai.ml._restclient.v2024_10_01_preview.models import ManagedNetworkProvisionOptions +from azure.ai.ml._scope_dependent_operations import OperationsContainer, OperationScope +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml._utils.utils import camel_to_snake +from azure.ai.ml.constants import ManagedServiceIdentityType +from azure.ai.ml.constants._common import Scope, WorkspaceKind +from azure.ai.ml.entities import ( + IdentityConfiguration, + ManagedIdentityConfiguration, + ManagedNetworkProvisionStatus, + WorkspaceConnection, +) +from azure.ai.ml.entities._feature_store._constants import ( + OFFLINE_MATERIALIZATION_STORE_TYPE, + OFFLINE_STORE_CONNECTION_CATEGORY, + OFFLINE_STORE_CONNECTION_NAME, + ONLINE_MATERIALIZATION_STORE_TYPE, + ONLINE_STORE_CONNECTION_CATEGORY, + ONLINE_STORE_CONNECTION_NAME, + STORE_REGEX_PATTERN, +) +from azure.ai.ml.entities._feature_store.feature_store import FeatureStore +from azure.ai.ml.entities._feature_store.materialization_store import MaterializationStore +from azure.ai.ml.entities._workspace.feature_store_settings import FeatureStoreSettings +from azure.core.credentials import TokenCredential +from azure.core.exceptions import ResourceNotFoundError +from azure.core.polling import LROPoller +from azure.core.tracing.decorator import distributed_trace + +from ._workspace_operations_base import WorkspaceOperationsBase + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class FeatureStoreOperations(WorkspaceOperationsBase): + """FeatureStoreOperations. + + You should not instantiate this class directly. Instead, you should + create an MLClient instance that instantiates it for you and + attaches it as an attribute. + """ + + def __init__( + self, + operation_scope: OperationScope, + service_client: ServiceClient102024Preview, + all_operations: OperationsContainer, + credentials: Optional[TokenCredential] = None, + **kwargs: Dict, + ): + ops_logger.update_filter() + self._provision_network_operation = service_client.managed_network_provisions + super().__init__( + operation_scope=operation_scope, + service_client=service_client, + all_operations=all_operations, + credentials=credentials, + **kwargs, + ) + self._workspace_connection_operation = service_client.workspace_connections + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureStore.List", ActivityType.PUBLICAPI) + # pylint: disable=unused-argument + def list(self, *, scope: str = Scope.RESOURCE_GROUP, **kwargs: Dict) -> Iterable[FeatureStore]: + """List all feature stores that the user has access to in the current + resource group or subscription. + + :keyword scope: scope of the listing, "resource_group" or "subscription", defaults to "resource_group" + :paramtype scope: str + :return: An iterator like instance of FeatureStore objects + :rtype: ~azure.core.paging.ItemPaged[FeatureStore] + """ + + if scope == Scope.SUBSCRIPTION: + return cast( + Iterable[FeatureStore], + self._operation.list_by_subscription( + cls=lambda objs: [ + FeatureStore._from_rest_object(filterObj) + for filterObj in filter(lambda ws: ws.kind.lower() == WorkspaceKind.FEATURE_STORE, objs) + ], + ), + ) + return cast( + Iterable[FeatureStore], + self._operation.list_by_resource_group( + self._resource_group_name, + cls=lambda objs: [ + FeatureStore._from_rest_object(filterObj) + for filterObj in filter(lambda ws: ws.kind.lower() == WorkspaceKind.FEATURE_STORE, objs) + ], + ), + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureStore.Get", ActivityType.PUBLICAPI) + # pylint: disable=arguments-renamed + def get(self, name: str, **kwargs: Any) -> FeatureStore: + """Get a feature store by name. + + :param name: Name of the feature store. + :type name: str + :raises ~azure.core.exceptions.HttpResponseError: Raised if the corresponding name and version cannot be + retrieved from the service. + :return: The feature store with the provided name. + :rtype: FeatureStore + """ + + feature_store: Any = None + resource_group = kwargs.get("resource_group") or self._resource_group_name + rest_workspace_obj = kwargs.get("rest_workspace_obj", None) or self._operation.get(resource_group, name) + if ( + rest_workspace_obj + and rest_workspace_obj.kind + and rest_workspace_obj.kind.lower() == WorkspaceKind.FEATURE_STORE + ): + feature_store = FeatureStore._from_rest_object(rest_workspace_obj) + + if feature_store: + offline_store_connection = None + if ( + rest_workspace_obj.feature_store_settings + and rest_workspace_obj.feature_store_settings.offline_store_connection_name + ): + try: + offline_store_connection = self._workspace_connection_operation.get( + resource_group, name, rest_workspace_obj.feature_store_settings.offline_store_connection_name + ) + except ResourceNotFoundError: + pass + + if offline_store_connection: + if ( + offline_store_connection.properties + and offline_store_connection.properties.category == OFFLINE_STORE_CONNECTION_CATEGORY + ): + feature_store.offline_store = MaterializationStore( + type=OFFLINE_MATERIALIZATION_STORE_TYPE, target=offline_store_connection.properties.target + ) + + online_store_connection = None + if ( + rest_workspace_obj.feature_store_settings + and rest_workspace_obj.feature_store_settings.online_store_connection_name + ): + try: + online_store_connection = self._workspace_connection_operation.get( + resource_group, name, rest_workspace_obj.feature_store_settings.online_store_connection_name + ) + except ResourceNotFoundError: + pass + + if online_store_connection: + if ( + online_store_connection.properties + and online_store_connection.properties.category == ONLINE_STORE_CONNECTION_CATEGORY + ): + feature_store.online_store = MaterializationStore( + type=ONLINE_MATERIALIZATION_STORE_TYPE, target=online_store_connection.properties.target + ) + + # materialization identity = identity when created through feature store operations + if ( + offline_store_connection and offline_store_connection.name.startswith(OFFLINE_STORE_CONNECTION_NAME) + ) or (online_store_connection and online_store_connection.name.startswith(ONLINE_STORE_CONNECTION_NAME)): + if ( + feature_store.identity + and feature_store.identity.user_assigned_identities + and isinstance(feature_store.identity.user_assigned_identities[0], ManagedIdentityConfiguration) + ): + feature_store.materialization_identity = feature_store.identity.user_assigned_identities[0] + + return feature_store + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureStore.BeginCreate", ActivityType.PUBLICAPI) + # pylint: disable=arguments-differ + def begin_create( + self, + feature_store: FeatureStore, + *, + grant_materialization_permissions: bool = True, + update_dependent_resources: bool = False, + **kwargs: Dict, + ) -> LROPoller[FeatureStore]: + """Create a new FeatureStore. + + Returns the feature store if already exists. + + :param feature_store: FeatureStore definition. + :type feature_store: FeatureStore + :keyword grant_materialization_permissions: Whether or not to grant materialization permissions. + Defaults to True. + :paramtype grant_materialization_permissions: bool + :keyword update_dependent_resources: Whether or not to update dependent resources. Defaults to False. + :type update_dependent_resources: bool + :return: An instance of LROPoller that returns a FeatureStore. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.FeatureStore] + """ + resource_group = kwargs.get("resource_group", self._resource_group_name) + try: + rest_workspace_obj = self._operation.get(resource_group, feature_store.name) + if rest_workspace_obj: + return self.begin_update( + feature_store=feature_store, update_dependent_resources=update_dependent_resources, kwargs=kwargs + ) + except Exception: # pylint: disable=W0718 + pass + + if feature_store.offline_store and feature_store.offline_store.type != OFFLINE_MATERIALIZATION_STORE_TYPE: + raise ValidationError("offline store type should be azure_data_lake_gen2") + + if feature_store.online_store and feature_store.online_store.type != ONLINE_MATERIALIZATION_STORE_TYPE: + raise ValidationError("online store type should be redis") + + # generate a random suffix for online/offline store connection name, + # please don't refer to OFFLINE_STORE_CONNECTION_NAME and + # ONLINE_STORE_CONNECTION_NAME directly from FeatureStore + random_string = uuid.uuid4().hex[:8] + if feature_store._feature_store_settings is not None: + feature_store._feature_store_settings.offline_store_connection_name = ( + f"{OFFLINE_STORE_CONNECTION_NAME}-{random_string}" + ) + feature_store._feature_store_settings.online_store_connection_name = ( + f"{ONLINE_STORE_CONNECTION_NAME}-{random_string}" + if feature_store.online_store and feature_store.online_store.target + else None + ) + + def get_callback() -> FeatureStore: + return self.get(feature_store.name) + + return super().begin_create( + workspace=feature_store, + update_dependent_resources=update_dependent_resources, + get_callback=get_callback, + offline_store_target=feature_store.offline_store.target if feature_store.offline_store else None, + online_store_target=feature_store.online_store.target if feature_store.online_store else None, + materialization_identity=feature_store.materialization_identity, + grant_materialization_permissions=grant_materialization_permissions, + **kwargs, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureStore.BeginUpdate", ActivityType.PUBLICAPI) + # pylint: disable=arguments-renamed + # pylint: disable=too-many-locals, too-many-branches, too-many-statements + def begin_update( # pylint: disable=C4758 + self, + feature_store: FeatureStore, + *, + grant_materialization_permissions: bool = True, + update_dependent_resources: bool = False, + **kwargs: Any, + ) -> LROPoller[FeatureStore]: + """Update friendly name, description, online store connection, offline store connection, materialization + identities or tags of a feature store. + + :param feature_store: FeatureStore resource. + :type feature_store: FeatureStore + :keyword grant_materialization_permissions: Whether or not to grant materialization permissions. + Defaults to True. + :paramtype grant_materialization_permissions: bool + :keyword update_dependent_resources: gives your consent to update the feature store dependent resources. + Note that updating the feature store attached Azure Container Registry resource may break lineage + of previous jobs or your ability to rerun earlier jobs in this feature store. + Also, updating the feature store attached Azure Application Insights resource may break lineage of + deployed inference endpoints this feature store. Only set this argument if you are sure that you want + to perform this operation. If this argument is not set, the command to update + Azure Container Registry and Azure Application Insights will fail. + :keyword application_insights: Application insights resource for feature store. Defaults to None. + :paramtype application_insights: Optional[str] + :keyword container_registry: Container registry resource for feature store. Defaults to None. + :paramtype container_registry: Optional[str] + :return: An instance of LROPoller that returns a FeatureStore. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.FeatureStore] + """ + resource_group = kwargs.get("resource_group", self._resource_group_name) + rest_workspace_obj = self._operation.get(resource_group, feature_store.name) + if not ( + rest_workspace_obj + and rest_workspace_obj.kind + and rest_workspace_obj.kind.lower() == WorkspaceKind.FEATURE_STORE + ): + raise ValidationError("{0} is not a feature store".format(feature_store.name)) + + resource_group = kwargs.get("resource_group") or self._resource_group_name + offline_store = kwargs.get("offline_store", feature_store.offline_store) + online_store = kwargs.get("online_store", feature_store.online_store) + offline_store_target_to_update = offline_store.target if offline_store else None + online_store_target_to_update = online_store.target if online_store else None + update_workspace_role_assignment = False + update_offline_store_role_assignment = False + update_online_store_role_assignment = False + + update_offline_store_connection = False + update_online_store_connection = False + + existing_materialization_identity = None + if rest_workspace_obj.identity: + identity = IdentityConfiguration._from_workspace_rest_object(rest_workspace_obj.identity) + if ( + identity + and identity.user_assigned_identities + and isinstance(identity.user_assigned_identities[0], ManagedIdentityConfiguration) + ): + existing_materialization_identity = identity.user_assigned_identities[0] + + materialization_identity = kwargs.get( + "materialization_identity", feature_store.materialization_identity or existing_materialization_identity + ) + + if ( + feature_store.materialization_identity + and feature_store.materialization_identity.resource_id + and ( + not existing_materialization_identity + or feature_store.materialization_identity.resource_id != existing_materialization_identity.resource_id + ) + ): + update_workspace_role_assignment = True + update_offline_store_role_assignment = True + update_online_store_role_assignment = True + + self._validate_offline_store(offline_store=offline_store) + + if ( + rest_workspace_obj.feature_store_settings + and rest_workspace_obj.feature_store_settings.offline_store_connection_name + ): + existing_offline_store_connection = self._workspace_connection_operation.get( + resource_group, + feature_store.name, + rest_workspace_obj.feature_store_settings.offline_store_connection_name, + ) + + offline_store_target_to_update = ( + offline_store_target_to_update or existing_offline_store_connection.properties.target + ) + if offline_store and ( + not existing_offline_store_connection.properties + or existing_offline_store_connection.properties.target != offline_store.target + ): + update_offline_store_connection = True + update_offline_store_role_assignment = True + module_logger.info( + "Warning: You have changed the offline store connection, " + "any data that was materialized " + "earlier will not be available. You have to run backfill again." + ) + elif offline_store_target_to_update: + update_offline_store_connection = True + update_offline_store_role_assignment = True + + if online_store and online_store.type != ONLINE_MATERIALIZATION_STORE_TYPE: + raise ValidationError("online store type should be redis") + + if ( + rest_workspace_obj.feature_store_settings + and rest_workspace_obj.feature_store_settings.online_store_connection_name + ): + existing_online_store_connection = self._workspace_connection_operation.get( + resource_group, + feature_store.name, + rest_workspace_obj.feature_store_settings.online_store_connection_name, + ) + + online_store_target_to_update = ( + online_store_target_to_update or existing_online_store_connection.properties.target + ) + if online_store and ( + not existing_online_store_connection.properties + or existing_online_store_connection.properties.target != online_store.target + ): + update_online_store_connection = True + update_online_store_role_assignment = True + module_logger.info( + "Warning: You have changed the online store connection, " + "any data that was materialized earlier " + "will not be available. You have to run backfill again." + ) + elif online_store_target_to_update: + update_online_store_connection = True + update_online_store_role_assignment = True + + feature_store_settings: Any = FeatureStoreSettings._from_rest_object(rest_workspace_obj.feature_store_settings) + + # generate a random suffix for online/offline store connection name + random_string = uuid.uuid4().hex[:8] + if offline_store: + if materialization_identity: + if update_offline_store_connection: + offline_store_connection_name_new = f"{OFFLINE_STORE_CONNECTION_NAME}-{random_string}" + offline_store_connection = WorkspaceConnection( + name=offline_store_connection_name_new, + type=offline_store.type, + target=offline_store.target, + credentials=materialization_identity, + ) + rest_offline_store_connection = offline_store_connection._to_rest_object() + self._workspace_connection_operation.create( + resource_group_name=resource_group, + workspace_name=feature_store.name, + connection_name=offline_store_connection_name_new, + body=rest_offline_store_connection, + ) + feature_store_settings.offline_store_connection_name = offline_store_connection_name_new + else: + module_logger.info( + "No need to update Offline store connection, name: %s.\n", + feature_store_settings.offline_store_connection_name, + ) + else: + raise ValidationError("Materialization identity is required to setup offline store connection") + + if online_store: + if materialization_identity: + if update_online_store_connection: + online_store_connection_name_new = f"{ONLINE_STORE_CONNECTION_NAME}-{random_string}" + online_store_connection = WorkspaceConnection( + name=online_store_connection_name_new, + type=online_store.type, + target=online_store.target, + credentials=materialization_identity, + ) + rest_online_store_connection = online_store_connection._to_rest_object() + self._workspace_connection_operation.create( + resource_group_name=resource_group, + workspace_name=feature_store.name, + connection_name=online_store_connection_name_new, + body=rest_online_store_connection, + ) + feature_store_settings.online_store_connection_name = online_store_connection_name_new + else: + module_logger.info( + "No need to update Online store connection, name: %s.\n", + feature_store_settings.online_store_connection_name, + ) + else: + raise ValidationError("Materialization identity is required to setup online store connection") + + if not offline_store_target_to_update: + update_offline_store_role_assignment = False + if not online_store_target_to_update: + update_online_store_role_assignment = False + + user_defined_cr = feature_store.compute_runtime + if ( + user_defined_cr + and user_defined_cr.spark_runtime_version != feature_store_settings.compute_runtime.spark_runtime_version + ): + # update user defined compute runtime + feature_store_settings.compute_runtime = feature_store.compute_runtime + + identity = kwargs.pop("identity", feature_store.identity) + if materialization_identity: + identity = IdentityConfiguration( + type=camel_to_snake(ManagedServiceIdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED), + # At most 1 UAI can be attached to workspace when materialization is enabled + user_assigned_identities=[materialization_identity], + ) + + def deserialize_callback(rest_obj: Any) -> FeatureStore: + return self.get(rest_obj.name, rest_workspace_obj=rest_obj) + + return super().begin_update( + workspace=feature_store, + update_dependent_resources=update_dependent_resources, + deserialize_callback=deserialize_callback, + feature_store_settings=feature_store_settings, + identity=identity, + grant_materialization_permissions=grant_materialization_permissions, + update_workspace_role_assignment=update_workspace_role_assignment, + update_offline_store_role_assignment=update_offline_store_role_assignment, + update_online_store_role_assignment=update_online_store_role_assignment, + materialization_identity_id=( + materialization_identity.resource_id + if update_workspace_role_assignment + or update_offline_store_role_assignment + or update_online_store_role_assignment + else None + ), + offline_store_target=offline_store_target_to_update if update_offline_store_role_assignment else None, + online_store_target=online_store_target_to_update if update_online_store_role_assignment else None, + **kwargs, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureStore.BeginDelete", ActivityType.PUBLICAPI) + def begin_delete(self, name: str, *, delete_dependent_resources: bool = False, **kwargs: Any) -> LROPoller[None]: + """Delete a FeatureStore. + + :param name: Name of the FeatureStore + :type name: str + :keyword delete_dependent_resources: Whether to delete resources associated with the feature store, + i.e., container registry, storage account, key vault, and application insights. + The default is False. Set to True to delete these resources. + :paramtype delete_dependent_resources: bool + :return: A poller to track the operation status. + :rtype: ~azure.core.polling.LROPoller[None] + """ + resource_group = kwargs.get("resource_group") or self._resource_group_name + rest_workspace_obj = self._operation.get(resource_group, name) + if not ( + rest_workspace_obj + and rest_workspace_obj.kind + and rest_workspace_obj.kind.lower() == WorkspaceKind.FEATURE_STORE + ): + raise ValidationError("{0} is not a feature store".format(name)) + + return super().begin_delete(name=name, delete_dependent_resources=delete_dependent_resources, **kwargs) + + @distributed_trace + @monitor_with_activity(ops_logger, "FeatureStore.BeginProvisionNetwork", ActivityType.PUBLICAPI) + def begin_provision_network( + self, + *, + feature_store_name: Optional[str] = None, + include_spark: bool = False, + **kwargs: Any, + ) -> LROPoller[ManagedNetworkProvisionStatus]: + """Triggers the feature store to provision the managed network. Specifying spark enabled + as true prepares the feature store managed network for supporting Spark. + + :keyword feature_store_name: Name of the feature store. + :paramtype feature_store_name: str + :keyword include_spark: Whether to include spark in the network provisioning. Defaults to False. + :paramtype include_spark: bool + :return: An instance of LROPoller. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.ManagedNetworkProvisionStatus] + """ + workspace_name = self._check_workspace_name(feature_store_name) + + poller = self._provision_network_operation.begin_provision_managed_network( + self._resource_group_name, + workspace_name, + ManagedNetworkProvisionOptions(include_spark=include_spark), + polling=True, + cls=lambda response, deserialized, headers: ManagedNetworkProvisionStatus._from_rest_object(deserialized), + **kwargs, + ) + module_logger.info("Provision network request initiated for feature store: %s\n", workspace_name) + return poller + + def _validate_offline_store(self, offline_store: MaterializationStore) -> None: + store_regex = re.compile(STORE_REGEX_PATTERN) + if offline_store and store_regex.match(offline_store.target) is None: + raise ValidationError(f"Invalid AzureML offlinestore target ARM Id {offline_store.target}") + if offline_store and offline_store.type != OFFLINE_MATERIALIZATION_STORE_TYPE: + raise ValidationError("offline store type should be azure_data_lake_gen2") diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py new file mode 100644 index 00000000..28e409c7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py @@ -0,0 +1,483 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access +import json +from typing import Any, Callable, Dict, Iterable, List, Optional, Union + +from azure.ai.ml._artifacts._artifact_utilities import _check_and_upload_path + +# cspell:disable-next-line +from azure.ai.ml._restclient.azure_ai_assets_v2024_04_01.azureaiassetsv20240401 import ( + MachineLearningServicesClient as AzureAiAssetsClient042024, +) + +# cspell:disable-next-line +from azure.ai.ml._restclient.azure_ai_assets_v2024_04_01.azureaiassetsv20240401.models import Index as RestIndex +from azure.ai.ml._restclient.v2023_04_01_preview.models import ListViewType +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationsContainer, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._asset_utils import ( + _resolve_label_to_asset, + _validate_auto_delete_setting_in_data_output, + _validate_workspace_managed_datastore, +) +from azure.ai.ml._utils._http_utils import HttpPipeline +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml._utils.utils import _get_base_urls_from_discovery_service +from azure.ai.ml.constants._common import AssetTypes, AzureMLResourceType, WorkspaceDiscoveryUrlKey +from azure.ai.ml.dsl import pipeline +from azure.ai.ml.entities import PipelineJob, PipelineJobSettings +from azure.ai.ml.entities._assets import Index +from azure.ai.ml.entities._credentials import ManagedIdentityConfiguration, UserIdentityConfiguration +from azure.ai.ml.entities._indexes import ( + AzureAISearchConfig, + GitSource, + IndexDataSource, + LocalSource, + ModelConfiguration, +) +from azure.ai.ml.entities._indexes.data_index_func import index_data as index_data_func +from azure.ai.ml.entities._indexes.entities.data_index import ( + CitationRegex, + Data, + DataIndex, + Embedding, + IndexSource, + IndexStore, +) +from azure.ai.ml.entities._indexes.utils._open_ai_utils import build_connection_id, build_open_ai_protocol +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException +from azure.ai.ml.operations._datastore_operations import DatastoreOperations +from azure.core.credentials import TokenCredential + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class IndexOperations(_ScopeDependentOperations): + """Represents a client for performing operations on index assets. + + You should not instantiate this class directly. Instead, you should create MLClient and use this client via the + property MLClient.index + """ + + def __init__( + self, + *, + operation_scope: OperationScope, + operation_config: OperationConfig, + credential: TokenCredential, + datastore_operations: DatastoreOperations, + all_operations: OperationsContainer, + **kwargs: Any, + ): + super().__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._credential = credential + # Dataplane service clients are lazily created as they are needed + self.__azure_ai_assets_client: Optional[AzureAiAssetsClient042024] = None + # Kwargs to propagate to dataplane service clients + self._service_client_kwargs: Dict[str, Any] = kwargs.pop("_service_client_kwargs", {}) + self._all_operations = all_operations + + self._datastore_operation: DatastoreOperations = datastore_operations + self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline") + + # Maps a label to a function which given an asset name, + # returns the asset associated with the label + self._managed_label_resolver: Dict[str, Callable[[str], Index]] = {"latest": self._get_latest_version} + + @property + def _azure_ai_assets(self) -> AzureAiAssetsClient042024: + """Lazily instantiated client for azure ai assets api. + + .. note:: + + Property is lazily instantiated since the api's base url depends on the user's workspace, and + must be fetched remotely. + """ + if self.__azure_ai_assets_client is None: + workspace_operations = self._all_operations.all_operations[AzureMLResourceType.WORKSPACE] + + endpoint = _get_base_urls_from_discovery_service( + workspace_operations, self._operation_scope.workspace_name, self._requests_pipeline + )[WorkspaceDiscoveryUrlKey.API] + + self.__azure_ai_assets_client = AzureAiAssetsClient042024( + endpoint=endpoint, + subscription_id=self._operation_scope.subscription_id, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._operation_scope.workspace_name, + credential=self._credential, + **self._service_client_kwargs, + ) + + return self.__azure_ai_assets_client + + @monitor_with_activity(ops_logger, "Index.CreateOrUpdate", ActivityType.PUBLICAPI) + def create_or_update(self, index: Index, **kwargs) -> Index: + """Returns created or updated index asset. + + If not already in storage, asset will be uploaded to the workspace's default datastore. + + :param index: Index asset object. + :type index: Index + :return: Index asset object. + :rtype: ~azure.ai.ml.entities.Index + """ + + if not index.name: + msg = "Must specify a name." + raise ValidationException( + message=msg, + target=ErrorTarget.INDEX, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + + if not index.version: + if not index._auto_increment_version: + msg = "Must specify a version." + raise ValidationException( + message=msg, + target=ErrorTarget.INDEX, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + + next_version = self._azure_ai_assets.indexes.get_next_version(index.name).next_version + + if next_version is None: + msg = "Version not specified, could not automatically increment version. Set a version to resolve." + raise ValidationException( + message=msg, + target=ErrorTarget.INDEX, + no_personal_data_message=msg, + error_category=ErrorCategory.SYSTEM_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + + index.version = str(next_version) + + _ = _check_and_upload_path( + artifact=index, + asset_operations=self, + datastore_name=index.datastore, + artifact_type=ErrorTarget.INDEX, + show_progress=self._show_progress, + ) + + return Index._from_rest_object( + self._azure_ai_assets.indexes.create_or_update( + name=index.name, version=index.version, body=index._to_rest_object(), **kwargs + ) + ) + + @monitor_with_activity(ops_logger, "Index.Get", ActivityType.PUBLICAPI) + def get(self, name: str, *, version: Optional[str] = None, label: Optional[str] = None, **kwargs) -> Index: + """Returns information about the specified index asset. + + :param str name: Name of the index asset. + :keyword Optional[str] version: Version of the index asset. Mutually exclusive with `label`. + :keyword Optional[str] label: The label of the index asset. Mutually exclusive with `version`. + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Index cannot be successfully validated. + Details will be provided in the error message. + :return: Index asset object. + :rtype: ~azure.ai.ml.entities.Index + """ + if version and label: + msg = "Cannot specify both version and label." + raise ValidationException( + message=msg, + target=ErrorTarget.INDEX, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + if label: + return _resolve_label_to_asset(self, name, label) + + if not version: + msg = "Must provide either version or label." + raise ValidationException( + message=msg, + target=ErrorTarget.INDEX, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + + index_version_resource = self._azure_ai_assets.indexes.get(name=name, version=version, **kwargs) + + return Index._from_rest_object(index_version_resource) + + def _get_latest_version(self, name: str) -> Index: + return Index._from_rest_object(self._azure_ai_assets.indexes.get_latest(name)) + + @monitor_with_activity(ops_logger, "Index.List", ActivityType.PUBLICAPI) + def list( + self, name: Optional[str] = None, *, list_view_type: ListViewType = ListViewType.ACTIVE_ONLY, **kwargs + ) -> Iterable[Index]: + """List all Index assets in workspace. + + If an Index is specified by name, list all version of that Index. + + :param name: Name of the model. + :type name: Optional[str] + :keyword list_view_type: View type for including/excluding (for example) archived models. + Defaults to :attr:`ListViewType.ACTIVE_ONLY`. + :paramtype list_view_type: ListViewType + :return: An iterator like instance of Index objects + :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.Index] + """ + + def cls(rest_indexes: Iterable[RestIndex]) -> List[Index]: + return [Index._from_rest_object(i) for i in rest_indexes] + + if name is None: + return self._azure_ai_assets.indexes.list_latest(cls=cls, **kwargs) + + return self._azure_ai_assets.indexes.list(name, list_view_type=list_view_type, cls=cls, **kwargs) + + def build_index( + self, + *, + ######## required args ########## + name: str, + embeddings_model_config: ModelConfiguration, + ######## chunking information ########## + data_source_citation_url: Optional[str] = None, + tokens_per_chunk: Optional[int] = None, + token_overlap_across_chunks: Optional[int] = None, + input_glob: Optional[str] = None, + ######## other generic args ######## + document_path_replacement_regex: Optional[str] = None, + ######## Azure AI Search index info ######## + index_config: Optional[AzureAISearchConfig] = None, # todo better name? + ######## data source info ######## + input_source: Union[IndexDataSource, str], + input_source_credential: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None, + ) -> Union["Index", "Job"]: # type: ignore[name-defined] + """Builds an index on the cloud using the Azure AI Resources service. + + :keyword name: The name of the index to be created. + :paramtype name: str + :keyword embeddings_model_config: Model config for the embedding model. + :paramtype embeddings_model_config: ~azure.ai.ml.entities._indexes.ModelConfiguration + :keyword data_source_citation_url: The URL of the data source. + :paramtype data_source_citation_url: Optional[str] + :keyword tokens_per_chunk: The size of chunks to be used for indexing. + :paramtype tokens_per_chunk: Optional[int] + :keyword token_overlap_across_chunks: The amount of overlap between chunks. + :paramtype token_overlap_across_chunks: Optional[int] + :keyword input_glob: The glob pattern to be used for indexing. + :paramtype input_glob: Optional[str] + :keyword document_path_replacement_regex: The regex pattern for replacing document paths. + :paramtype document_path_replacement_regex: Optional[str] + :keyword index_config: The configuration for the ACS output. + :paramtype index_config: Optional[~azure.ai.ml.entities._indexes.AzureAISearchConfig] + :keyword input_source: The input source for the index. + :paramtype input_source: Union[~azure.ai.ml.entities._indexes.IndexDataSource, str] + :keyword input_source_credential: The identity to be used for the index. + :paramtype input_source_credential: Optional[Union[~azure.ai.ml.entities.ManagedIdentityConfiguration, + ~azure.ai.ml.entities.UserIdentityConfiguration]] + :return: If the `source_input` is a GitSource, returns a created DataIndex Job object. + :rtype: Union[~azure.ai.ml.entities.Index, ~azure.ai.ml.entities.Job] + :raises ValueError: If the `source_input` is not type ~typing.Str or + ~azure.ai.ml.entities._indexes.LocalSource. + """ + if document_path_replacement_regex: + document_path_replacement_regex = json.loads(document_path_replacement_regex) + + data_index = DataIndex( + name=name, + source=IndexSource( + input_data=Data( + type="uri_folder", + path=".", + ), + input_glob=input_glob, + chunk_size=tokens_per_chunk, + chunk_overlap=token_overlap_across_chunks, + citation_url=data_source_citation_url, + citation_url_replacement_regex=( + CitationRegex( + match_pattern=document_path_replacement_regex["match_pattern"], # type: ignore[index] + replacement_pattern=document_path_replacement_regex[ + "replacement_pattern" # type: ignore[index] + ], + ) + if document_path_replacement_regex + else None + ), + ), + embedding=Embedding( + model=build_open_ai_protocol( + model=embeddings_model_config.model_name, deployment=embeddings_model_config.deployment_name + ), + connection=build_connection_id(embeddings_model_config.connection_name, self._operation_scope), + ), + index=( + IndexStore( + type="acs", + connection=build_connection_id(index_config.connection_id, self._operation_scope), + name=index_config.index_name, + ) + if index_config is not None + else IndexStore(type="faiss") + ), + # name is replaced with a unique value each time the job is run + path=f"azureml://datastores/workspaceblobstore/paths/indexes/{name}/{{name}}", + ) + + if isinstance(input_source, LocalSource): + data_index.source.input_data = Data( + type="uri_folder", + path=input_source.input_data.path, + ) + + return self._create_data_indexing_job(data_index=data_index, identity=input_source_credential) + + if isinstance(input_source, str): + data_index.source.input_data = Data( + type="uri_folder", + path=input_source, + ) + + return self._create_data_indexing_job(data_index=data_index, identity=input_source_credential) + + if isinstance(input_source, GitSource): + from azure.ai.ml import MLClient + + ml_registry = MLClient(credential=self._credential, registry_name="azureml") + git_clone_component = ml_registry.components.get("llm_rag_git_clone", label="latest") + + # Clone Git Repo and use as input to index_job + @pipeline(default_compute="serverless") # type: ignore[call-overload] + def git_to_index( + git_url, + branch_name="", + git_connection_id="", + ): + git_clone = git_clone_component(git_repository=git_url, branch_name=branch_name) + git_clone.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_GIT"] = git_connection_id + + index_job = index_data_func( + description=data_index.description, + data_index=data_index, + input_data_override=git_clone.outputs.output_data, + ml_client=MLClient( + subscription_id=self._subscription_id, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + credential=self._credential, + ), + ) + # pylint: disable=no-member + return index_job.outputs + + git_index_job = git_to_index( + git_url=input_source.url, + branch_name=input_source.branch_name, + git_connection_id=input_source.connection_id, + ) + # Ensure repo cloned each run to get latest, comment out to have first clone reused. + git_index_job.settings.force_rerun = True + + # Submit the DataIndex Job + return self._all_operations.all_operations[AzureMLResourceType.JOB].create_or_update(git_index_job) + raise ValueError(f"Unsupported input source type {type(input_source)}") + + def _create_data_indexing_job( + self, + data_index: DataIndex, + identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None, + compute: str = "serverless", + serverless_instance_type: Optional[str] = None, + input_data_override: Optional[Input] = None, + submit_job: bool = True, + **kwargs, + ) -> PipelineJob: + """ + Returns the data import job that is creating the data asset. + + :param data_index: DataIndex object. + :type data_index: azure.ai.ml.entities._dataindex + :param identity: Identity configuration for the job. + :type identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] + :param compute: The compute target to use for the job. Default: "serverless". + :type compute: str + :param serverless_instance_type: The instance type to use for serverless compute. + :type serverless_instance_type: Optional[str] + :param input_data_override: Input data override for the job. + Used to pipe output of step into DataIndex Job in a pipeline. + :type input_data_override: Optional[Input] + :param submit_job: Whether to submit the job to the service. Default: True. + :type submit_job: bool + :return: data import job object. + :rtype: ~azure.ai.ml.entities.PipelineJob. + """ + # pylint: disable=no-member + from azure.ai.ml import MLClient + + default_name = "data_index_" + data_index.name if data_index.name is not None else "" + experiment_name = kwargs.pop("experiment_name", None) or default_name + data_index.type = AssetTypes.URI_FOLDER + + # avoid specifying auto_delete_setting in job output now + _validate_auto_delete_setting_in_data_output(data_index.auto_delete_setting) + + # block customer specified path on managed datastore + data_index.path = _validate_workspace_managed_datastore(data_index.path) + + if "${{name}}" not in str(data_index.path) and "{name}" not in str(data_index.path): + data_index.path = str(data_index.path).rstrip("/") + "/${{name}}" + + index_job = index_data_func( + description=data_index.description or kwargs.pop("description", None) or default_name, + name=data_index.name or kwargs.pop("name", None), + display_name=kwargs.pop("display_name", None) or default_name, + experiment_name=experiment_name, + compute=compute, + serverless_instance_type=serverless_instance_type, + data_index=data_index, + ml_client=MLClient( + subscription_id=self._subscription_id, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + credential=self._credential, + ), + identity=identity, + input_data_override=input_data_override, + **kwargs, + ) + index_pipeline = PipelineJob( + description=index_job.description, + tags=index_job.tags, + name=index_job.name, + display_name=index_job.display_name, + experiment_name=experiment_name, + properties=index_job.properties or {}, + settings=PipelineJobSettings(force_rerun=True, default_compute=compute), + jobs={default_name: index_job}, + ) + index_pipeline.properties["azureml.mlIndexAssetName"] = data_index.name + index_pipeline.properties["azureml.mlIndexAssetKind"] = data_index.index.type + index_pipeline.properties["azureml.mlIndexAssetSource"] = kwargs.pop("mlindex_asset_source", "Data Asset") + + if submit_job: + return self._all_operations.all_operations[AzureMLResourceType.JOB].create_or_update( + job=index_pipeline, skip_validation=True, **kwargs + ) + return index_pipeline diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_job_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_job_operations.py new file mode 100644 index 00000000..0003c8cd --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_job_operations.py @@ -0,0 +1,1677 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access, too-many-instance-attributes, too-many-statements, too-many-lines +import json +import os.path +from os import PathLike +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Union, cast + +import jwt +from marshmallow import ValidationError + +from azure.ai.ml._artifacts._artifact_utilities import ( + _upload_and_generate_remote_uri, + aml_datastore_path_exists, + download_artifact_from_aml_uri, +) +from azure.ai.ml._azure_environments import ( + _get_aml_resource_id_from_metadata, + _get_base_url_from_metadata, + _resource_to_scopes, +) +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._restclient.dataset_dataplane import AzureMachineLearningWorkspaces as ServiceClientDatasetDataplane +from azure.ai.ml._restclient.model_dataplane import AzureMachineLearningWorkspaces as ServiceClientModelDataplane +from azure.ai.ml._restclient.runhistory import AzureMachineLearningWorkspaces as ServiceClientRunHistory +from azure.ai.ml._restclient.runhistory.models import Run +from azure.ai.ml._restclient.v2023_04_01_preview import AzureMachineLearningWorkspaces as ServiceClient022023Preview +from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase, ListViewType, UserIdentity +from azure.ai.ml._restclient.v2023_08_01_preview.models import JobType as RestJobType +from azure.ai.ml._restclient.v2024_01_01_preview.models import JobBase as JobBase_2401 +from azure.ai.ml._restclient.v2024_10_01_preview.models import JobType as RestJobType_20241001Preview +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationsContainer, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity, monitor_with_telemetry_mixin +from azure.ai.ml._utils._http_utils import HttpPipeline +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml._utils.utils import ( + create_requests_pipeline_with_retry, + download_text_from_url, + is_data_binding_expression, + is_private_preview_enabled, + is_url, +) +from azure.ai.ml.constants._common import ( + AZUREML_RESOURCE_PROVIDER, + BATCH_JOB_CHILD_RUN_OUTPUT_NAME, + COMMON_RUNTIME_ENV_VAR, + DEFAULT_ARTIFACT_STORE_OUTPUT_NAME, + GIT_PATH_PREFIX, + LEVEL_ONE_NAMED_RESOURCE_ID_FORMAT, + LOCAL_COMPUTE_TARGET, + SERVERLESS_COMPUTE, + SHORT_URI_FORMAT, + SWEEP_JOB_BEST_CHILD_RUN_ID_PROPERTY_NAME, + TID_FMT, + AssetTypes, + AzureMLResourceType, + WorkspaceDiscoveryUrlKey, +) +from azure.ai.ml.constants._compute import ComputeType +from azure.ai.ml.constants._job.pipeline import PipelineConstants +from azure.ai.ml.entities import Compute, Job, PipelineJob, ServiceInstance, ValidationResult +from azure.ai.ml.entities._assets._artifacts.code import Code +from azure.ai.ml.entities._builders import BaseNode, Command, Spark +from azure.ai.ml.entities._datastore._constants import WORKSPACE_BLOB_STORE +from azure.ai.ml.entities._inputs_outputs import Input +from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob +from azure.ai.ml.entities._job.base_job import _BaseJob +from azure.ai.ml.entities._job.distillation.distillation_job import DistillationJob +from azure.ai.ml.entities._job.finetuning.finetuning_job import FineTuningJob +from azure.ai.ml.entities._job.import_job import ImportJob +from azure.ai.ml.entities._job.job import _is_pipeline_child_job +from azure.ai.ml.entities._job.parallel.parallel_job import ParallelJob +from azure.ai.ml.entities._job.to_rest_functions import to_rest_job_object +from azure.ai.ml.entities._validation import PathAwareSchemaValidatableMixin +from azure.ai.ml.exceptions import ( + ComponentException, + ErrorCategory, + ErrorTarget, + JobException, + JobParsingError, + MlException, + PipelineChildJobError, + UserErrorException, + ValidationErrorType, + ValidationException, +) +from azure.ai.ml.operations._run_history_constants import RunHistoryConstants +from azure.ai.ml.sweep import SweepJob +from azure.core.credentials import TokenCredential +from azure.core.exceptions import HttpResponseError, ResourceNotFoundError +from azure.core.polling import LROPoller +from azure.core.tracing.decorator import distributed_trace + +from ..constants._component import ComponentSource +from ..entities._builders.control_flow_node import ControlFlowNode +from ..entities._job.pipeline._io import InputOutputBase, PipelineInput, _GroupAttrDict +from ._component_operations import ComponentOperations +from ._compute_operations import ComputeOperations +from ._dataset_dataplane_operations import DatasetDataplaneOperations +from ._job_ops_helper import get_git_properties, get_job_output_uris_from_dataplane, stream_logs_until_completion +from ._local_job_invoker import is_local_run, start_run_if_local +from ._model_dataplane_operations import ModelDataplaneOperations +from ._operation_orchestrator import ( + OperationOrchestrator, + _AssetResolver, + is_ARM_id_for_resource, + is_registry_id_for_resource, + is_singularity_full_name_for_resource, + is_singularity_id_for_resource, + is_singularity_short_name_for_resource, +) +from ._run_operations import RunOperations +from ._virtual_cluster_operations import VirtualClusterOperations + +try: + pass +except ImportError: + pass + +if TYPE_CHECKING: + from azure.ai.ml.operations import DatastoreOperations + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class JobOperations(_ScopeDependentOperations): + """Initiates an instance of JobOperations + + This class should not be instantiated directly. Instead, use the `jobs` attribute of an MLClient object. + + :param operation_scope: Scope variables for the operations classes of an MLClient object. + :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope + :param operation_config: Common configuration for operations classes of an MLClient object. + :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig + :param service_client_02_2023_preview: Service client to allow end users to operate on Azure Machine Learning + Workspace resources. + :type service_client_02_2023_preview: ~azure.ai.ml._restclient.v2023_02_01_preview.AzureMachineLearningWorkspaces + :param all_operations: All operations classes of an MLClient object. + :type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer + :param credential: Credential to use for authentication. + :type credential: ~azure.core.credentials.TokenCredential + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client_02_2023_preview: ServiceClient022023Preview, + all_operations: OperationsContainer, + credential: TokenCredential, + **kwargs: Any, + ) -> None: + super(JobOperations, self).__init__(operation_scope, operation_config) + ops_logger.update_filter() + + self._operation_2023_02_preview = service_client_02_2023_preview.jobs + self._service_client = service_client_02_2023_preview + self._all_operations = all_operations + self._stream_logs_until_completion = stream_logs_until_completion + # Dataplane service clients are lazily created as they are needed + self._runs_operations_client: Optional[RunOperations] = None + self._dataset_dataplane_operations_client: Optional[DatasetDataplaneOperations] = None + self._model_dataplane_operations_client: Optional[ModelDataplaneOperations] = None + # Kwargs to propagate to dataplane service clients + self._service_client_kwargs = kwargs.pop("_service_client_kwargs", {}) + self._api_base_url: Optional[str] = None + self._container = "azureml" + self._credential = credential + self._orchestrators = OperationOrchestrator(self._all_operations, self._operation_scope, self._operation_config) + + self.service_client_01_2024_preview = kwargs.pop("service_client_01_2024_preview", None) + self.service_client_10_2024_preview = kwargs.pop("service_client_10_2024_preview", None) + self.service_client_01_2025_preview = kwargs.pop("service_client_01_2025_preview", None) + self._kwargs = kwargs + + self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline") + + @property + def _component_operations(self) -> ComponentOperations: + return cast( + ComponentOperations, + self._all_operations.get_operation( # type: ignore[misc] + AzureMLResourceType.COMPONENT, lambda x: isinstance(x, ComponentOperations) + ), + ) + + @property + def _compute_operations(self) -> ComputeOperations: + return cast( + ComputeOperations, + self._all_operations.get_operation( # type: ignore[misc] + AzureMLResourceType.COMPUTE, lambda x: isinstance(x, ComputeOperations) + ), + ) + + @property + def _virtual_cluster_operations(self) -> VirtualClusterOperations: + return cast( + VirtualClusterOperations, + self._all_operations.get_operation( # type: ignore[misc] + AzureMLResourceType.VIRTUALCLUSTER, + lambda x: isinstance(x, VirtualClusterOperations), + ), + ) + + @property + def _datastore_operations(self) -> "DatastoreOperations": + from azure.ai.ml.operations import DatastoreOperations + + return cast(DatastoreOperations, self._all_operations.all_operations[AzureMLResourceType.DATASTORE]) + + @property + def _runs_operations(self) -> RunOperations: + if not self._runs_operations_client: + service_client_run_history = ServiceClientRunHistory( + self._credential, base_url=self._api_url, **self._service_client_kwargs + ) + self._runs_operations_client = RunOperations( + self._operation_scope, self._operation_config, service_client_run_history + ) + return self._runs_operations_client + + @property + def _dataset_dataplane_operations(self) -> DatasetDataplaneOperations: + if not self._dataset_dataplane_operations_client: + service_client_dataset_dataplane = ServiceClientDatasetDataplane( + self._credential, base_url=self._api_url, **self._service_client_kwargs + ) + self._dataset_dataplane_operations_client = DatasetDataplaneOperations( + self._operation_scope, + self._operation_config, + service_client_dataset_dataplane, + ) + return self._dataset_dataplane_operations_client + + @property + def _model_dataplane_operations(self) -> ModelDataplaneOperations: + if not self._model_dataplane_operations_client: + service_client_model_dataplane = ServiceClientModelDataplane( + self._credential, base_url=self._api_url, **self._service_client_kwargs + ) + self._model_dataplane_operations_client = ModelDataplaneOperations( + self._operation_scope, + self._operation_config, + service_client_model_dataplane, + ) + return self._model_dataplane_operations_client + + @property + def _api_url(self) -> str: + if not self._api_base_url: + self._api_base_url = self._get_workspace_url(url_key=WorkspaceDiscoveryUrlKey.API) + return self._api_base_url + + @distributed_trace + @monitor_with_activity(ops_logger, "Job.List", ActivityType.PUBLICAPI) + def list( + self, + *, + parent_job_name: Optional[str] = None, + list_view_type: ListViewType = ListViewType.ACTIVE_ONLY, + **kwargs: Any, + ) -> Iterable[Job]: + """Lists jobs in the workspace. + + :keyword parent_job_name: When provided, only returns jobs that are children of the named job. Defaults to None, + listing all jobs in the workspace. + :paramtype parent_job_name: Optional[str] + :keyword list_view_type: The view type for including/excluding archived jobs. Defaults to + ~azure.mgt.machinelearningservices.models.ListViewType.ACTIVE_ONLY, excluding archived jobs. + :paramtype list_view_type: ~azure.mgmt.machinelearningservices.models.ListViewType + :return: An iterator-like instance of Job objects. + :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.Job] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START job_operations_list] + :end-before: [END job_operations_list] + :language: python + :dedent: 8 + :caption: Retrieving a list of the archived jobs in a workspace with parent job named + "iris-dataset-jobs". + """ + + schedule_defined = kwargs.pop("schedule_defined", None) + scheduled_job_name = kwargs.pop("scheduled_job_name", None) + max_results = kwargs.pop("max_results", None) + + if parent_job_name: + parent_job = self.get(parent_job_name) + return self._runs_operations.get_run_children(parent_job.name, max_results=max_results) + + return cast( + Iterable[Job], + self.service_client_01_2024_preview.jobs.list( + self._operation_scope.resource_group_name, + self._workspace_name, + cls=lambda objs: [self._handle_rest_errors(obj) for obj in objs], + list_view_type=list_view_type, + scheduled=schedule_defined, + schedule_id=scheduled_job_name, + **self._kwargs, + **kwargs, + ), + ) + + def _handle_rest_errors(self, job_object: Union[JobBase, Run]) -> Optional[Job]: + """Handle errors while resolving azureml_id's during list operation. + + :param job_object: The REST object to turn into a Job + :type job_object: Union[JobBase, Run] + :return: The resolved job + :rtype: Optional[Job] + """ + try: + return self._resolve_azureml_id(Job._from_rest_object(job_object)) + except JobParsingError: + return None + + @distributed_trace + @monitor_with_telemetry_mixin(ops_logger, "Job.Get", ActivityType.PUBLICAPI) + def get(self, name: str) -> Job: + """Gets a job resource. + + :param name: The name of the job. + :type name: str + :raises azure.core.exceptions.ResourceNotFoundError: Raised if no job with the given name can be found. + :raises ~azure.ai.ml.exceptions.UserErrorException: Raised if the name parameter is not a string. + :return: Job object retrieved from the service. + :rtype: ~azure.ai.ml.entities.Job + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START job_operations_get] + :end-before: [END job_operations_get] + :language: python + :dedent: 8 + :caption: Retrieving a job named "iris-dataset-job-1". + """ + if not isinstance(name, str): + raise UserErrorException(f"{name} is a invalid input for client.jobs.get().") + job_object = self._get_job(name) + + job: Any = None + if not _is_pipeline_child_job(job_object): + job = Job._from_rest_object(job_object) + if job_object.properties.job_type != RestJobType.AUTO_ML: + # resolvers do not work with the old contract, leave the ids as is + job = self._resolve_azureml_id(job) + else: + # Child jobs are no longer available through MFE, fetch + # through run history instead + job = self._runs_operations._translate_from_rest_object(self._runs_operations.get_run(name)) + + return job + + @distributed_trace + @monitor_with_telemetry_mixin(ops_logger, "Job.ShowServices", ActivityType.PUBLICAPI) + def show_services(self, name: str, node_index: int = 0) -> Optional[Dict[str, ServiceInstance]]: + """Gets services associated with a job's node. + + :param name: The name of the job. + :type name: str + :param node_index: The node's index (zero-based). Defaults to 0. + :type node_index: int + :return: The services associated with the job for the given node. + :rtype: dict[str, ~azure.ai.ml.entities.ServiceInstance] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START job_operations_show_services] + :end-before: [END job_operations_show_services] + :language: python + :dedent: 8 + :caption: Retrieving the services associated with a job's 1st node. + """ + + service_instances_dict = self._runs_operations._operation.get_run_service_instances( + self._subscription_id, + self._operation_scope.resource_group_name, + self._workspace_name, + name, + node_index, + ) + if not service_instances_dict.instances: + return None + + return { + k: ServiceInstance._from_rest_object(v, node_index) for k, v in service_instances_dict.instances.items() + } + + @distributed_trace + @monitor_with_activity(ops_logger, "Job.Cancel", ActivityType.PUBLICAPI) + def begin_cancel(self, name: str, **kwargs: Any) -> LROPoller[None]: + """Cancels a job. + + :param name: The name of the job. + :type name: str + :raises azure.core.exceptions.ResourceNotFoundError: Raised if no job with the given name can be found. + :return: A poller to track the operation status. + :rtype: ~azure.core.polling.LROPoller[None] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START job_operations_begin_cancel] + :end-before: [END job_operations_begin_cancel] + :language: python + :dedent: 8 + :caption: Canceling the job named "iris-dataset-job-1" and checking the poller for status. + """ + tag = kwargs.pop("tag", None) + + if not tag: + return self._operation_2023_02_preview.begin_cancel( + id=name, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + **self._kwargs, + **kwargs, + ) + + # Note: Below batch cancel is experimental and for private usage + results = [] + jobs = self.list(tag=tag) + # TODO: Do we need to show error message when no jobs is returned for the given tag? + for job in jobs: + result = self._operation_2023_02_preview.begin_cancel( + id=job.name, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + **self._kwargs, + ) + results.append(result) + return results + + def _try_get_compute_arm_id(self, compute: Union[Compute, str]) -> Optional[Union[Compute, str]]: + # pylint: disable=too-many-return-statements + # TODO: Remove in PuP with native import job/component type support in MFE/Designer + # DataFactory 'clusterless' job + if str(compute) == ComputeType.ADF: + return compute + + if compute is not None: + # Singularity + if isinstance(compute, str) and is_singularity_id_for_resource(compute): + return self._virtual_cluster_operations.get(compute)["id"] + if isinstance(compute, str) and is_singularity_full_name_for_resource(compute): + return self._orchestrators._get_singularity_arm_id_from_full_name(compute) + if isinstance(compute, str) and is_singularity_short_name_for_resource(compute): + return self._orchestrators._get_singularity_arm_id_from_short_name(compute) + # other compute + if is_ARM_id_for_resource(compute, resource_type=AzureMLResourceType.COMPUTE): + # compute is not a sub-workspace resource + compute_name = compute.split("/")[-1] # type: ignore + elif isinstance(compute, Compute): + compute_name = compute.name + elif isinstance(compute, str): + compute_name = compute + elif isinstance(compute, PipelineInput): + compute_name = str(compute) + else: + raise ValueError( + "compute must be either an arm id of Compute, a Compute object or a compute name but" + f" got {type(compute)}" + ) + + if is_data_binding_expression(compute_name): + return compute_name + if compute_name == SERVERLESS_COMPUTE: + return compute_name + try: + return self._compute_operations.get(compute_name).id + except ResourceNotFoundError as e: + # the original error is not helpful (Operation returned an invalid status 'Not Found'), + # so we raise a more helpful one + response = e.response + response.reason = "Not found compute with name {}".format(compute_name) + raise ResourceNotFoundError(response=response) from e + return None + + @distributed_trace + @monitor_with_telemetry_mixin(ops_logger, "Job.Validate", ActivityType.PUBLICAPI) + def validate(self, job: Job, *, raise_on_failure: bool = False, **kwargs: Any) -> ValidationResult: + """Validates a Job object before submitting to the service. Anonymous assets may be created if there are inline + defined entities such as Component, Environment, and Code. Only pipeline jobs are supported for validation + currently. + + :param job: The job object to be validated. + :type job: ~azure.ai.ml.entities.Job + :keyword raise_on_failure: Specifies if an error should be raised if validation fails. Defaults to False. + :paramtype raise_on_failure: bool + :return: A ValidationResult object containing all found errors. + :rtype: ~azure.ai.ml.entities.ValidationResult + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START job_operations_validate] + :end-before: [END job_operations_validate] + :language: python + :dedent: 8 + :caption: Validating a PipelineJob object and printing out the found errors. + """ + return self._validate(job, raise_on_failure=raise_on_failure, **kwargs) + + @monitor_with_telemetry_mixin(ops_logger, "Job.Validate", ActivityType.INTERNALCALL) + def _validate( + self, + job: Job, + *, + raise_on_failure: bool = False, + # pylint:disable=unused-argument + **kwargs: Any, + ) -> ValidationResult: + """Implementation of validate. + + Add this function to avoid calling validate() directly in + create_or_update(), which will impact telemetry statistics & + bring experimental warning in create_or_update(). + + :param job: The job to validate + :type job: Job + :keyword raise_on_failure: Whether to raise on validation failure + :paramtype raise_on_failure: bool + :return: The validation result + :rtype: ValidationResult + """ + git_code_validation_result = PathAwareSchemaValidatableMixin._create_empty_validation_result() + # TODO: move this check to Job._validate after validation is supported for all job types + # If private features are enable and job has code value of type str we need to check + # that it is a valid git path case. Otherwise we should throw a ValidationException + # saying that the code value is not a valid code value + if ( + hasattr(job, "code") + and job.code is not None + and isinstance(job.code, str) + and job.code.startswith(GIT_PATH_PREFIX) + and not is_private_preview_enabled() + ): + git_code_validation_result.append_error( + message=f"Invalid code value: {job.code}. Git paths are not supported.", + yaml_path="code", + ) + + if not isinstance(job, PathAwareSchemaValidatableMixin): + + def error_func(msg: str, no_personal_data_msg: str) -> ValidationException: + return ValidationException( + message=msg, + no_personal_data_message=no_personal_data_msg, + error_target=ErrorTarget.JOB, + error_category=ErrorCategory.USER_ERROR, + ) + + return git_code_validation_result.try_raise( + raise_error=raise_on_failure, + error_func=error_func, + ) + + validation_result = job._validate(raise_error=raise_on_failure) + validation_result.merge_with(git_code_validation_result) + # fast return to avoid remote call if local validation not passed + # TODO: use remote call to validate the entire job after MFE API is ready + if validation_result.passed and isinstance(job, PipelineJob): + try: + job.compute = self._try_get_compute_arm_id(job.compute) + except Exception as e: # pylint: disable=W0718 + validation_result.append_error(yaml_path="compute", message=str(e)) + + for node_name, node in job.jobs.items(): + try: + # TODO(1979547): refactor, not all nodes have compute + if not isinstance(node, ControlFlowNode): + node.compute = self._try_get_compute_arm_id(node.compute) + except Exception as e: # pylint: disable=W0718 + validation_result.append_error(yaml_path=f"jobs.{node_name}.compute", message=str(e)) + + validation_result.resolve_location_for_diagnostics(str(job._source_path)) + return job._try_raise(validation_result, raise_error=raise_on_failure) # pylint: disable=protected-access + + @distributed_trace + @monitor_with_telemetry_mixin(ops_logger, "Job.CreateOrUpdate", ActivityType.PUBLICAPI) + def create_or_update( + self, + job: Job, + *, + description: Optional[str] = None, + compute: Optional[str] = None, + tags: Optional[dict] = None, + experiment_name: Optional[str] = None, + skip_validation: bool = False, + **kwargs: Any, + ) -> Job: + """Creates or updates a job. If entities such as Environment or Code are defined inline, they'll be created + together with the job. + + :param job: The job object. + :type job: ~azure.ai.ml.entities.Job + :keyword description: The job description. + :paramtype description: Optional[str] + :keyword compute: The compute target for the job. + :paramtype compute: Optional[str] + :keyword tags: The tags for the job. + :paramtype tags: Optional[dict] + :keyword experiment_name: The name of the experiment the job will be created under. If None is provided, + job will be created under experiment 'Default'. + :paramtype experiment_name: Optional[str] + :keyword skip_validation: Specifies whether or not to skip validation before creating or updating the job. Note + that validation for dependent resources such as an anonymous component will not be skipped. Defaults to + False. + :paramtype skip_validation: bool + :raises Union[~azure.ai.ml.exceptions.UserErrorException, ~azure.ai.ml.exceptions.ValidationException]: Raised + if Job cannot be successfully validated. Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.AssetException: Raised if Job assets + (e.g. Data, Code, Model, Environment) cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.ModelException: Raised if Job model cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.JobException: Raised if Job object or attributes correctly formatted. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty + directory. + :raises ~azure.ai.ml.exceptions.DockerEngineNotAvailableError: Raised if Docker Engine is not available for + local job. + :return: Created or updated job. + :rtype: ~azure.ai.ml.entities.Job + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START job_operations_create_and_update] + :end-before: [END job_operations_create_and_update] + :language: python + :dedent: 8 + :caption: Creating a new job and then updating its compute. + """ + if isinstance(job, BaseNode) and not ( + isinstance(job, (Command, Spark)) + ): # Command/Spark objects can be used directly + job = job._to_job() + + # Set job properties before submission + if description is not None: + job.description = description + if compute is not None: + job.compute = compute + if tags is not None: + job.tags = tags + if experiment_name is not None: + job.experiment_name = experiment_name + + if job.compute == LOCAL_COMPUTE_TARGET: + job.environment_variables[COMMON_RUNTIME_ENV_VAR] = "true" # type: ignore + + # TODO: why we put log logic here instead of inside self._validate()? + try: + if not skip_validation: + self._validate(job, raise_on_failure=True) + + # Create all dependent resources + self._resolve_arm_id_or_upload_dependencies(job) + except (ValidationException, ValidationError) as ex: + log_and_raise_error(ex) + + git_props = get_git_properties() + # Do not add git props if they already exist in job properties. + # This is for update specifically-- if the user switches branches and tries to update + # their job, the request will fail since the git props will be repopulated. + # MFE does not allow existing properties to be updated, only for new props to be added + if not any(prop_name in job.properties for prop_name in git_props): + job.properties = {**job.properties, **git_props} + rest_job_resource = to_rest_job_object(job) + + # Make a copy of self._kwargs instead of contaminate the original one + kwargs = {**self._kwargs} + # set headers with user aml token if job is a pipeline or has a user identity setting + if (rest_job_resource.properties.job_type == RestJobType.PIPELINE) or ( + hasattr(rest_job_resource.properties, "identity") + and (isinstance(rest_job_resource.properties.identity, UserIdentity)) + ): + self._set_headers_with_user_aml_token(kwargs) + + result = self._create_or_update_with_different_version_api(rest_job_resource=rest_job_resource, **kwargs) + + if is_local_run(result): + ws_base_url = self._all_operations.all_operations[ + AzureMLResourceType.WORKSPACE + ]._operation._client._base_url + snapshot_id = start_run_if_local( + result, + self._credential, + ws_base_url, + self._requests_pipeline, + ) + # in case of local run, the first create/update call to MFE returns the + # request for submitting to ES. Once we request to ES and start the run, we + # need to put the same body to MFE to append user tags etc. + if rest_job_resource.properties.job_type == RestJobType.PIPELINE: + job_object = self._get_job_2401(rest_job_resource.name) + else: + job_object = self._get_job(rest_job_resource.name) + if result.properties.tags is not None: + for tag_name, tag_value in rest_job_resource.properties.tags.items(): + job_object.properties.tags[tag_name] = tag_value + if result.properties.properties is not None: + for ( + prop_name, + prop_value, + ) in rest_job_resource.properties.properties.items(): + job_object.properties.properties[prop_name] = prop_value + if snapshot_id is not None: + job_object.properties.properties["ContentSnapshotId"] = snapshot_id + + result = self._create_or_update_with_different_version_api(rest_job_resource=job_object, **kwargs) + + return self._resolve_azureml_id(Job._from_rest_object(result)) + + def _create_or_update_with_different_version_api(self, rest_job_resource: JobBase, **kwargs: Any) -> JobBase: + service_client_operation = self._operation_2023_02_preview + if rest_job_resource.properties.job_type == RestJobType_20241001Preview.FINE_TUNING: + service_client_operation = self.service_client_10_2024_preview.jobs + if rest_job_resource.properties.job_type == RestJobType.PIPELINE: + service_client_operation = self.service_client_01_2024_preview.jobs + if rest_job_resource.properties.job_type == RestJobType.AUTO_ML: + service_client_operation = self.service_client_01_2024_preview.jobs + if rest_job_resource.properties.job_type == RestJobType.SWEEP: + service_client_operation = self.service_client_01_2024_preview.jobs + if rest_job_resource.properties.job_type == RestJobType.COMMAND: + service_client_operation = self.service_client_01_2025_preview.jobs + + result = service_client_operation.create_or_update( + id=rest_job_resource.name, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + body=rest_job_resource, + **kwargs, + ) + + return result + + def _create_or_update_with_latest_version_api(self, rest_job_resource: JobBase, **kwargs: Any) -> JobBase: + service_client_operation = self.service_client_01_2024_preview.jobs + result = service_client_operation.create_or_update( + id=rest_job_resource.name, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + body=rest_job_resource, + **kwargs, + ) + + return result + + def _archive_or_restore(self, name: str, is_archived: bool) -> None: + job_object = self._get_job(name) + if job_object.properties.job_type == RestJobType.PIPELINE: + job_object = self._get_job_2401(name) + if _is_pipeline_child_job(job_object): + raise PipelineChildJobError(job_id=job_object.id) + job_object.properties.is_archived = is_archived + + self._create_or_update_with_different_version_api(rest_job_resource=job_object) + + @distributed_trace + @monitor_with_telemetry_mixin(ops_logger, "Job.Archive", ActivityType.PUBLICAPI) + def archive(self, name: str) -> None: + """Archives a job. + + :param name: The name of the job. + :type name: str + :raises azure.core.exceptions.ResourceNotFoundError: Raised if no job with the given name can be found. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START job_operations_archive] + :end-before: [END job_operations_archive] + :language: python + :dedent: 8 + :caption: Archiving a job. + """ + + self._archive_or_restore(name=name, is_archived=True) + + @distributed_trace + @monitor_with_telemetry_mixin(ops_logger, "Job.Restore", ActivityType.PUBLICAPI) + def restore(self, name: str) -> None: + """Restores an archived job. + + :param name: The name of the job. + :type name: str + :raises azure.core.exceptions.ResourceNotFoundError: Raised if no job with the given name can be found. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START job_operations_restore] + :end-before: [END job_operations_restore] + :language: python + :dedent: 8 + :caption: Restoring an archived job. + """ + + self._archive_or_restore(name=name, is_archived=False) + + @distributed_trace + @monitor_with_activity(ops_logger, "Job.Stream", ActivityType.PUBLICAPI) + def stream(self, name: str) -> None: + """Streams the logs of a running job. + + :param name: The name of the job. + :type name: str + :raises azure.core.exceptions.ResourceNotFoundError: Raised if no job with the given name can be found. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START job_operations_stream_logs] + :end-before: [END job_operations_stream_logs] + :language: python + :dedent: 8 + :caption: Streaming a running job. + """ + job_object = self._get_job(name) + + if _is_pipeline_child_job(job_object): + raise PipelineChildJobError(job_id=job_object.id) + + self._stream_logs_until_completion( + self._runs_operations, + job_object, + self._datastore_operations, + requests_pipeline=self._requests_pipeline, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "Job.Download", ActivityType.PUBLICAPI) + def download( + self, + name: str, + *, + download_path: Union[PathLike, str] = ".", + output_name: Optional[str] = None, + all: bool = False, # pylint: disable=redefined-builtin + ) -> None: + """Downloads the logs and output of a job. + + :param name: The name of a job. + :type name: str + :keyword download_path: The local path to be used as the download destination. Defaults to ".". + :paramtype download_path: Union[PathLike, str] + :keyword output_name: The name of the output to download. Defaults to None. + :paramtype output_name: Optional[str] + :keyword all: Specifies if all logs and named outputs should be downloaded. Defaults to False. + :paramtype all: bool + :raises ~azure.ai.ml.exceptions.JobException: Raised if Job is not yet in a terminal state. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.MlException: Raised if logs and outputs cannot be successfully downloaded. + Details will be provided in the error message. + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START job_operations_download] + :end-before: [END job_operations_download] + :language: python + :dedent: 8 + :caption: Downloading all logs and named outputs of the job "job-1" into local directory "job-1-logs". + """ + job_details = self.get(name) + # job is reused, get reused job to download + if ( + job_details.properties.get(PipelineConstants.REUSED_FLAG_FIELD) == PipelineConstants.REUSED_FLAG_TRUE + and PipelineConstants.REUSED_JOB_ID in job_details.properties + ): + reused_job_name = job_details.properties[PipelineConstants.REUSED_JOB_ID] + reused_job_detail = self.get(reused_job_name) + module_logger.info( + "job %s reuses previous job %s, download from the reused job.", + name, + reused_job_name, + ) + name, job_details = reused_job_name, reused_job_detail + job_status = job_details.status + if job_status not in RunHistoryConstants.TERMINAL_STATUSES: + msg = "This job is in state {}. Download is allowed only in states {}".format( + job_status, RunHistoryConstants.TERMINAL_STATUSES + ) + raise JobException( + message=msg, + target=ErrorTarget.JOB, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + + is_batch_job = ( + job_details.tags.get("azureml.batchrun", None) == "true" + and job_details.tags.get("azureml.jobtype", None) != PipelineConstants.PIPELINE_JOB_TYPE + ) + outputs = {} + download_path = Path(download_path) + artifact_directory_name = "artifacts" + output_directory_name = "named-outputs" + + def log_missing_uri(what: str) -> None: + module_logger.debug( + 'Could not download %s for job "%s" (job status: %s)', + what, + job_details.name, + job_details.status, + ) + + if isinstance(job_details, SweepJob): + best_child_run_id = job_details.properties.get(SWEEP_JOB_BEST_CHILD_RUN_ID_PROPERTY_NAME, None) + if best_child_run_id: + self.download( + best_child_run_id, + download_path=download_path, + output_name=output_name, + all=all, + ) + else: + log_missing_uri(what="from best child run") + + if output_name: + # Don't need to download anything from the parent + return + # only download default artifacts (logs + default outputs) from parent + artifact_directory_name = "hd-artifacts" + output_name = None + all = False + + if is_batch_job: + scoring_uri = self._get_batch_job_scoring_output_uri(job_details.name) + if scoring_uri: + outputs = {BATCH_JOB_CHILD_RUN_OUTPUT_NAME: scoring_uri} + else: + log_missing_uri("batch job scoring file") + elif output_name: + outputs = self._get_named_output_uri(name, output_name) + + if output_name not in outputs: + log_missing_uri(what=f'output "{output_name}"') + elif all: + outputs = self._get_named_output_uri(name) + + if DEFAULT_ARTIFACT_STORE_OUTPUT_NAME not in outputs: + log_missing_uri(what="logs") + else: + outputs = self._get_named_output_uri(name, DEFAULT_ARTIFACT_STORE_OUTPUT_NAME) + + if DEFAULT_ARTIFACT_STORE_OUTPUT_NAME not in outputs: + log_missing_uri(what="logs") + + # Download all requested artifacts + for item_name, uri in outputs.items(): + if is_batch_job: + destination = download_path + elif item_name == DEFAULT_ARTIFACT_STORE_OUTPUT_NAME: + destination = download_path / artifact_directory_name + else: + destination = download_path / output_directory_name / item_name + + module_logger.info("Downloading artifact %s to %s", uri, destination) + download_artifact_from_aml_uri( + uri=uri, + destination=destination, # type: ignore[arg-type] + datastore_operation=self._datastore_operations, + ) + + def _get_named_output_uri( + self, job_name: Optional[str], output_names: Optional[Union[Iterable[str], str]] = None + ) -> Dict[str, str]: + """Gets the URIs to the specified named outputs of job. + + :param job_name: Run ID of the job + :type job_name: str + :param output_names: Either an output name, or an iterable of output names. If omitted, all outputs are + returned. + :type output_names: Optional[Union[Iterable[str], str]] + :return: Map of output_names to URIs. Note that URIs that could not be found will not be present in the map. + :rtype: Dict[str, str] + """ + + if isinstance(output_names, str): + output_names = {output_names} + elif output_names: + output_names = set(output_names) + + outputs = get_job_output_uris_from_dataplane( + job_name, + self._runs_operations, + self._dataset_dataplane_operations, + self._model_dataplane_operations, + output_names=output_names, + ) + + missing_outputs: Set = set() + if output_names is not None: + missing_outputs = set(output_names).difference(outputs.keys()) + else: + missing_outputs = set().difference(outputs.keys()) + + # Include default artifact store in outputs + if (not output_names) or DEFAULT_ARTIFACT_STORE_OUTPUT_NAME in missing_outputs: + try: + job = self.get(job_name) + artifact_store_uri = job.outputs[DEFAULT_ARTIFACT_STORE_OUTPUT_NAME] + if artifact_store_uri is not None and artifact_store_uri.path: + outputs[DEFAULT_ARTIFACT_STORE_OUTPUT_NAME] = artifact_store_uri.path + except (AttributeError, KeyError): + outputs[DEFAULT_ARTIFACT_STORE_OUTPUT_NAME] = SHORT_URI_FORMAT.format( + "workspaceartifactstore", f"ExperimentRun/dcid.{job_name}/" + ) + missing_outputs.discard(DEFAULT_ARTIFACT_STORE_OUTPUT_NAME) + + # A job's output is not always reported in the outputs dict, but + # doesn't currently have a user configurable location. + # Perform a search of known paths to find output + # TODO: Remove once job output locations are reliably returned from the service + + default_datastore = self._datastore_operations.get_default().name + + for name in missing_outputs: + potential_uris = [ + SHORT_URI_FORMAT.format(default_datastore, f"azureml/{job_name}/{name}/"), + SHORT_URI_FORMAT.format(default_datastore, f"dataset/{job_name}/{name}/"), + ] + + for potential_uri in potential_uris: + if aml_datastore_path_exists(potential_uri, self._datastore_operations): + outputs[name] = potential_uri + break + + return outputs + + def _get_batch_job_scoring_output_uri(self, job_name: str) -> Optional[str]: + uri = None + # Download scoring output, which is the "score" output of the child job named "batchscoring" + # Batch Jobs are pipeline jobs with only one child, so this should terminate after an iteration + for child in self._runs_operations.get_run_children(job_name): + uri = self._get_named_output_uri(child.name, BATCH_JOB_CHILD_RUN_OUTPUT_NAME).get( + BATCH_JOB_CHILD_RUN_OUTPUT_NAME, None + ) + # After the correct child is found, break to prevent unnecessary looping + if uri is not None: + break + return uri + + def _get_job(self, name: str) -> JobBase: + job = self.service_client_01_2024_preview.jobs.get( + id=name, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + **self._kwargs, + ) + + if ( + hasattr(job, "properties") + and job.properties + and hasattr(job.properties, "job_type") + and job.properties.job_type == RestJobType_20241001Preview.FINE_TUNING + ): + return self.service_client_10_2024_preview.jobs.get( + id=name, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + **self._kwargs, + ) + + return job + + # Upgrade api from 2023-04-01-preview to 2024-01-01-preview for pipeline job + # We can remove this function once `_get_job` function has also been upgraded to the same version with pipeline + def _get_job_2401(self, name: str) -> JobBase_2401: + service_client_operation = self.service_client_01_2024_preview.jobs + return service_client_operation.get( + id=name, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + **self._kwargs, + ) + + def _get_workspace_url(self, url_key: WorkspaceDiscoveryUrlKey) -> str: + discovery_url = ( + self._all_operations.all_operations[AzureMLResourceType.WORKSPACE] + .get(self._operation_scope.workspace_name) + .discovery_url + ) + all_urls = json.loads( + download_text_from_url( + discovery_url, + create_requests_pipeline_with_retry(requests_pipeline=self._requests_pipeline), + ) + ) + return all_urls[url_key] + + def _resolve_arm_id_or_upload_dependencies(self, job: Job) -> None: + """This method converts name or name:version to ARM id. Or it + registers/uploads nested dependencies. + + :param job: the job resource entity + :type job: Job + :return: the job resource entity that nested dependencies are resolved + :rtype: Job + """ + + self._resolve_arm_id_or_azureml_id(job, self._orchestrators.get_asset_arm_id) + + if isinstance(job, PipelineJob): + # Resolve top-level inputs + self._resolve_job_inputs(self._flatten_group_inputs(job.inputs), job._base_path) + # inputs in sub-pipelines has been resolved in + # self._resolve_arm_id_or_azureml_id(job, self._orchestrators.get_asset_arm_id) + # as they are part of the pipeline component + elif isinstance(job, AutoMLJob): + self._resolve_automl_job_inputs(job) + elif isinstance(job, FineTuningJob): + self._resolve_finetuning_job_inputs(job) + elif isinstance(job, DistillationJob): + self._resolve_distillation_job_inputs(job) + elif isinstance(job, Spark): + self._resolve_job_inputs(job._job_inputs.values(), job._base_path) + elif isinstance(job, Command): + # TODO: switch to use inputs of Command objects, once the inputs/outputs building + # logic is removed from the BaseNode constructor. + try: + self._resolve_job_inputs(job._job_inputs.values(), job._base_path) + except AttributeError: + # If the job object doesn't have "inputs" attribute, we don't need to resolve. E.g. AutoML jobs + pass + else: + try: + self._resolve_job_inputs(job.inputs.values(), job._base_path) # type: ignore + except AttributeError: + # If the job object doesn't have "inputs" attribute, we don't need to resolve. E.g. AutoML jobs + pass + + def _resolve_automl_job_inputs(self, job: AutoMLJob) -> None: + """This method resolves the inputs for AutoML jobs. + + :param job: the job resource entity + :type job: AutoMLJob + """ + if isinstance(job, AutoMLJob): + self._resolve_job_input(job.training_data, job._base_path) + if job.validation_data is not None: + self._resolve_job_input(job.validation_data, job._base_path) + if hasattr(job, "test_data") and job.test_data is not None: + self._resolve_job_input(job.test_data, job._base_path) + + def _resolve_finetuning_job_inputs(self, job: FineTuningJob) -> None: + """This method resolves the inputs for FineTuning jobs. + + :param job: the job resource entity + :type job: FineTuningJob + """ + from azure.ai.ml.entities._job.finetuning.finetuning_vertical import FineTuningVertical + + if isinstance(job, FineTuningVertical): + # self._resolve_job_input(job.model, job._base_path) + self._resolve_job_input(job.training_data, job._base_path) + if job.validation_data is not None: + self._resolve_job_input(job.validation_data, job._base_path) + + def _resolve_distillation_job_inputs(self, job: DistillationJob) -> None: + """This method resolves the inputs for Distillation jobs. + + :param job: the job resource entity + :type job: DistillationJob + """ + if isinstance(job, DistillationJob): + if job.training_data: + self._resolve_job_input(job.training_data, job._base_path) + if job.validation_data is not None: + self._resolve_job_input(job.validation_data, job._base_path) + + def _resolve_azureml_id(self, job: Job) -> Job: + """This method converts ARM id to name or name:version for nested + entities. + + :param job: the job resource entity + :type job: Job + :return: the job resource entity that nested dependencies are resolved + :rtype: Job + """ + self._append_tid_to_studio_url(job) + self._resolve_job_inputs_arm_id(job) + return self._resolve_arm_id_or_azureml_id(job, self._orchestrators.resolve_azureml_id) + + def _resolve_compute_id(self, resolver: _AssetResolver, target: Any) -> Any: + # special case for local runs + if target is not None and target.lower() == LOCAL_COMPUTE_TARGET: + return LOCAL_COMPUTE_TARGET + try: + modified_target_name = target + if target.lower().startswith(AzureMLResourceType.VIRTUALCLUSTER + "/"): + # Compute target can be either workspace-scoped compute type, + # or AML scoped VC. In the case of VC, resource name will be of form + # azureml:virtualClusters/<name> to disambiguate from azureml:name (which is always compute) + modified_target_name = modified_target_name[len(AzureMLResourceType.VIRTUALCLUSTER) + 1 :] + modified_target_name = LEVEL_ONE_NAMED_RESOURCE_ID_FORMAT.format( + self._operation_scope.subscription_id, + self._operation_scope.resource_group_name, + AZUREML_RESOURCE_PROVIDER, + AzureMLResourceType.VIRTUALCLUSTER, + modified_target_name, + ) + return resolver( + modified_target_name, + azureml_type=AzureMLResourceType.VIRTUALCLUSTER, + sub_workspace_resource=False, + ) + except Exception: # pylint: disable=W0718 + return resolver(target, azureml_type=AzureMLResourceType.COMPUTE) + + def _resolve_job_inputs(self, entries: Iterable[Union[Input, str, bool, int, float]], base_path: str) -> None: + """resolve job inputs as ARM id or remote url. + + :param entries: An iterable of job inputs + :type entries: Iterable[Union[Input, str, bool, int, float]] + :param base_path: The base path + :type base_path: str + """ + for entry in entries: + self._resolve_job_input(entry, base_path) + + # TODO: move it to somewhere else? + @classmethod + def _flatten_group_inputs( + cls, inputs: Dict[str, Union[Input, str, bool, int, float]] + ) -> List[Union[Input, str, bool, int, float]]: + """Get flatten values from an InputDict. + + :param inputs: The input dict + :type inputs: Dict[str, Union[Input, str, bool, int, float]] + :return: A list of values from the Input Dict + :rtype: List[Union[Input, str, bool, int, float]] + """ + input_values: List = [] + # Flatten inputs for pipeline job + for key, item in inputs.items(): + if isinstance(item, _GroupAttrDict): + input_values.extend(item.flatten(group_parameter_name=key)) + else: + if not isinstance(item, (str, bool, int, float)): + # skip resolving inferred optional input without path (in do-while + dynamic input case) + if isinstance(item._data, Input): # type: ignore + if not item._data.path and item._meta._is_inferred_optional: # type: ignore + continue + input_values.append(item._data) # type: ignore + return input_values + + def _resolve_job_input(self, entry: Union[Input, str, bool, int, float], base_path: str) -> None: + """resolve job input as ARM id or remote url. + + :param entry: The job input + :type entry: Union[Input, str, bool, int, float] + :param base_path: The base path + :type base_path: str + """ + + # path can be empty if the job was created from builder functions + if isinstance(entry, Input) and not entry.path: + msg = "Input path can't be empty for jobs." + raise ValidationException( + message=msg, + target=ErrorTarget.JOB, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + + if ( + not isinstance(entry, Input) + or is_ARM_id_for_resource(entry.path) + or is_url(entry.path) + or is_data_binding_expression(entry.path) # literal value but set mode in pipeline yaml + ): # Literal value, ARM id or remote url. Pass through + return + try: + datastore_name = ( + entry.datastore if hasattr(entry, "datastore") and entry.datastore else WORKSPACE_BLOB_STORE + ) + + # absolute local path, upload, transform to remote url + if os.path.isabs(entry.path): # type: ignore + if entry.type == AssetTypes.URI_FOLDER and not os.path.isdir(entry.path): # type: ignore + raise ValidationException( + message="There is no dir on target path: {}".format(entry.path), + target=ErrorTarget.JOB, + no_personal_data_message="There is no dir on target path", + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND, + ) + if entry.type == AssetTypes.URI_FILE and not os.path.isfile(entry.path): # type: ignore + raise ValidationException( + message="There is no file on target path: {}".format(entry.path), + target=ErrorTarget.JOB, + no_personal_data_message="There is no file on target path", + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND, + ) + # absolute local path + entry.path = _upload_and_generate_remote_uri( + self._operation_scope, + self._datastore_operations, + entry.path, + datastore_name=datastore_name, + show_progress=self._show_progress, + ) + # TODO : Move this part to a common place + if entry.type == AssetTypes.URI_FOLDER and entry.path and not entry.path.endswith("/"): + entry.path = entry.path + "/" + # Check for AzureML id, is there a better way? + elif ":" in entry.path or "@" in entry.path: # type: ignore + asset_type = AzureMLResourceType.DATA + if entry.type in [AssetTypes.MLFLOW_MODEL, AssetTypes.CUSTOM_MODEL]: + asset_type = AzureMLResourceType.MODEL + + entry.path = self._orchestrators.get_asset_arm_id(entry.path, asset_type) # type: ignore + else: # relative local path, upload, transform to remote url + # Base path will be None for dsl pipeline component for now. We have 2 choices if the dsl pipeline + # function is imported from another file: + # 1) Use cwd as default base path; + # 2) Use the file path of the dsl pipeline function as default base path. + # Pick solution 1 for now as defining input path in the script to submit is a more common scenario. + local_path = Path(base_path or Path.cwd(), entry.path).resolve() # type: ignore + entry.path = _upload_and_generate_remote_uri( + self._operation_scope, + self._datastore_operations, + local_path, + datastore_name=datastore_name, + show_progress=self._show_progress, + ) + # TODO : Move this part to a common place + if entry.type == AssetTypes.URI_FOLDER and entry.path and not entry.path.endswith("/"): + entry.path = entry.path + "/" + except (MlException, HttpResponseError) as e: + raise e + except Exception as e: + raise ValidationException( + message=f"Supported input path value are ARM id, AzureML id, remote uri or local path.\n" + f"Met {type(e)}:\n{e}", + target=ErrorTarget.JOB, + no_personal_data_message=( + "Supported input path value are ARM id, AzureML id, " "remote uri or local path." + ), + error=e, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) from e + + def _resolve_job_inputs_arm_id(self, job: Job) -> None: + try: + inputs: Dict[str, Union[Input, InputOutputBase, str, bool, int, float]] = job.inputs # type: ignore + for _, entry in inputs.items(): + if isinstance(entry, InputOutputBase): + # extract original input form input builder. + entry = entry._data + if not isinstance(entry, Input) or is_url(entry.path): # Literal value or remote url + continue + # ARM id + entry.path = self._orchestrators.resolve_azureml_id(entry.path) + + except AttributeError: + # If the job object doesn't have "inputs" attribute, we don't need to resolve. E.g. AutoML jobs + pass + + def _resolve_arm_id_or_azureml_id(self, job: Job, resolver: Union[Callable, _AssetResolver]) -> Job: + """Resolve arm_id for a given job. + + + :param job: The job + :type job: Job + :param resolver: The asset resolver function + :type resolver: _AssetResolver + :return: The provided job, with fields resolved to full ARM IDs + :rtype: Job + """ + # TODO: this will need to be parallelized when multiple tasks + # are required. Also consider the implications for dependencies. + + if isinstance(job, _BaseJob): + job.compute = self._resolve_compute_id(resolver, job.compute) + elif isinstance(job, Command): + job = self._resolve_arm_id_for_command_job(job, resolver) + elif isinstance(job, ImportJob): + job = self._resolve_arm_id_for_import_job(job, resolver) + elif isinstance(job, Spark): + job = self._resolve_arm_id_for_spark_job(job, resolver) + elif isinstance(job, ParallelJob): + job = self._resolve_arm_id_for_parallel_job(job, resolver) + elif isinstance(job, SweepJob): + job = self._resolve_arm_id_for_sweep_job(job, resolver) + elif isinstance(job, AutoMLJob): + job = self._resolve_arm_id_for_automl_job(job, resolver, inside_pipeline=False) + elif isinstance(job, PipelineJob): + job = self._resolve_arm_id_for_pipeline_job(job, resolver) + elif isinstance(job, FineTuningJob): + pass + elif isinstance(job, DistillationJob): + pass + else: + msg = f"Non supported job type: {type(job)}" + raise ValidationException( + message=msg, + target=ErrorTarget.JOB, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + return job + + def _resolve_arm_id_for_command_job(self, job: Command, resolver: _AssetResolver) -> Command: + """Resolve arm_id for CommandJob. + + + :param job: The Command job + :type job: Command + :param resolver: The asset resolver function + :type resolver: _AssetResolver + :return: The provided Command job, with resolved fields + :rtype: Command + """ + if job.code is not None and is_registry_id_for_resource(job.code): + msg = "Format not supported for code asset: {}" + raise ValidationException( + message=msg.format(job.code), + target=ErrorTarget.JOB, + no_personal_data_message=msg.format("[job.code]"), + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + if job.code is not None and not is_ARM_id_for_resource(job.code, AzureMLResourceType.CODE): + job.code = resolver( # type: ignore + Code(base_path=job._base_path, path=job.code), + azureml_type=AzureMLResourceType.CODE, + ) + job.environment = resolver(job.environment, azureml_type=AzureMLResourceType.ENVIRONMENT) + job.compute = self._resolve_compute_id(resolver, job.compute) + return job + + def _resolve_arm_id_for_spark_job(self, job: Spark, resolver: _AssetResolver) -> Spark: + """Resolve arm_id for SparkJob. + + :param job: The Spark job + :type job: Spark + :param resolver: The asset resolver function + :type resolver: _AssetResolver + :return: The provided SparkJob, with resolved fields + :rtype: Spark + """ + if job.code is not None and is_registry_id_for_resource(job.code): + msg = "Format not supported for code asset: {}" + raise JobException( + message=msg.format(job.code), + target=ErrorTarget.JOB, + no_personal_data_message=msg.format("[job.code]"), + error_category=ErrorCategory.USER_ERROR, + ) + + if job.code is not None and not is_ARM_id_for_resource(job.code, AzureMLResourceType.CODE): + job.code = resolver( # type: ignore + Code(base_path=job._base_path, path=job.code), + azureml_type=AzureMLResourceType.CODE, + ) + job.environment = resolver(job.environment, azureml_type=AzureMLResourceType.ENVIRONMENT) + job.compute = self._resolve_compute_id(resolver, job.compute) + return job + + def _resolve_arm_id_for_import_job(self, job: ImportJob, resolver: _AssetResolver) -> ImportJob: + """Resolve arm_id for ImportJob. + + :param job: The Import job + :type job: ImportJob + :param resolver: The asset resolver function + :type resolver: _AssetResolver + :return: The provided ImportJob, with resolved fields + :rtype: ImportJob + """ + # compute property will be no longer applicable once import job type is ready on MFE in PuP + # for PrP, we use command job type instead for import job where compute property is required + # However, MFE only validates compute resource url format. Execution service owns the real + # validation today but supports reserved compute names like AmlCompute, ContainerInstance and + # DataFactory here for 'clusterless' jobs + job.compute = self._resolve_compute_id(resolver, ComputeType.ADF) + return job + + def _resolve_arm_id_for_parallel_job(self, job: ParallelJob, resolver: _AssetResolver) -> ParallelJob: + """Resolve arm_id for ParallelJob. + + :param job: The Parallel job + :type job: ParallelJob + :param resolver: The asset resolver function + :type resolver: _AssetResolver + :return: The provided ParallelJob, with resolved fields + :rtype: ParallelJob + """ + if job.code is not None and not is_ARM_id_for_resource(job.code, AzureMLResourceType.CODE): # type: ignore + job.code = resolver( # type: ignore + Code(base_path=job._base_path, path=job.code), # type: ignore + azureml_type=AzureMLResourceType.CODE, + ) + job.environment = resolver(job.environment, azureml_type=AzureMLResourceType.ENVIRONMENT) # type: ignore + job.compute = self._resolve_compute_id(resolver, job.compute) + return job + + def _resolve_arm_id_for_sweep_job(self, job: SweepJob, resolver: _AssetResolver) -> SweepJob: + """Resolve arm_id for SweepJob. + + :param job: The Sweep job + :type job: SweepJob + :param resolver: The asset resolver function + :type resolver: _AssetResolver + :return: The provided SweepJob, with resolved fields + :rtype: SweepJob + """ + if ( + job.trial is not None + and job.trial.code is not None + and not is_ARM_id_for_resource(job.trial.code, AzureMLResourceType.CODE) + ): + job.trial.code = resolver( # type: ignore[assignment] + Code(base_path=job._base_path, path=job.trial.code), + azureml_type=AzureMLResourceType.CODE, + ) + if ( + job.trial is not None + and job.trial.environment is not None + and not is_ARM_id_for_resource(job.trial.environment, AzureMLResourceType.ENVIRONMENT) + ): + job.trial.environment = resolver( # type: ignore[assignment] + job.trial.environment, azureml_type=AzureMLResourceType.ENVIRONMENT + ) + job.compute = self._resolve_compute_id(resolver, job.compute) + return job + + def _resolve_arm_id_for_automl_job( + self, job: AutoMLJob, resolver: _AssetResolver, inside_pipeline: bool + ) -> AutoMLJob: + """Resolve arm_id for AutoMLJob. + + :param job: The AutoML job + :type job: AutoMLJob + :param resolver: The asset resolver function + :type resolver: _AssetResolver + :param inside_pipeline: Whether the job is within a pipeline + :type inside_pipeline: bool + :return: The provided AutoMLJob, with resolved fields + :rtype: AutoMLJob + """ + # AutoML does not have dependency uploads. Only need to resolve reference to arm id. + + # automl node in pipeline has optional compute + if inside_pipeline and job.compute is None: + return job + job.compute = resolver(job.compute, azureml_type=AzureMLResourceType.COMPUTE) + return job + + def _resolve_arm_id_for_pipeline_job(self, pipeline_job: PipelineJob, resolver: _AssetResolver) -> PipelineJob: + """Resolve arm_id for pipeline_job. + + :param pipeline_job: The pipeline job + :type pipeline_job: PipelineJob + :param resolver: The asset resolver function + :type resolver: _AssetResolver + :return: The provided PipelineJob, with resolved fields + :rtype: PipelineJob + """ + # Get top-level job compute + _get_job_compute_id(pipeline_job, resolver) + + # Process job defaults: + if pipeline_job.settings: + pipeline_job.settings.default_datastore = resolver( + pipeline_job.settings.default_datastore, + azureml_type=AzureMLResourceType.DATASTORE, + ) + pipeline_job.settings.default_compute = resolver( + pipeline_job.settings.default_compute, + azureml_type=AzureMLResourceType.COMPUTE, + ) + + # Process each component job + try: + self._component_operations._resolve_dependencies_for_pipeline_component_jobs( + pipeline_job.component, resolver + ) + except ComponentException as e: + raise JobException( + message=e.message, + target=ErrorTarget.JOB, + no_personal_data_message=e.no_personal_data_message, + error_category=e.error_category, + ) from e + + # Create a pipeline component for pipeline job if user specified component in job yaml. + if ( + not isinstance(pipeline_job.component, str) + and getattr(pipeline_job.component, "_source", None) == ComponentSource.YAML_COMPONENT + ): + pipeline_job.component = resolver( # type: ignore + pipeline_job.component, + azureml_type=AzureMLResourceType.COMPONENT, + ) + + return pipeline_job + + def _append_tid_to_studio_url(self, job: Job) -> None: + """Appends the user's tenant ID to the end of the studio URL. + + Allows the UI to authenticate against the correct tenant. + + :param job: The job + :type job: Job + """ + try: + if job.services is not None: + studio_endpoint = job.services.get("Studio", None) + studio_url = studio_endpoint.endpoint + default_scopes = _resource_to_scopes(_get_base_url_from_metadata()) + module_logger.debug("default_scopes used: `%s`\n", default_scopes) + # Extract the tenant id from the credential using PyJWT + decode = jwt.decode( + self._credential.get_token(*default_scopes).token, + options={"verify_signature": False, "verify_aud": False}, + ) + tid = decode["tid"] + formatted_tid = TID_FMT.format(tid) + studio_endpoint.endpoint = studio_url + formatted_tid + except Exception: # pylint: disable=W0718 + module_logger.info("Proceeding with no tenant id appended to studio URL\n") + + def _set_headers_with_user_aml_token(self, kwargs: Any) -> None: + aml_resource_id = _get_aml_resource_id_from_metadata() + azure_ml_scopes = _resource_to_scopes(aml_resource_id) + module_logger.debug("azure_ml_scopes used: `%s`\n", azure_ml_scopes) + aml_token = self._credential.get_token(*azure_ml_scopes).token + # validate token has aml audience + decoded_token = jwt.decode( + aml_token, + options={"verify_signature": False, "verify_aud": False}, + ) + if decoded_token.get("aud") != aml_resource_id: + msg = """AAD token with aml scope could not be fetched using the credentials being used. + Please validate if token with {0} scope can be fetched using credentials provided to MLClient. + Token with {0} scope can be fetched using credentials.get_token({0}) + """ + raise ValidationException( + message=msg.format(*azure_ml_scopes), + target=ErrorTarget.JOB, + error_type=ValidationErrorType.RESOURCE_NOT_FOUND, + no_personal_data_message=msg.format("[job.code]"), + error_category=ErrorCategory.USER_ERROR, + ) + + headers = kwargs.pop("headers", {}) + headers["x-azureml-token"] = aml_token + kwargs["headers"] = headers + + +def _get_job_compute_id(job: Union[Job, Command], resolver: _AssetResolver) -> None: + job.compute = resolver(job.compute, azureml_type=AzureMLResourceType.COMPUTE) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_job_ops_helper.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_job_ops_helper.py new file mode 100644 index 00000000..4c0802d9 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_job_ops_helper.py @@ -0,0 +1,513 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import json +import logging +import os +import re +import subprocess +import sys +import time +from typing import Any, Dict, Iterable, List, Optional, TextIO, Union + +from azure.ai.ml._artifacts._artifact_utilities import get_datastore_info, list_logs_in_datastore +from azure.ai.ml._restclient.runhistory.models import Run, RunDetails, TypedAssetReference +from azure.ai.ml._restclient.v2022_02_01_preview.models import DataType +from azure.ai.ml._restclient.v2022_02_01_preview.models import JobType as RestJobType +from azure.ai.ml._restclient.v2022_02_01_preview.models import ModelType +from azure.ai.ml._restclient.v2022_10_01.models import JobBase +from azure.ai.ml._utils._http_utils import HttpPipeline +from azure.ai.ml._utils.utils import create_requests_pipeline_with_retry, download_text_from_url +from azure.ai.ml.constants._common import GitProperties +from azure.ai.ml.constants._job.job import JobLogPattern, JobType +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, JobException +from azure.ai.ml.operations._dataset_dataplane_operations import DatasetDataplaneOperations +from azure.ai.ml.operations._datastore_operations import DatastoreOperations +from azure.ai.ml.operations._model_dataplane_operations import ModelDataplaneOperations +from azure.ai.ml.operations._run_history_constants import JobStatus, RunHistoryConstants +from azure.ai.ml.operations._run_operations import RunOperations + +STATUS_KEY = "status" + +module_logger = logging.getLogger(__name__) + + +def _get_sorted_filtered_logs( + logs_iterable: Iterable[str], + job_type: str, + processed_logs: Optional[Dict[str, int]] = None, + only_streamable: bool = True, +) -> List[str]: + """Filters log file names, sorts, and returns list starting with where we left off last iteration. + + :param logs_iterable: An iterable of log paths. + :type logs_iterable: Iterable[str] + :param job_type: the job type to filter log files + :type job_type: str + :param processed_logs: dictionary tracking the state of how many lines of each file have been written out + :type processed_logs: dict[str, int] + :param only_streamable: Whether to only get streamable logs + :type only_streamable: bool + :return: List of logs to continue from + :rtype: list[str] + """ + processed_logs = processed_logs if processed_logs else {} + # First, attempt to read logs in new Common Runtime form + output_logs_pattern = ( + JobLogPattern.COMMON_RUNTIME_STREAM_LOG_PATTERN + if only_streamable + else JobLogPattern.COMMON_RUNTIME_ALL_USER_LOG_PATTERN + ) + logs = list(logs_iterable) + filtered_logs = [x for x in logs if re.match(output_logs_pattern, x)] + + # fall back to legacy log format + if filtered_logs is None or len(filtered_logs) == 0: + job_type = job_type.lower() + if job_type in JobType.COMMAND: + output_logs_pattern = JobLogPattern.COMMAND_JOB_LOG_PATTERN + elif job_type in JobType.PIPELINE: + output_logs_pattern = JobLogPattern.PIPELINE_JOB_LOG_PATTERN + elif job_type in JobType.SWEEP: + output_logs_pattern = JobLogPattern.SWEEP_JOB_LOG_PATTERN + + filtered_logs = [x for x in logs if re.match(output_logs_pattern, x)] + filtered_logs.sort() + previously_printed_index = 0 + for i, v in enumerate(filtered_logs): + if processed_logs.get(v): + previously_printed_index = i + else: + break + # Slice inclusive from the last printed log (can be updated before printing new files) + return filtered_logs[previously_printed_index:] + + +def _incremental_print(log: str, processed_logs: Dict[str, int], current_log_name: str, fileout: TextIO) -> None: + """Incremental print. + + :param log: + :type log: str + :param processed_logs: The record of how many lines have been written for each log file + :type processed_logs: dict[str, int] + :param current_log_name: the file name being read out, used in header writing and accessing processed_logs + :type current_log_name: str + :param fileout: + :type fileout: TestIOWrapper + """ + lines = log.splitlines() + doc_length = len(lines) + if doc_length == 0: + # If a file is empty, skip writing out. + # This addresses issue where batch endpoint jobs can create log files before they are needed. + return + previous_printed_lines = processed_logs.get(current_log_name, 0) + # when a new log is first being written to console, print spacing and label + if previous_printed_lines == 0: + fileout.write("\n") + fileout.write("Streaming " + current_log_name + "\n") + fileout.write("=" * (len(current_log_name) + 10) + "\n") + fileout.write("\n") + # otherwise, continue to log the file where we left off + for line in lines[previous_printed_lines:]: + fileout.write(line + "\n") + # update state to record number of lines written for this log file + processed_logs[current_log_name] = doc_length + + +def _get_last_log_primary_instance(logs: List) -> Any: + """Return last log for primary instance. + + :param logs: + :type logs: builtin.list + :return: Returns the last log primary instance. + :rtype: + """ + primary_ranks = ["rank_0", "worker_0"] + rank_match_re = re.compile(r"(.*)_(.*?_.*?)\.txt") + last_log_name = logs[-1] + + last_log_match = rank_match_re.match(last_log_name) + if not last_log_match: + return last_log_name + + last_log_prefix = last_log_match.group(1) + matching_logs = sorted(filter(lambda x: x.startswith(last_log_prefix), logs)) + + # we have some specific ranks that denote the primary, use those if found + for log_name in matching_logs: + match = rank_match_re.match(log_name) + if not match: + continue + if match.group(2) in primary_ranks: + return log_name + + # no definitively primary instance, just return the highest sorted + return matching_logs[0] + + +def _wait_before_polling(current_seconds: float) -> int: + if current_seconds < 0: + msg = "current_seconds must be positive" + raise JobException( + message=msg, + target=ErrorTarget.JOB, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + import math + + # Sigmoid that tapers off near the_get_logs max at ~ 3 min + duration = int( + int(RunHistoryConstants._WAIT_COMPLETION_POLLING_INTERVAL_MAX) / (1.0 + 100 * math.exp(-current_seconds / 20.0)) + ) + return max(int(RunHistoryConstants._WAIT_COMPLETION_POLLING_INTERVAL_MIN), duration) + + +def list_logs(run_operations: RunOperations, job_resource: JobBase) -> Dict: + details: RunDetails = run_operations.get_run_details(job_resource.name) + logs_dict = details.log_files + keys = _get_sorted_filtered_logs(logs_dict, job_resource.properties.job_type) + return {key: logs_dict[key] for key in keys} + + +# pylint: disable=too-many-statements,too-many-locals +def stream_logs_until_completion( + run_operations: RunOperations, + job_resource: JobBase, + datastore_operations: Optional[DatastoreOperations] = None, + raise_exception_on_failed_job: bool = True, + *, + requests_pipeline: HttpPipeline +) -> None: + """Stream the experiment run output to the specified file handle. By default the the file handle points to stdout. + + :param run_operations: The run history operations class. + :type run_operations: RunOperations + :param job_resource: The job to stream + :type job_resource: JobBase + :param datastore_operations: Optional, the datastore operations class, used to get logs from datastore + :type datastore_operations: Optional[DatastoreOperations] + :param raise_exception_on_failed_job: Should this method fail if job fails + :type raise_exception_on_failed_job: Boolean + :keyword requests_pipeline: The HTTP pipeline to use for requests. + :type requests_pipeline: ~azure.ai.ml._utils._http_utils.HttpPipeline + :return: + :rtype: None + """ + job_type = job_resource.properties.job_type + job_name = job_resource.name + studio_endpoint = job_resource.properties.services.get("Studio", None) + studio_endpoint = studio_endpoint.endpoint if studio_endpoint else None + # Feature store jobs should be linked to the Feature Store Workspace UI. + # Todo: Consolidate this logic to service side + if "azureml.FeatureStoreJobType" in job_resource.properties.properties: + url_format = ( + "https://ml.azure.com/featureStore/{fs_name}/featureSets/{fset_name}/{fset_version}/matJobs/" + "jobs/{run_id}?wsid=/subscriptions/{fs_sub_id}/resourceGroups/{fs_rg_name}/providers/" + "Microsoft.MachineLearningServices/workspaces/{fs_name}" + ) + studio_endpoint = url_format.format( + fs_name=job_resource.properties.properties["azureml.FeatureStoreName"], + fs_sub_id=run_operations._subscription_id, + fs_rg_name=run_operations._resource_group_name, + fset_name=job_resource.properties.properties["azureml.FeatureSetName"], + fset_version=job_resource.properties.properties["azureml.FeatureSetVersion"], + run_id=job_name, + ) + elif "FeatureStoreJobType" in job_resource.properties.properties: + url_format = ( + "https://ml.azure.com/featureStore/{fs_name}/featureSets/{fset_name}/{fset_version}/matJobs/" + "jobs/{run_id}?wsid=/subscriptions/{fs_sub_id}/resourceGroups/{fs_rg_name}/providers/" + "Microsoft.MachineLearningServices/workspaces/{fs_name}" + ) + studio_endpoint = url_format.format( + fs_name=job_resource.properties.properties["FeatureStoreName"], + fs_sub_id=run_operations._subscription_id, + fs_rg_name=run_operations._resource_group_name, + fset_name=job_resource.properties.properties["FeatureSetName"], + fset_version=job_resource.properties.properties["FeatureSetVersion"], + run_id=job_name, + ) + + file_handle = sys.stdout + ds_properties = None + prefix = None + if ( + hasattr(job_resource.properties, "outputs") + and job_resource.properties.job_type != RestJobType.AUTO_ML + and datastore_operations + ): + # Get default output location + + default_output = ( + job_resource.properties.outputs.get("default", None) if job_resource.properties.outputs else None + ) + is_uri_folder = default_output and default_output.job_output_type == DataType.URI_FOLDER + if is_uri_folder: + output_uri = default_output.uri # type: ignore + # Parse the uri format + output_uri = output_uri.split("datastores/")[1] + datastore_name, prefix = output_uri.split("/", 1) + ds_properties = get_datastore_info(datastore_operations, datastore_name) + + try: + file_handle.write("RunId: {}\n".format(job_name)) + file_handle.write("Web View: {}\n".format(studio_endpoint)) + + _current_details: RunDetails = run_operations.get_run_details(job_name) + + processed_logs: Dict = {} + + poll_start_time = time.time() + pipeline_with_retries = create_requests_pipeline_with_retry(requests_pipeline=requests_pipeline) + while ( + _current_details.status in RunHistoryConstants.IN_PROGRESS_STATUSES + or _current_details.status == JobStatus.FINALIZING + ): + file_handle.flush() + time.sleep(_wait_before_polling(time.time() - poll_start_time)) + _current_details = run_operations.get_run_details(job_name) # TODO use FileWatcher + if job_type.lower() in JobType.PIPELINE: + legacy_folder_name = "/logs/azureml/" + else: + legacy_folder_name = "/azureml-logs/" + _current_logs_dict = ( + list_logs_in_datastore( + ds_properties, + prefix=str(prefix), + legacy_log_folder_name=legacy_folder_name, + ) + if ds_properties is not None + else _current_details.log_files + ) + # Get the list of new logs available after filtering out the processed ones + available_logs = _get_sorted_filtered_logs(_current_logs_dict, job_type, processed_logs) + content = "" + for current_log in available_logs: + content = download_text_from_url( + _current_logs_dict[current_log], + pipeline_with_retries, + timeout=RunHistoryConstants._DEFAULT_GET_CONTENT_TIMEOUT, + ) + + _incremental_print(content, processed_logs, current_log, file_handle) + + # TODO: Temporary solution to wait for all the logs to be printed in the finalizing state. + if ( + _current_details.status not in RunHistoryConstants.IN_PROGRESS_STATUSES + and _current_details.status == JobStatus.FINALIZING + and "The activity completed successfully. Finalizing run..." in content + ): + break + + file_handle.write("\n") + file_handle.write("Execution Summary\n") + file_handle.write("=================\n") + file_handle.write("RunId: {}\n".format(job_name)) + file_handle.write("Web View: {}\n".format(studio_endpoint)) + + warnings = _current_details.warnings + if warnings: + messages = [x.message for x in warnings if x.message] + if len(messages) > 0: + file_handle.write("\nWarnings:\n") + for message in messages: + file_handle.write(message + "\n") + file_handle.write("\n") + + if _current_details.status == JobStatus.FAILED: + error = ( + _current_details.error.as_dict() + if _current_details.error + else "Detailed error not set on the Run. Please check the logs for details." + ) + # If we are raising the error later on, so we don't double print. + if not raise_exception_on_failed_job: + file_handle.write("\nError:\n") + file_handle.write(json.dumps(error, indent=4)) + file_handle.write("\n") + else: + raise JobException( + message="Exception : \n {} ".format(json.dumps(error, indent=4)), + target=ErrorTarget.JOB, + no_personal_data_message="Exception raised on failed job.", + error_category=ErrorCategory.SYSTEM_ERROR, + ) + + file_handle.write("\n") + file_handle.flush() + except KeyboardInterrupt as e: + error_message = ( + "The output streaming for the run interrupted.\n" + "But the run is still executing on the compute target. \n" + "Details for canceling the run can be found here: " + "https://aka.ms/aml-docs-cancel-run" + ) + raise JobException( + message=error_message, + target=ErrorTarget.JOB, + no_personal_data_message=error_message, + error_category=ErrorCategory.USER_ERROR, + ) from e + + +def get_git_properties() -> Dict[str, str]: + """Gather Git tracking info from the local environment. + + :return: Properties dictionary. + :rtype: dict + """ + + def _clean_git_property_bool(value: Any) -> Optional[bool]: + if value is None: + return None + return str(value).strip().lower() in ["true", "1"] + + def _clean_git_property_str(value: Any) -> Optional[str]: + if value is None: + return None + return str(value).strip() or None + + def _run_git_cmd(args: Iterable[str]) -> Optional[str]: + """Runs git with the provided arguments + + :param args: A iterable of arguments for a git command. Should not include leading "git" + :type args: Iterable[str] + :return: The output of running git with arguments, or None if it fails. + :rtype: Optional[str] + """ + try: + with open(os.devnull, "wb") as devnull: + return subprocess.check_output(["git"] + list(args), stderr=devnull).decode() + except KeyboardInterrupt: + raise + except BaseException: # pylint: disable=W0718 + return None + + # Check for environment variable overrides. + repository_uri = os.environ.get(GitProperties.ENV_REPOSITORY_URI, None) + branch = os.environ.get(GitProperties.ENV_BRANCH, None) + commit = os.environ.get(GitProperties.ENV_COMMIT, None) + dirty: Optional[Union[str, bool]] = os.environ.get(GitProperties.ENV_DIRTY, None) + build_id = os.environ.get(GitProperties.ENV_BUILD_ID, None) + build_uri = os.environ.get(GitProperties.ENV_BUILD_URI, None) + + is_git_repo = _run_git_cmd(["rev-parse", "--is-inside-work-tree"]) + if _clean_git_property_bool(is_git_repo): + repository_uri = repository_uri or _run_git_cmd(["ls-remote", "--get-url"]) + branch = branch or _run_git_cmd(["symbolic-ref", "--short", "HEAD"]) + commit = commit or _run_git_cmd(["rev-parse", "HEAD"]) + dirty = dirty or _run_git_cmd(["status", "--porcelain", "."]) and True + + # Parsing logic. + repository_uri = _clean_git_property_str(repository_uri) + commit = _clean_git_property_str(commit) + branch = _clean_git_property_str(branch) + dirty = _clean_git_property_bool(dirty) + build_id = _clean_git_property_str(build_id) + build_uri = _clean_git_property_str(build_uri) + + # Return with appropriate labels. + properties = {} + + if repository_uri is not None: + properties[GitProperties.PROP_MLFLOW_GIT_REPO_URL] = repository_uri + + if branch is not None: + properties[GitProperties.PROP_MLFLOW_GIT_BRANCH] = branch + + if commit is not None: + properties[GitProperties.PROP_MLFLOW_GIT_COMMIT] = commit + + if dirty is not None: + properties[GitProperties.PROP_DIRTY] = str(dirty) + + if build_id is not None: + properties[GitProperties.PROP_BUILD_ID] = build_id + + if build_uri is not None: + properties[GitProperties.PROP_BUILD_URI] = build_uri + + return properties + + +def get_job_output_uris_from_dataplane( + job_name: Optional[str], + run_operations: RunOperations, + dataset_dataplane_operations: DatasetDataplaneOperations, + model_dataplane_operations: Optional[ModelDataplaneOperations], + output_names: Optional[Union[Iterable[str], str]] = None, +) -> Dict[str, str]: + """Returns the output path for the given output in cloud storage of the given job. + + If no output names are given, the output paths for all outputs will be returned. + URIs obtained from the service will be in the long-form azureml:// format. + + For example: + azureml://subscriptions/<sub>/resource[gG]roups/<rg_name>/workspaces/<ws_name>/datastores/<ds_name>/paths/<ds_path> + + :param job_name: The job name + :type job_name: str + :param run_operations: The RunOperations used to fetch run data for the job + :type run_operations: RunOperations + :param dataset_dataplane_operations: The DatasetDataplaneOperations used to fetch dataset uris + :type dataset_dataplane_operations: DatasetDataplaneOperations + :param model_dataplane_operations: The ModelDataplaneOperations used to fetch dataset uris + :type model_dataplane_operations: ModelDataplaneOperations + :param output_names: The output name(s) to fetch. If not specified, retrieves all. + :type output_names: Optional[Union[Iterable[str] str]] + :return: Dictionary mapping user-defined output name to output uri + :rtype: Dict[str, str] + """ + run_metadata: Run = run_operations.get_run_data(str(job_name)).run_metadata + run_outputs: Dict[str, TypedAssetReference] = run_metadata.outputs or {} + + # Create a reverse mapping from internal asset id to user-defined output name + asset_id_to_output_name = {v.asset_id: k for k, v in run_outputs.items()} + if not output_names: + # Assume all outputs are needed if no output name is provided + output_names = run_outputs.keys() + else: + if isinstance(output_names, str): + output_names = [output_names] + output_names = [o for o in output_names if o in run_outputs] + + # Collect all output ids that correspond to data assets + dataset_ids = [ + run_outputs[output_name].asset_id + for output_name in output_names + if run_outputs[output_name].type in [o.value for o in DataType] + ] + + # Collect all output ids that correspond to models + model_ids = [ + run_outputs[output_name].asset_id + for output_name in output_names + if run_outputs[output_name].type in [o.value for o in ModelType] + ] + + output_name_to_dataset_uri = {} + if dataset_ids: + # Get the data paths from the service + dataset_uris = dataset_dataplane_operations.get_batch_dataset_uris(dataset_ids) + # Map the user-defined output name to the output uri + # The service returns a mapping from internal asset id to output metadata, so we need the reverse map + # defined above to get the user-defined output name from the internal asset id. + output_name_to_dataset_uri = {asset_id_to_output_name[k]: v.uri for k, v in dataset_uris.values.items()} + + # This is a repeat of the logic above for models. + output_name_to_model_uri = {} + if model_ids: + model_uris = ( + model_dataplane_operations.get_batch_model_uris(model_ids) + if model_dataplane_operations is not None + else None + ) + output_name_to_model_uri = { + asset_id_to_output_name[k]: v.path for k, v in model_uris.values.items() # type: ignore + } + return {**output_name_to_dataset_uri, **output_name_to_model_uri} diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_deployment_helper.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_deployment_helper.py new file mode 100644 index 00000000..f49ff59d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_deployment_helper.py @@ -0,0 +1,390 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access,too-many-locals + +import json +import logging +import os +import shutil +from pathlib import Path +from typing import Any, Iterable, Optional, Union + +from marshmallow.exceptions import ValidationError as SchemaValidationError + +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._local_endpoints import AzureMlImageContext, DockerfileResolver, LocalEndpointMode +from azure.ai.ml._local_endpoints.docker_client import ( + DockerClient, + get_deployment_json_from_container, + get_status_from_container, +) +from azure.ai.ml._local_endpoints.mdc_config_resolver import MdcConfigResolver +from azure.ai.ml._local_endpoints.validators.code_validator import get_code_configuration_artifacts +from azure.ai.ml._local_endpoints.validators.environment_validator import get_environment_artifacts +from azure.ai.ml._local_endpoints.validators.model_validator import get_model_artifacts +from azure.ai.ml._scope_dependent_operations import OperationsContainer +from azure.ai.ml._utils._endpoint_utils import local_endpoint_polling_wrapper +from azure.ai.ml._utils.utils import DockerProxy +from azure.ai.ml.constants._common import AzureMLResourceType, DefaultOpenEncoding +from azure.ai.ml.constants._endpoint import LocalEndpointConstants +from azure.ai.ml.entities import OnlineDeployment +from azure.ai.ml.exceptions import InvalidLocalEndpointError, LocalEndpointNotFoundError, ValidationException + +docker = DockerProxy() +module_logger = logging.getLogger(__name__) + + +class _LocalDeploymentHelper(object): + """A helper class to interact with Azure ML endpoints locally. + + Use this helper to manage Azure ML endpoints locally, e.g. create, invoke, show, list, delete. + """ + + def __init__( + self, + operation_container: OperationsContainer, + ): + self._docker_client = DockerClient() + self._model_operations: Any = operation_container.all_operations.get(AzureMLResourceType.MODEL) + self._code_operations: Any = operation_container.all_operations.get(AzureMLResourceType.CODE) + self._environment_operations: Any = operation_container.all_operations.get(AzureMLResourceType.ENVIRONMENT) + + def create_or_update( # type: ignore + self, + deployment: OnlineDeployment, + local_endpoint_mode: LocalEndpointMode, + local_enable_gpu: Optional[bool] = False, + ) -> OnlineDeployment: + """Create or update an deployment locally using Docker. + + :param deployment: OnlineDeployment object with information from user yaml. + :type deployment: OnlineDeployment + :param local_endpoint_mode: Mode for how to create the local user container. + :type local_endpoint_mode: LocalEndpointMode + :param local_enable_gpu: enable local container to access gpu + :type local_enable_gpu: bool + """ + try: + if deployment is None: + msg = "The entity provided for local endpoint was null. Please provide valid entity." + raise InvalidLocalEndpointError(message=msg, no_personal_data_message=msg) + + endpoint_metadata: Any = None + try: + self.get(endpoint_name=str(deployment.endpoint_name), deployment_name=str(deployment.name)) + endpoint_metadata = json.dumps( + self._docker_client.get_endpoint(endpoint_name=str(deployment.endpoint_name)) + ) + operation_message = "Updating local deployment" + except LocalEndpointNotFoundError: + operation_message = "Creating local deployment" + + deployment_metadata = json.dumps(deployment._to_dict()) + endpoint_metadata = ( + endpoint_metadata + if endpoint_metadata + else _get_stubbed_endpoint_metadata(endpoint_name=str(deployment.endpoint_name)) + ) + local_endpoint_polling_wrapper( + func=self._create_deployment, + message=f"{operation_message} ({deployment.endpoint_name} / {deployment.name}) ", + endpoint_name=deployment.endpoint_name, + deployment=deployment, + local_endpoint_mode=local_endpoint_mode, + local_enable_gpu=local_enable_gpu, + endpoint_metadata=endpoint_metadata, + deployment_metadata=deployment_metadata, + ) + return self.get(endpoint_name=str(deployment.endpoint_name), deployment_name=str(deployment.name)) + except Exception as ex: # pylint: disable=W0718 + if isinstance(ex, (ValidationException, SchemaValidationError)): + log_and_raise_error(ex) + else: + raise ex + + def get_deployment_logs(self, endpoint_name: str, deployment_name: str, lines: int) -> str: + """Get logs from a local endpoint. + + :param endpoint_name: Name of endpoint to invoke. + :type endpoint_name: str + :param deployment_name: Name of specific deployment to invoke. + :type deployment_name: str + :param lines: Number of most recent lines from container logs. + :type lines: int + :return: The deployment logs + :rtype: str + """ + return str(self._docker_client.logs(endpoint_name=endpoint_name, deployment_name=deployment_name, lines=lines)) + + def get(self, endpoint_name: str, deployment_name: str) -> OnlineDeployment: + """Get a local deployment. + + :param endpoint_name: Name of endpoint. + :type endpoint_name: str + :param deployment_name: Name of deployment. + :type deployment_name: str + :return: The deployment + :rtype: OnlineDeployment + """ + container = self._docker_client.get_endpoint_container( + endpoint_name=endpoint_name, + deployment_name=deployment_name, + include_stopped=True, + ) + if container is None: + raise LocalEndpointNotFoundError(endpoint_name=endpoint_name, deployment_name=deployment_name) + return _convert_container_to_deployment(container=container) + + def list(self) -> Iterable[OnlineDeployment]: + """List all local endpoints. + + :return: The OnlineDeployments + :rtype: Iterable[OnlineDeployment] + """ + containers = self._docker_client.list_containers() + deployments = [] + for container in containers: + deployments.append(_convert_container_to_deployment(container=container)) + return deployments + + def delete(self, name: str, deployment_name: Optional[str] = None) -> None: + """Delete a local deployment. + + :param name: Name of endpoint associated with the deployment to delete. + :type name: str + :param deployment_name: Name of specific deployment to delete. + :type deployment_name: str + """ + self._docker_client.delete(endpoint_name=name, deployment_name=deployment_name) + try: + build_directory = _get_deployment_directory(endpoint_name=name, deployment_name=deployment_name) + shutil.rmtree(build_directory) + except (PermissionError, OSError): + pass + + def _create_deployment( + self, + endpoint_name: str, + deployment: OnlineDeployment, + local_endpoint_mode: LocalEndpointMode, + local_enable_gpu: Optional[bool] = False, + endpoint_metadata: Optional[dict] = None, + deployment_metadata: Optional[dict] = None, + ) -> None: + """Create deployment locally using Docker. + + :param endpoint_name: OnlineDeployment object with information from user yaml. + :type endpoint_name: str + :param deployment: Deployment to create + :type deployment: OnlineDeployment + :param local_endpoint_mode: Mode for local endpoint. + :type local_endpoint_mode: LocalEndpointMode + :param local_enable_gpu: enable local container to access gpu + :type local_enable_gpu: bool + :param endpoint_metadata: Endpoint metadata (json serialied Endpoint entity) + :type endpoint_metadata: dict + :param deployment_metadata: Deployment metadata (json serialied Deployment entity) + :type deployment_metadata: dict + """ + deployment_name = deployment.name + deployment_directory = _create_build_directory( + endpoint_name=endpoint_name, deployment_name=str(deployment_name) + ) + deployment_directory_path = str(deployment_directory.resolve()) + + # Get assets for mounting into the container + # If code_directory_path is None, consider NCD flow + code_directory_path = get_code_configuration_artifacts( + endpoint_name=endpoint_name, + deployment=deployment, + code_operations=self._code_operations, + download_path=deployment_directory_path, + ) + # We always require the model, however it may be anonymous for local (model_name=None) + ( + model_name, + model_version, + model_directory_path, + ) = get_model_artifacts( # type: ignore[misc] + endpoint_name=endpoint_name, + deployment=deployment, + model_operations=self._model_operations, + download_path=deployment_directory_path, + ) + + # Resolve the environment information + # - Image + conda file - environment.image / environment.conda_file + # - Docker context - environment.build + ( + yaml_base_image_name, + yaml_env_conda_file_path, + yaml_env_conda_file_contents, + downloaded_build_context, + yaml_dockerfile, + inference_config, + ) = get_environment_artifacts( # type: ignore[misc] + endpoint_name=endpoint_name, + deployment=deployment, + environment_operations=self._environment_operations, + download_path=deployment_directory, # type: ignore[arg-type] + ) + # Retrieve AzureML specific information + # - environment variables required for deployment + # - volumes to mount into container + image_context = AzureMlImageContext( + endpoint_name=endpoint_name, + deployment_name=str(deployment_name), + yaml_code_directory_path=str(code_directory_path), + yaml_code_scoring_script_file_name=( + deployment.code_configuration.scoring_script if code_directory_path else None # type: ignore + ), + model_directory_path=model_directory_path, + model_mount_path=f"/{model_name}/{model_version}" if model_name else "", + ) + + # Construct Dockerfile if necessary, ie. + # - User did not provide environment.inference_config, then this is not BYOC flow, cases below: + # --- user provided environment.build + # --- user provided environment.image + # --- user provided environment.image + environment.conda_file + is_byoc = inference_config is not None + dockerfile: Any = None + if not is_byoc: + install_debugpy = local_endpoint_mode is LocalEndpointMode.VSCodeDevContainer + if yaml_env_conda_file_path: + _write_conda_file( + conda_contents=yaml_env_conda_file_contents, + directory_path=deployment_directory, + conda_file_name=LocalEndpointConstants.CONDA_FILE_NAME, + ) + dockerfile = DockerfileResolver( + dockerfile=yaml_dockerfile, + docker_base_image=yaml_base_image_name, + docker_azureml_app_path=image_context.docker_azureml_app_path, + docker_conda_file_name=LocalEndpointConstants.CONDA_FILE_NAME, + docker_port=LocalEndpointConstants.DOCKER_PORT, + install_debugpy=install_debugpy, + ) + else: + dockerfile = DockerfileResolver( + dockerfile=yaml_dockerfile, + docker_base_image=yaml_base_image_name, + docker_azureml_app_path=image_context.docker_azureml_app_path, + docker_conda_file_name=None, + docker_port=LocalEndpointConstants.DOCKER_PORT, + install_debugpy=install_debugpy, + ) + dockerfile.write_file(directory_path=deployment_directory_path) + + # Merge AzureML environment variables and user environment variables + user_environment_variables = deployment.environment_variables + environment_variables = { + **image_context.environment, + **user_environment_variables, + } + + volumes = {} + volumes.update(image_context.volumes) + + if deployment.data_collector: + mdc_config = MdcConfigResolver(deployment.data_collector) + mdc_config.write_file(deployment_directory_path) + + environment_variables.update(mdc_config.environment_variables) + volumes.update(mdc_config.volumes) + + # Determine whether we need to use local context or downloaded context + build_directory = downloaded_build_context if downloaded_build_context else deployment_directory + self._docker_client.create_deployment( + endpoint_name=endpoint_name, + deployment_name=str(deployment_name), + endpoint_metadata=endpoint_metadata, # type: ignore[arg-type] + deployment_metadata=deployment_metadata, # type: ignore[arg-type] + build_directory=str(build_directory), + dockerfile_path=None if is_byoc else dockerfile.local_path, # type: ignore[arg-type] + conda_source_path=yaml_env_conda_file_path, + conda_yaml_contents=yaml_env_conda_file_contents, + volumes=volumes, + environment=environment_variables, + azureml_port=( + inference_config.scoring_route.port if is_byoc else LocalEndpointConstants.DOCKER_PORT # type: ignore + ), + local_endpoint_mode=local_endpoint_mode, + prebuilt_image_name=yaml_base_image_name if is_byoc else None, + local_enable_gpu=local_enable_gpu, + ) + + +# Bug Item number: 2885719 +def _convert_container_to_deployment( + # Bug Item number: 2885719 + container: "docker.models.containers.Container", # type: ignore +) -> OnlineDeployment: + """Converts provided Container for local deployment to OnlineDeployment entity. + + :param container: Container for a local deployment. + :type container: docker.models.containers.Container + :return: The OnlineDeployment entity + :rtype: OnlineDeployment + """ + deployment_json = get_deployment_json_from_container(container=container) + provisioning_state = get_status_from_container(container=container) + if provisioning_state == LocalEndpointConstants.CONTAINER_EXITED: + return _convert_json_to_deployment( + deployment_json=deployment_json, + instance_type=LocalEndpointConstants.ENDPOINT_STATE_LOCATION, + provisioning_state=LocalEndpointConstants.ENDPOINT_STATE_FAILED, + ) + return _convert_json_to_deployment( + deployment_json=deployment_json, + instance_type=LocalEndpointConstants.ENDPOINT_STATE_LOCATION, + provisioning_state=LocalEndpointConstants.ENDPOINT_STATE_SUCCEEDED, + ) + + +def _write_conda_file(conda_contents: str, directory_path: Union[str, os.PathLike], conda_file_name: str) -> None: + """Writes out conda file to provided directory. + + :param conda_contents: contents of conda yaml file provided by user + :type conda_contents: str + :param directory_path: directory on user's local system to write conda file + :type directory_path: str + :param conda_file_name: The filename to write to + :type conda_file_name: str + """ + conda_file_path = f"{directory_path}/{conda_file_name}" + p = Path(conda_file_path) + p.write_text(conda_contents, encoding=DefaultOpenEncoding.WRITE) + + +def _convert_json_to_deployment(deployment_json: Optional[dict], **kwargs: Any) -> OnlineDeployment: + """Converts metadata json and kwargs to OnlineDeployment entity. + + :param deployment_json: dictionary representation of OnlineDeployment entity. + :type deployment_json: dict + :returns: The OnlineDeployment entity + :rtype: OnlineDeployment + """ + params_override = [] + for k, v in kwargs.items(): + params_override.append({k: v}) + return OnlineDeployment._load(data=deployment_json, params_override=params_override) + + +def _get_stubbed_endpoint_metadata(endpoint_name: str) -> str: + return json.dumps({"name": endpoint_name}) + + +def _create_build_directory(endpoint_name: str, deployment_name: str) -> Path: + build_directory = _get_deployment_directory(endpoint_name=endpoint_name, deployment_name=deployment_name) + build_directory.mkdir(parents=True, exist_ok=True) + return build_directory + + +def _get_deployment_directory(endpoint_name: str, deployment_name: Optional[str]) -> Path: + if deployment_name is not None: + return Path(Path.home(), ".azureml", "inferencing", endpoint_name, deployment_name) + + return Path(Path.home(), ".azureml", "inferencing", endpoint_name, "") diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_endpoint_helper.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_endpoint_helper.py new file mode 100644 index 00000000..341c55b0 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_endpoint_helper.py @@ -0,0 +1,205 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import json +import logging +from typing import Any, Iterable, List, Optional + +from marshmallow.exceptions import ValidationError as SchemaValidationError + +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._local_endpoints import EndpointStub +from azure.ai.ml._local_endpoints.docker_client import ( + DockerClient, + get_endpoint_json_from_container, + get_scoring_uri_from_container, + get_status_from_container, +) +from azure.ai.ml._utils._endpoint_utils import local_endpoint_polling_wrapper +from azure.ai.ml._utils._http_utils import HttpPipeline +from azure.ai.ml._utils.utils import DockerProxy +from azure.ai.ml.constants._endpoint import EndpointInvokeFields, LocalEndpointConstants +from azure.ai.ml.entities import OnlineEndpoint +from azure.ai.ml.exceptions import InvalidLocalEndpointError, LocalEndpointNotFoundError, ValidationException + +docker = DockerProxy() +module_logger = logging.getLogger(__name__) + + +class _LocalEndpointHelper(object): + """A helper class to interact with Azure ML endpoints locally. + + Use this helper to manage Azure ML endpoints locally, e.g. create, invoke, show, list, delete. + """ + + def __init__(self, *, requests_pipeline: HttpPipeline): + self._docker_client = DockerClient() + self._endpoint_stub = EndpointStub() + self._requests_pipeline = requests_pipeline + + def create_or_update(self, endpoint: OnlineEndpoint) -> OnlineEndpoint: # type: ignore + """Create or update an endpoint locally using Docker. + + :param endpoint: OnlineEndpoint object with information from user yaml. + :type endpoint: OnlineEndpoint + """ + try: + if endpoint is None: + msg = "The entity provided for local endpoint was null. Please provide valid entity." + raise InvalidLocalEndpointError(message=msg, no_personal_data_message=msg) + + try: + self.get(endpoint_name=str(endpoint.name)) + operation_message = "Updating local endpoint" + except LocalEndpointNotFoundError: + operation_message = "Creating local endpoint" + + local_endpoint_polling_wrapper( + func=self._endpoint_stub.create_or_update, + message=f"{operation_message} ({endpoint.name}) ", + endpoint=endpoint, + ) + return self.get(endpoint_name=str(endpoint.name)) + except Exception as ex: # pylint: disable=W0718 + if isinstance(ex, (ValidationException, SchemaValidationError)): + log_and_raise_error(ex) + else: + raise ex + + def invoke(self, endpoint_name: str, data: dict, deployment_name: Optional[str] = None) -> str: + """Invoke a local endpoint. + + :param endpoint_name: Name of endpoint to invoke. + :type endpoint_name: str + :param data: json data to pass + :type data: dict + :param deployment_name: Name of specific deployment to invoke. + :type deployment_name: (str, optional) + :return: The text response + :rtype: str + """ + # get_scoring_uri will throw user error if there are multiple deployments and no deployment_name is specified + scoring_uri = self._docker_client.get_scoring_uri(endpoint_name=endpoint_name, deployment_name=deployment_name) + if scoring_uri: + headers = {} + if deployment_name is not None: + headers[EndpointInvokeFields.MODEL_DEPLOYMENT] = deployment_name + return str(self._requests_pipeline.post(scoring_uri, json=data, headers=headers).text()) + endpoint_stub = self._endpoint_stub.get(endpoint_name=endpoint_name) + if endpoint_stub: + return str(self._endpoint_stub.invoke()) + raise LocalEndpointNotFoundError(endpoint_name=endpoint_name, deployment_name=deployment_name) + + def get(self, endpoint_name: str) -> OnlineEndpoint: + """Get a local endpoint. + + :param endpoint_name: Name of endpoint. + :type endpoint_name: str + :return OnlineEndpoint: + """ + endpoint = self._endpoint_stub.get(endpoint_name=endpoint_name) + container = self._docker_client.get_endpoint_container(endpoint_name=endpoint_name, include_stopped=True) + if endpoint: + if container: + return _convert_container_to_endpoint(container=container, endpoint_json=endpoint.dump()) + return endpoint + if container: + return _convert_container_to_endpoint(container=container) + raise LocalEndpointNotFoundError(endpoint_name=endpoint_name) + + def list(self) -> Iterable[OnlineEndpoint]: + """List all local endpoints. + + :return: An iterable of local endpoints + :rtype: Iterable[OnlineEndpoint] + """ + endpoints: List = [] + containers = self._docker_client.list_containers() + endpoint_stubs = self._endpoint_stub.list() + # Iterate through all cached endpoint files + for endpoint_file in endpoint_stubs: + endpoint_json = json.loads(endpoint_file.read_text()) + container = self._docker_client.get_endpoint_container( + endpoint_name=endpoint_json.get("name"), include_stopped=True + ) + # If a deployment is associated with endpoint, + # override certain endpoint properties with deployment information and remove it from containers list. + # Otherwise, return endpoint spec. + if container: + endpoints.append(_convert_container_to_endpoint(endpoint_json=endpoint_json, container=container)) + containers.remove(container) + else: + endpoints.append( + OnlineEndpoint._load( + data=endpoint_json, + params_override=[{"location": LocalEndpointConstants.ENDPOINT_STATE_LOCATION}], + ) + ) + # Iterate through any deployments that don't have an explicit local endpoint stub. + for container in containers: + endpoints.append(_convert_container_to_endpoint(container=container)) + return endpoints + + def delete(self, name: str) -> None: + """Delete a local endpoint. + + :param name: Name of endpoint to delete. + :type name: str + """ + endpoint_stub = self._endpoint_stub.get(endpoint_name=name) + if endpoint_stub: + self._endpoint_stub.delete(endpoint_name=name) + endpoint_container = self._docker_client.get_endpoint_container(endpoint_name=name) + if endpoint_container: + self._docker_client.delete(endpoint_name=name) + else: + raise LocalEndpointNotFoundError(endpoint_name=name) + + +def _convert_container_to_endpoint( + # Bug Item number: 2885719 + container: "docker.models.containers.Container", # type: ignore + endpoint_json: Optional[dict] = None, +) -> OnlineEndpoint: + """Converts provided Container for local deployment to OnlineEndpoint entity. + + :param container: Container for a local deployment. + :type container: docker.models.containers.Container + :param endpoint_json: The endpoint json + :type endpoint_json: Optional[dict] + :return: The OnlineEndpoint entity + :rtype: OnlineEndpoint + """ + if endpoint_json is None: + endpoint_json = get_endpoint_json_from_container(container=container) + provisioning_state = get_status_from_container(container=container) + if provisioning_state == LocalEndpointConstants.CONTAINER_EXITED: + return _convert_json_to_endpoint( + endpoint_json=endpoint_json, + location=LocalEndpointConstants.ENDPOINT_STATE_LOCATION, + provisioning_state=LocalEndpointConstants.ENDPOINT_STATE_FAILED, + ) + scoring_uri = get_scoring_uri_from_container(container=container) + return _convert_json_to_endpoint( + endpoint_json=endpoint_json, + location=LocalEndpointConstants.ENDPOINT_STATE_LOCATION, + provisioning_state=LocalEndpointConstants.ENDPOINT_STATE_SUCCEEDED, + scoring_uri=scoring_uri, + ) + + +def _convert_json_to_endpoint(endpoint_json: Optional[dict], **kwargs: Any) -> OnlineEndpoint: + """Converts metadata json and kwargs to OnlineEndpoint entity. + + :param endpoint_json: dictionary representation of OnlineEndpoint entity. + :type endpoint_json: dict + :return: The OnlineEndpoint entity + :rtype: OnlineEndpoint + """ + params_override = [] + for k, v in kwargs.items(): + params_override.append({k: v}) + return OnlineEndpoint._load(data=endpoint_json, params_override=params_override) # type: ignore diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_job_invoker.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_job_invoker.py new file mode 100644 index 00000000..90914e4c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_job_invoker.py @@ -0,0 +1,432 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import base64 +import io +import json +import logging +import os +import re +import shutil +import subprocess +import tarfile +import tempfile +import urllib.parse +import zipfile +from pathlib import Path +from threading import Thread +from typing import Any, Dict, Optional, Tuple + +from azure.ai.ml._restclient.v2022_02_01_preview.models import JobBaseData +from azure.ai.ml._utils._http_utils import HttpPipeline +from azure.ai.ml._utils.utils import DockerProxy +from azure.ai.ml.constants._common import ( + AZUREML_RUN_SETUP_DIR, + AZUREML_RUNS_DIR, + EXECUTION_SERVICE_URL_KEY, + INVOCATION_BASH_FILE, + INVOCATION_BAT_FILE, + LOCAL_JOB_FAILURE_MSG, + DefaultOpenEncoding, +) +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, JobException, MlException +from azure.core.credentials import TokenCredential +from azure.core.exceptions import AzureError + +docker = DockerProxy() +module_logger = logging.getLogger(__name__) + + +def unzip_to_temporary_file(job_definition: JobBaseData, zip_content: Any) -> Path: + temp_dir = Path(tempfile.gettempdir(), AZUREML_RUNS_DIR, job_definition.name) + temp_dir.mkdir(parents=True, exist_ok=True) + with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_ref: + zip_ref.extractall(temp_dir) + return temp_dir + + +def _get_creationflags_and_startupinfo_for_background_process( + os_override: Optional[str] = None, +) -> Dict: + args: Dict = { + "startupinfo": None, + "creationflags": None, + "stdin": None, + "stdout": None, + "stderr": None, + "shell": False, + } + os_name = os_override if os_override is not None else os.name + if os_name == "nt": + # Windows process creation flag to not reuse the parent console. + + # Without this, the background service is associated with the + # starting process's console, and will block that console from + # exiting until the background service self-terminates. Elsewhere, + # fork just does the right thing. + + CREATE_NEW_CONSOLE = 0x00000010 + args["creationflags"] = CREATE_NEW_CONSOLE + + # Bug Item number: 2895261 + startupinfo = subprocess.STARTUPINFO() # type: ignore + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW # type: ignore + startupinfo.wShowWindow = subprocess.SW_HIDE # type: ignore + args["startupinfo"] = startupinfo + + else: + # On MacOS, the child inherits the parent's stdio descriptors by + # default this can block the parent's stdout/stderr from closing even + # after the parent has exited. + + args["stdin"] = subprocess.DEVNULL + args["stdout"] = subprocess.DEVNULL + args["stderr"] = subprocess.STDOUT + + # filter entries with value None + return {k: v for (k, v) in args.items() if v} + + +def patch_invocation_script_serialization(invocation_path: Path) -> None: + content = invocation_path.read_text() + searchRes = re.search(r"([\s\S]*)(--snapshots \'.*\')([\s\S]*)", content) + if searchRes: + patched_json = searchRes.group(2).replace('"', '\\"') + patched_json = patched_json.replace("'", '"') + invocation_path.write_text(searchRes.group(1) + patched_json + searchRes.group(3)) + + +def invoke_command(project_temp_dir: Path) -> None: + if os.name == "nt": + invocation_script = project_temp_dir / AZUREML_RUN_SETUP_DIR / INVOCATION_BAT_FILE + # There is a bug in Execution service on the serialized json for snapshots. + # This is a client-side patch until the service fixes it, at which point it should + # be a no-op + patch_invocation_script_serialization(invocation_script) + invoked_command = ["cmd.exe", "/c", "{0}".format(invocation_script)] + else: + invocation_script = project_temp_dir / AZUREML_RUN_SETUP_DIR / INVOCATION_BASH_FILE + subprocess.check_output(["chmod", "+x", invocation_script]) + invoked_command = ["/bin/bash", "-c", "{0}".format(invocation_script)] + + env = os.environ.copy() + env.pop("AZUREML_TARGET_TYPE", None) + subprocess.Popen( # pylint: disable=consider-using-with + invoked_command, + cwd=project_temp_dir, + env=env, + **_get_creationflags_and_startupinfo_for_background_process(), + ) + + +def get_execution_service_response( + job_definition: JobBaseData, token: str, requests_pipeline: HttpPipeline +) -> Tuple[Dict[str, str], str]: + """Get zip file containing local run information from Execution Service. + + MFE will send down a mock job contract, with service 'local'. + This will have the URL for contacting Execution Service, with a URL-encoded JSON object following the '&fake=' + string (aka EXECUTION_SERVICE_URL_KEY constant below). The encoded JSON should be the body to pass from the + client to ES. The ES response will be a zip file containing all the scripts required to invoke a local run. + + :param job_definition: Job definition data + :type job_definition: JobBaseData + :param token: The bearer token to use when retrieving information from Execution Service + :type token: str + :param requests_pipeline: The HttpPipeline to use when sending network requests + :type requests_pipeline: HttpPipeline + :return: Execution service response and snapshot ID + :rtype: Tuple[Dict[str, str], str] + """ + try: + local = job_definition.properties.services.get("Local", None) + + (url, encodedBody) = local.endpoint.split(EXECUTION_SERVICE_URL_KEY) + body = urllib.parse.unquote_plus(encodedBody) + body_dict: Dict = json.loads(body) + response = requests_pipeline.post(url, json=body_dict, headers={"Authorization": "Bearer " + token}) + response.raise_for_status() + return (response.content, body_dict.get("SnapshotId", None)) + except AzureError as err: + raise SystemExit(err) from err + except Exception as e: + msg = "Failed to read in local executable job" + raise JobException( + message=msg, + target=ErrorTarget.LOCAL_JOB, + no_personal_data_message=msg, + error_category=ErrorCategory.SYSTEM_ERROR, + ) from e + + +def is_local_run(job_definition: JobBaseData) -> bool: + if not job_definition.properties.services: + return False + local = job_definition.properties.services.get("Local", None) + return local is not None and EXECUTION_SERVICE_URL_KEY in local.endpoint + + +class CommonRuntimeHelper: + COMMON_RUNTIME_BOOTSTRAPPER_INFO = "common_runtime_bootstrapper_info.json" + COMMON_RUNTIME_JOB_SPEC = "common_runtime_job_spec.json" + VM_BOOTSTRAPPER_FILE_NAME = "vm-bootstrapper" + LOCAL_JOB_ENV_VARS = { + "RUST_LOG": "1", + "AZ_BATCHAI_CLUSTER_NAME": "fake_cluster_name", + "AZ_BATCH_NODE_ID": "fake_id", + "AZ_BATCH_NODE_ROOT_DIR": ".", + "AZ_BATCH_CERTIFICATES_DIR": ".", + "AZ_BATCH_NODE_SHARED_DIR": ".", + "AZ_LS_CERT_THUMBPRINT": "fake_thumbprint", + } + DOCKER_IMAGE_WARNING_MSG = ( + "Failed to pull required Docker image. " + "Please try removing all unused containers to free up space and then re-submit your job." + ) + DOCKER_CLIENT_FAILURE_MSG = ( + "Failed to create Docker client. Is Docker running/installed?\n " + "For local submissions, we need to build a Docker container to run your job in.\n Detailed message: {}" + ) + DOCKER_DAEMON_FAILURE_MSG = ( + "Unable to communicate with Docker daemon. Is Docker running/installed?\n " + "For local submissions, we need to build a Docker container to run your job in.\n Detailed message: {}" + ) + DOCKER_LOGIN_FAILURE_MSG = "Login to Docker registry '{}' failed. See error message: {}" + BOOTSTRAP_BINARY_FAILURE_MSG = ( + "Azure Common Runtime execution failed. See detailed message below for troubleshooting " + "information or re-submit with flag --use-local-runtime to try running on your local runtime: {}" + ) + + def __init__(self, job_name: str): + self.common_runtime_temp_folder = os.path.join(Path.home(), ".azureml-common-runtime", job_name) + if os.path.exists(self.common_runtime_temp_folder): + shutil.rmtree(self.common_runtime_temp_folder) + Path(self.common_runtime_temp_folder).mkdir(parents=True) + self.vm_bootstrapper_full_path = os.path.join( + self.common_runtime_temp_folder, + CommonRuntimeHelper.VM_BOOTSTRAPPER_FILE_NAME, + ) + self.stdout = open( # pylint: disable=consider-using-with + os.path.join(self.common_runtime_temp_folder, "stdout"), "w+", encoding=DefaultOpenEncoding.WRITE + ) + self.stderr = open( # pylint: disable=consider-using-with + os.path.join(self.common_runtime_temp_folder, "stderr"), "w+", encoding=DefaultOpenEncoding.WRITE + ) + + # Bug Item number: 2885723 + def get_docker_client(self, registry: Dict) -> "docker.DockerClient": # type: ignore + """Retrieves the Docker client for performing docker operations. + + :param registry: Registry information + :type registry: Dict[str, str] + :return: Docker client + :rtype: docker.DockerClient + """ + try: + client = docker.from_env(version="auto") + except docker.errors.DockerException as e: + msg = self.DOCKER_CLIENT_FAILURE_MSG.format(e) + raise MlException(message=msg, no_personal_data_message=msg) from e + + try: + client.version() + except Exception as e: + msg = self.DOCKER_DAEMON_FAILURE_MSG.format(e) + raise MlException(message=msg, no_personal_data_message=msg) from e + + if registry: + try: + client.login( + username=registry.get("username"), + password=registry.get("password"), + registry=registry.get("url"), + ) + except Exception as e: + raise RuntimeError(self.DOCKER_LOGIN_FAILURE_MSG.format(registry.get("url"), e)) from e + else: + raise RuntimeError("Registry information is missing from bootstrapper configuration.") + + return client + + # Bug Item number: 2885719 + def copy_bootstrapper_from_container(self, container: "docker.models.containers.Container") -> None: # type: ignore + """Copy file/folder from container to local machine. + + :param container: Docker container + :type container: docker.models.containers.Container + """ + path_in_container = CommonRuntimeHelper.VM_BOOTSTRAPPER_FILE_NAME + path_in_host = self.vm_bootstrapper_full_path + + try: + data_stream, _ = container.get_archive(path_in_container) + tar_file = path_in_host + ".tar" + with open(tar_file, "wb") as f: + for chunk in data_stream: + f.write(chunk) + with tarfile.open(tar_file, mode="r") as tar: + for file_name in tar.getnames(): + tar.extract(file_name, os.path.dirname(path_in_host)) + os.remove(tar_file) + except docker.errors.APIError as e: + msg = f"Copying {path_in_container} from container has failed. Detailed message: {e}" + raise MlException(message=msg, no_personal_data_message=msg) from e + + def get_common_runtime_info_from_response(self, response: Any) -> Tuple[Dict[str, str], str]: + """Extract common-runtime info from Execution Service response. + + :param response: Content of zip file from Execution Service containing all the + scripts required to invoke a local run. + :type response: Dict[str, str] + :return: Bootstrapper info and job specification + :rtype: Tuple[Dict[str, str], str] + """ + + with zipfile.ZipFile(io.BytesIO(response)) as zip_ref: + bootstrapper_path = f"{AZUREML_RUN_SETUP_DIR}/{self.COMMON_RUNTIME_BOOTSTRAPPER_INFO}" + job_spec_path = f"{AZUREML_RUN_SETUP_DIR}/{self.COMMON_RUNTIME_JOB_SPEC}" + if not all(file_path in zip_ref.namelist() for file_path in [bootstrapper_path, job_spec_path]): + raise RuntimeError(f"{bootstrapper_path}, {job_spec_path} are not in the execution service response.") + + with zip_ref.open(bootstrapper_path, "r") as bootstrapper_file: + bootstrapper_json = json.loads(base64.b64decode(bootstrapper_file.read())) + with zip_ref.open(job_spec_path, "r") as job_spec_file: + job_spec = job_spec_file.read().decode("utf-8") + + return bootstrapper_json, job_spec + + def get_bootstrapper_binary(self, bootstrapper_info: Dict) -> None: + """Copy bootstrapper binary from the bootstrapper image to local machine. + + :param bootstrapper_info: + :type bootstrapper_info: Dict[str, str] + """ + Path(self.common_runtime_temp_folder).mkdir(parents=True, exist_ok=True) + + # Pull and build the docker image + registry: Any = bootstrapper_info.get("registry") + docker_client = self.get_docker_client(registry) + repo_prefix = bootstrapper_info.get("repo_prefix") + repository = registry.get("url") + tag = bootstrapper_info.get("tag") + + if repo_prefix: + bootstrapper_image = f"{repository}/{repo_prefix}/boot/vm-bootstrapper/binimage/linux:{tag}" + else: + bootstrapper_image = f"{repository}/boot/vm-bootstrapper/binimage/linux:{tag}" + + try: + boot_img = docker_client.images.pull(bootstrapper_image) + except Exception as e: + module_logger.warning(self.DOCKER_IMAGE_WARNING_MSG) + raise e + + boot_container = docker_client.containers.create(image=boot_img, command=[""]) + self.copy_bootstrapper_from_container(boot_container) + + boot_container.stop() + boot_container.remove() + + def execute_bootstrapper(self, bootstrapper_binary: str, job_spec: str) -> subprocess.Popen: + """Runs vm-bootstrapper with the job specification passed to it. This will build the Docker container, create + all necessary files and directories, and run the job locally. Command args are defined by Common Runtime team + here: https://msdata.visualstudio.com/Vienna/_git/vienna?path=/src/azureml- job-runtime/common- + runtime/bootstrapper/vm-bootstrapper/src/main.rs&ver + sion=GBmaster&line=764&lineEnd=845&lineStartColumn=1&lineEndColumn=6&li neStyle=plain&_a=contents. + + :param bootstrapper_binary: Binary file path for VM bootstrapper + (".azureml-common-runtime/<job_name>/vm-bootstrapper") + :type bootstrapper_binary: str + :param job_spec: JSON content of job specification + :type job_spec: str + :return process: Subprocess running the bootstrapper + :rtype process: subprocess.Popen + """ + cmd = [ + bootstrapper_binary, + "--job-spec", + job_spec, + "--skip-auto-update", # Skip the auto update + # "Disable the standard Identity Responder and use a dummy command instead." + "--disable-identity-responder", + "--skip-cleanup", # "Keep containers and volumes for debug." + ] + + env = self.LOCAL_JOB_ENV_VARS + + process = subprocess.Popen( # pylint: disable=consider-using-with + cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + cwd=self.common_runtime_temp_folder, + encoding="utf-8", + ) + _log_subprocess(process.stdout, self.stdout) + _log_subprocess(process.stderr, self.stderr) + + if self.check_bootstrapper_process_status(process): + return process + process.terminate() + process.kill() + raise RuntimeError(LOCAL_JOB_FAILURE_MSG.format(self.stderr.read())) + + def check_bootstrapper_process_status(self, bootstrapper_process: subprocess.Popen) -> Optional[int]: + """Check if bootstrapper process status is non-zero. + + :param bootstrapper_process: bootstrapper process + :type bootstrapper_process: subprocess.Popen + :return: return_code + :rtype: int + """ + return_code = bootstrapper_process.poll() + if return_code: + self.stderr.seek(0) + raise RuntimeError(self.BOOTSTRAP_BINARY_FAILURE_MSG.format(self.stderr.read())) + return return_code + + +def start_run_if_local( + job_definition: JobBaseData, + credential: TokenCredential, + ws_base_url: str, + requests_pipeline: HttpPipeline, +) -> str: + """Request execution bundle from ES and run job. If Linux or WSL environment, unzip and invoke job using job spec + and bootstrapper. Otherwise, invoke command locally. + + :param job_definition: Job definition data + :type job_definition: JobBaseData + :param credential: Credential to use for authentication + :type credential: TokenCredential + :param ws_base_url: Base url to workspace + :type ws_base_url: str + :param requests_pipeline: The HttpPipeline to use when sending network requests + :type requests_pipeline: HttpPipeline + :return: snapshot ID + :rtype: str + """ + token = credential.get_token(ws_base_url + "/.default").token + (zip_content, snapshot_id) = get_execution_service_response(job_definition, token, requests_pipeline) + + try: + temp_dir = unzip_to_temporary_file(job_definition, zip_content) + invoke_command(temp_dir) + except Exception as e: + msg = LOCAL_JOB_FAILURE_MSG.format(e) + raise MlException(message=msg, no_personal_data_message=msg) from e + + return snapshot_id + + +def _log_subprocess(output_io: Any, file: Any, show_in_console: bool = False) -> None: + def log_subprocess() -> None: + for line in iter(output_io.readline, ""): + if show_in_console: + print(line, end="") + file.write(line) + + thread = Thread(target=log_subprocess) + thread.daemon = True + thread.start() diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_marketplace_subscription_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_marketplace_subscription_operations.py new file mode 100644 index 00000000..36683e3a --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_marketplace_subscription_operations.py @@ -0,0 +1,122 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Iterable + +from azure.ai.ml._restclient.v2024_01_01_preview import ( + AzureMachineLearningWorkspaces as ServiceClient202401Preview, +) +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml.entities._autogen_entities.models import MarketplaceSubscription +from azure.core.polling import LROPoller + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class MarketplaceSubscriptionOperations(_ScopeDependentOperations): + """MarketplaceSubscriptionOperations. + + You should not instantiate this class directly. Instead, you should + create an MLClient instance that instantiates it for you and + attaches it as an attribute. + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client: ServiceClient202401Preview, + ): + super().__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._service_client = service_client.marketplace_subscriptions + + @experimental + @monitor_with_activity( + ops_logger, + "MarketplaceSubscription.BeginCreateOrUpdate", + ActivityType.PUBLICAPI, + ) + def begin_create_or_update( + self, marketplace_subscription: MarketplaceSubscription, **kwargs + ) -> LROPoller[MarketplaceSubscription]: + """Create or update a Marketplace Subscription. + + :param marketplace_subscription: The marketplace subscription entity. + :type marketplace_subscription: ~azure.ai.ml.entities.MarketplaceSubscription + :return: A poller to track the operation status + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.MarketplaceSubscription] + """ + return self._service_client.begin_create_or_update( + self._resource_group_name, + self._workspace_name, + marketplace_subscription.name, + marketplace_subscription._to_rest_object(), # type: ignore + cls=lambda response, deserialized, headers: MarketplaceSubscription._from_rest_object( # type: ignore + deserialized + ), + **kwargs, + ) + + @experimental + @monitor_with_activity(ops_logger, "MarketplaceSubscription.Get", ActivityType.PUBLICAPI) + def get(self, name: str, **kwargs) -> MarketplaceSubscription: + """Get a Marketplace Subscription resource. + + :param name: Name of the marketplace subscription. + :type name: str + :return: Marketplace subscription object retrieved from the service. + :rtype: ~azure.ai.ml.entities.MarketplaceSubscription + """ + return self._service_client.get( + self._resource_group_name, + self._workspace_name, + name, + cls=lambda response, deserialized, headers: MarketplaceSubscription._from_rest_object( # type: ignore + deserialized + ), + **kwargs, + ) + + @experimental + @monitor_with_activity(ops_logger, "MarketplaceSubscription.List", ActivityType.PUBLICAPI) + def list(self, **kwargs) -> Iterable[MarketplaceSubscription]: + """List marketplace subscriptions of the workspace. + + :return: A list of marketplace subscriptions + :rtype: ~typing.Iterable[~azure.ai.ml.entities.MarketplaceSubscription] + """ + return self._service_client.list( + self._resource_group_name, + self._workspace_name, + cls=lambda objs: [MarketplaceSubscription._from_rest_object(obj) for obj in objs], # type: ignore + **kwargs, + ) + + @experimental + @monitor_with_activity(ops_logger, "MarketplaceSubscription.BeginDelete", ActivityType.PUBLICAPI) + def begin_delete(self, name: str, **kwargs) -> LROPoller[None]: + """Delete a Marketplace Subscription. + + :param name: Name of the marketplace subscription. + :type name: str + :return: A poller to track the operation status. + :rtype: ~azure.core.polling.LROPoller[None] + """ + return self._service_client.begin_delete( + self._resource_group_name, + self._workspace_name, + name=name, + **kwargs, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_model_dataplane_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_model_dataplane_operations.py new file mode 100644 index 00000000..da15d079 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_model_dataplane_operations.py @@ -0,0 +1,32 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import logging +from typing import List + +from azure.ai.ml._restclient.model_dataplane import AzureMachineLearningWorkspaces as ServiceClientModelDataplane +from azure.ai.ml._restclient.model_dataplane.models import BatchGetResolvedUrisDto, BatchModelPathResponseDto +from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations + +module_logger = logging.getLogger(__name__) + + +class ModelDataplaneOperations(_ScopeDependentOperations): + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client: ServiceClientModelDataplane, + ): + super().__init__(operation_scope, operation_config) + self._operation = service_client.models + + def get_batch_model_uris(self, model_ids: List[str]) -> BatchModelPathResponseDto: + batch_uri_request = BatchGetResolvedUrisDto(values=model_ids) + return self._operation.batch_get_resolved_uris( + self._operation_scope.subscription_id, + self._operation_scope.resource_group_name, + self._workspace_name, + body=batch_uri_request, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_model_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_model_operations.py new file mode 100644 index 00000000..be468089 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_model_operations.py @@ -0,0 +1,833 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access,disable=docstring-missing-return,docstring-missing-param,docstring-missing-rtype,line-too-long,too-many-statements + +import re +from contextlib import contextmanager +from os import PathLike, path +from typing import Any, Dict, Generator, Iterable, Optional, Union, cast + +from marshmallow.exceptions import ValidationError as SchemaValidationError + +from azure.ai.ml._artifacts._artifact_utilities import ( + _check_and_upload_path, + _get_default_datastore_info, + _update_metadata, +) +from azure.ai.ml._artifacts._constants import ( + ASSET_PATH_ERROR, + CHANGED_ASSET_PATH_MSG, + CHANGED_ASSET_PATH_MSG_NO_PERSONAL_DATA, +) +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import ( + AzureMachineLearningWorkspaces as ServiceClient102021Dataplane, +) +from azure.ai.ml._restclient.v2023_08_01_preview import AzureMachineLearningWorkspaces as ServiceClient082023Preview +from azure.ai.ml._restclient.v2023_08_01_preview.models import ListViewType, ModelVersion +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationsContainer, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._arm_id_utils import AMLVersionedArmId, is_ARM_id_for_resource +from azure.ai.ml._utils._asset_utils import ( + _archive_or_restore, + _get_latest, + _get_next_version_from_container, + _resolve_label_to_asset, +) +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml._utils._registry_utils import ( + get_asset_body_for_registry_storage, + get_registry_client, + get_sas_uri_for_registry_asset, + get_storage_details_for_registry_assets, +) +from azure.ai.ml._utils._storage_utils import get_ds_name_and_path_prefix, get_storage_client +from azure.ai.ml._utils.utils import _is_evaluator, resolve_short_datastore_url, validate_ml_flow_folder +from azure.ai.ml.constants._common import ARM_ID_PREFIX, ASSET_ID_FORMAT, REGISTRY_URI_FORMAT, AzureMLResourceType +from azure.ai.ml.entities import AzureDataLakeGen2Datastore +from azure.ai.ml.entities._assets import Environment, Model, ModelPackage +from azure.ai.ml.entities._assets._artifacts.code import Code +from azure.ai.ml.entities._assets.workspace_asset_reference import WorkspaceAssetReference +from azure.ai.ml.entities._credentials import AccountKeyConfiguration +from azure.ai.ml.exceptions import ( + AssetPathException, + ErrorCategory, + ErrorTarget, + ValidationErrorType, + ValidationException, +) +from azure.ai.ml.operations._datastore_operations import DatastoreOperations +from azure.core.exceptions import ResourceNotFoundError + +from ._operation_orchestrator import OperationOrchestrator + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +# pylint: disable=too-many-instance-attributes +class ModelOperations(_ScopeDependentOperations): + """ModelOperations. + + You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it + for you and attaches it as an attribute. + + :param operation_scope: Scope variables for the operations classes of an MLClient object. + :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope + :param operation_config: Common configuration for operations classes of an MLClient object. + :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig + :param service_client: Service client to allow end users to operate on Azure Machine Learning Workspace + resources (ServiceClient082023Preview or ServiceClient102021Dataplane). + :type service_client: typing.Union[ + azure.ai.ml._restclient.v2023_04_01_preview._azure_machine_learning_workspaces.AzureMachineLearningWorkspaces, + azure.ai.ml._restclient.v2021_10_01_dataplanepreview._azure_machine_learning_workspaces. + AzureMachineLearningWorkspaces] + :param datastore_operations: Represents a client for performing operations on Datastores. + :type datastore_operations: ~azure.ai.ml.operations._datastore_operations.DatastoreOperations + :param all_operations: All operations classes of an MLClient object. + :type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer + """ + + _IS_EVALUATOR = "__is_evaluator" + + # pylint: disable=unused-argument + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client: Union[ServiceClient082023Preview, ServiceClient102021Dataplane], + datastore_operations: DatastoreOperations, + all_operations: Optional[OperationsContainer] = None, + **kwargs, + ): + super(ModelOperations, self).__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._model_versions_operation = service_client.model_versions + self._model_container_operation = service_client.model_containers + self._service_client = service_client + self._datastore_operation = datastore_operations + self._all_operations = all_operations + self._control_plane_client: Any = kwargs.get("control_plane_client", None) + self._workspace_rg = kwargs.pop("workspace_rg", None) + self._workspace_sub = kwargs.pop("workspace_sub", None) + self._registry_reference = kwargs.pop("registry_reference", None) + + # Maps a label to a function which given an asset name, + # returns the asset associated with the label + self._managed_label_resolver = {"latest": self._get_latest_version} + self.__is_evaluator = kwargs.pop(ModelOperations._IS_EVALUATOR, False) + + @monitor_with_activity(ops_logger, "Model.CreateOrUpdate", ActivityType.PUBLICAPI) + def create_or_update( # type: ignore + self, model: Union[Model, WorkspaceAssetReference] + ) -> Model: # TODO: Are we going to implement job_name? + """Returns created or updated model asset. + + :param model: Model asset object. + :type model: ~azure.ai.ml.entities.Model + :raises ~azure.ai.ml.exceptions.AssetPathException: Raised when the Model artifact path is + already linked to another asset + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Model cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty directory. + :return: Model asset object. + :rtype: ~azure.ai.ml.entities.Model + """ + # Check if we have the model with the same name and it is an + # evaluator. In this aces raise the exception do not create the model. + if not self.__is_evaluator and _is_evaluator(model.properties): + msg = ( + "Unable to create the evaluator using ModelOperations. To create " + "evaluator, please use EvaluatorOperations by calling " + "ml_client.evaluators.create_or_update(model) instead." + ) + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.MODEL, + error_category=ErrorCategory.USER_ERROR, + ) + if model.name is not None: + model_properties = self._get_model_properties(model.name) + if model_properties is not None and _is_evaluator(model_properties) != _is_evaluator(model.properties): + if _is_evaluator(model.properties): + msg = ( + f"Unable to create the model with name {model.name} " + "because this version of model was marked as promptflow evaluator, but the previous " + "version is a regular model. " + "Please change the model name and try again." + ) + else: + msg = ( + f"Unable to create the model with name {model.name} " + "because previous version of model was marked as promptflow evaluator, but this " + "version is a regular model. " + "Please change the model name and try again." + ) + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.MODEL, + error_category=ErrorCategory.USER_ERROR, + ) + try: + name = model.name + if not model.version and model._auto_increment_version: + model.version = _get_next_version_from_container( + name=model.name, + container_operation=self._model_container_operation, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + registry_name=self._registry_name, + ) + + version = model.version + + sas_uri = None + + if self._registry_name: + # Case of copy model to registry + if isinstance(model, WorkspaceAssetReference): + # verify that model is not already in registry + try: + self._model_versions_operation.get( + name=model.name, + version=model.version, + resource_group_name=self._resource_group_name, + registry_name=self._registry_name, + ) + except Exception as err: # pylint: disable=W0718 + if isinstance(err, ResourceNotFoundError): + pass + else: + raise err + else: + msg = "A model with this name and version already exists in registry" + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.MODEL, + error_category=ErrorCategory.USER_ERROR, + ) + + model_rest = model._to_rest_object() + result = self._service_client.resource_management_asset_reference.begin_import_method( + resource_group_name=self._resource_group_name, + registry_name=self._registry_name, + body=model_rest, + ).result() + + if not result: + model_rest_obj = self._get(name=str(model.name), version=model.version) + return Model._from_rest_object(model_rest_obj) + + sas_uri = get_sas_uri_for_registry_asset( + service_client=self._service_client, + name=model.name, + version=model.version, + resource_group=self._resource_group_name, + registry=self._registry_name, + body=get_asset_body_for_registry_storage(self._registry_name, "models", model.name, model.version), + ) + + model, indicator_file = _check_and_upload_path( # type: ignore[type-var] + artifact=model, + asset_operations=self, + sas_uri=sas_uri, + artifact_type=ErrorTarget.MODEL, + show_progress=self._show_progress, + ) + + model.path = resolve_short_datastore_url(model.path, self._operation_scope) # type: ignore + validate_ml_flow_folder(model.path, model.type) # type: ignore + model_version_resource = model._to_rest_object() + auto_increment_version = model._auto_increment_version + try: + result = ( + self._model_versions_operation.begin_create_or_update( + name=name, + version=version, + body=model_version_resource, + registry_name=self._registry_name, + **self._scope_kwargs, + ).result() + if self._registry_name + else self._model_versions_operation.create_or_update( + name=name, + version=version, + body=model_version_resource, + workspace_name=self._workspace_name, + **self._scope_kwargs, + ) + ) + + if not result and self._registry_name: + result = self._get(name=str(model.name), version=model.version) + + except Exception as e: + # service side raises an exception if we attempt to update an existing asset's path + if str(e) == ASSET_PATH_ERROR: + raise AssetPathException( + message=CHANGED_ASSET_PATH_MSG, + target=ErrorTarget.MODEL, + no_personal_data_message=CHANGED_ASSET_PATH_MSG_NO_PERSONAL_DATA, + error_category=ErrorCategory.USER_ERROR, + ) from e + raise e + + model = Model._from_rest_object(result) + if auto_increment_version and indicator_file: + datastore_info = _get_default_datastore_info(self._datastore_operation) + _update_metadata(model.name, model.version, indicator_file, datastore_info) # update version in storage + + return model + except Exception as ex: # pylint: disable=W0718 + if isinstance(ex, SchemaValidationError): + log_and_raise_error(ex) + else: + raise ex + + def _get(self, name: str, version: Optional[str] = None) -> ModelVersion: # name:latest + if version: + return ( + self._model_versions_operation.get( + name=name, + version=version, + registry_name=self._registry_name, + **self._scope_kwargs, + ) + if self._registry_name + else self._model_versions_operation.get( + name=name, + version=version, + workspace_name=self._workspace_name, + **self._scope_kwargs, + ) + ) + + return ( + self._model_container_operation.get(name=name, registry_name=self._registry_name, **self._scope_kwargs) + if self._registry_name + else self._model_container_operation.get( + name=name, workspace_name=self._workspace_name, **self._scope_kwargs + ) + ) + + @monitor_with_activity(ops_logger, "Model.Get", ActivityType.PUBLICAPI) + def get(self, name: str, version: Optional[str] = None, label: Optional[str] = None) -> Model: + """Returns information about the specified model asset. + + :param name: Name of the model. + :type name: str + :param version: Version of the model. + :type version: str + :param label: Label of the model. (mutually exclusive with version) + :type label: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Model cannot be successfully validated. + Details will be provided in the error message. + :return: Model asset object. + :rtype: ~azure.ai.ml.entities.Model + """ + if version and label: + msg = "Cannot specify both version and label." + raise ValidationException( + message=msg, + target=ErrorTarget.MODEL, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + if label: + return _resolve_label_to_asset(self, name, label) + + if not version: + msg = "Must provide either version or label" + raise ValidationException( + message=msg, + target=ErrorTarget.MODEL, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + # TODO: We should consider adding an exception trigger for internal_model=None + model_version_resource = self._get(name, version) + + return Model._from_rest_object(model_version_resource) + + @monitor_with_activity(ops_logger, "Model.Download", ActivityType.PUBLICAPI) + def download(self, name: str, version: str, download_path: Union[PathLike, str] = ".") -> None: + """Download files related to a model. + + :param name: Name of the model. + :type name: str + :param version: Version of the model. + :type version: str + :param download_path: Local path as download destination, defaults to current working directory of the current + user. Contents will be overwritten. + :type download_path: Union[PathLike, str] + :raises ResourceNotFoundError: if can't find a model matching provided name. + """ + + model_uri = self.get(name=name, version=version).path + ds_name, path_prefix = get_ds_name_and_path_prefix(model_uri, self._registry_name) + if self._registry_name: + sas_uri, auth_type = get_storage_details_for_registry_assets( + service_client=self._service_client, + asset_name=name, + asset_version=version, + reg_name=self._registry_name, + asset_type=AzureMLResourceType.MODEL, + rg_name=self._resource_group_name, + uri=model_uri, + ) + if auth_type == "SAS": + storage_client = get_storage_client(credential=None, storage_account=None, account_url=sas_uri) + else: + parts = sas_uri.split("/") + storage_account = parts[2].split(".")[0] + container_name = parts[3] + storage_client = get_storage_client( + credential=None, + storage_account=storage_account, + container_name=container_name, + ) + + else: + ds = self._datastore_operation.get(ds_name, include_secrets=True) + acc_name = ds.account_name + + if isinstance(ds.credentials, AccountKeyConfiguration): + credential = ds.credentials.account_key + else: + try: + credential = ds.credentials.sas_token + except Exception as e: # pylint: disable=W0718 + if not hasattr(ds.credentials, "sas_token"): + credential = self._datastore_operation._credential + else: + raise e + + if isinstance(ds, AzureDataLakeGen2Datastore): + container = ds.filesystem + try: + from azure.identity import ClientSecretCredential + + token_credential = ClientSecretCredential( + tenant_id=ds.credentials["tenant_id"], + client_id=ds.credentials["client_id"], + client_secret=ds.credentials["client_secret"], + authority=ds.credentials["authority_url"], + ) + credential = token_credential + except (KeyError, TypeError): + pass + + else: + container = ds.container_name + datastore_type = ds.type + storage_client = get_storage_client( + credential=credential, + container_name=container, + storage_account=acc_name, + storage_type=datastore_type, + ) + + path_file = "{}{}{}".format(download_path, path.sep, name) + is_directory = storage_client.exists(f"{path_prefix.rstrip('/')}/") + if is_directory: + path_file = path.join(path_file, path.basename(path_prefix.rstrip("/"))) + module_logger.info("Downloading the model %s at %s\n", path_prefix, path_file) + storage_client.download(starts_with=path_prefix, destination=path_file) + + @monitor_with_activity(ops_logger, "Model.Archive", ActivityType.PUBLICAPI) + def archive( + self, + name: str, + version: Optional[str] = None, + label: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Archive a model asset. + + :param name: Name of model asset. + :type name: str + :param version: Version of model asset. + :type version: str + :param label: Label of the model asset. (mutually exclusive with version) + :type label: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START model_operations_archive] + :end-before: [END model_operations_archive] + :language: python + :dedent: 8 + :caption: Archive a model. + """ + _archive_or_restore( + asset_operations=self, + version_operation=self._model_versions_operation, + container_operation=self._model_container_operation, + is_archived=True, + name=name, + version=version, + label=label, + ) + + @monitor_with_activity(ops_logger, "Model.Restore", ActivityType.PUBLICAPI) + def restore( + self, + name: str, + version: Optional[str] = None, + label: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Restore an archived model asset. + + :param name: Name of model asset. + :type name: str + :param version: Version of model asset. + :type version: str + :param label: Label of the model asset. (mutually exclusive with version) + :type label: str + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_misc.py + :start-after: [START model_operations_restore] + :end-before: [END model_operations_restore] + :language: python + :dedent: 8 + :caption: Restore an archived model. + """ + _archive_or_restore( + asset_operations=self, + version_operation=self._model_versions_operation, + container_operation=self._model_container_operation, + is_archived=False, + name=name, + version=version, + label=label, + ) + + @monitor_with_activity(ops_logger, "Model.List", ActivityType.PUBLICAPI) + def list( + self, + name: Optional[str] = None, + stage: Optional[str] = None, + *, + list_view_type: ListViewType = ListViewType.ACTIVE_ONLY, + ) -> Iterable[Model]: + """List all model assets in workspace. + + :param name: Name of the model. + :type name: Optional[str] + :param stage: The Model stage + :type stage: Optional[str] + :keyword list_view_type: View type for including/excluding (for example) archived models. + Defaults to :attr:`ListViewType.ACTIVE_ONLY`. + :paramtype list_view_type: ListViewType + :return: An iterator like instance of Model objects + :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.Model] + """ + if name: + return cast( + Iterable[Model], + ( + self._model_versions_operation.list( + name=name, + registry_name=self._registry_name, + cls=lambda objs: [Model._from_rest_object(obj) for obj in objs], + **self._scope_kwargs, + ) + if self._registry_name + else self._model_versions_operation.list( + name=name, + workspace_name=self._workspace_name, + cls=lambda objs: [Model._from_rest_object(obj) for obj in objs], + list_view_type=list_view_type, + stage=stage, + **self._scope_kwargs, + ) + ), + ) + + return cast( + Iterable[Model], + ( + self._model_container_operation.list( + registry_name=self._registry_name, + cls=lambda objs: [Model._from_container_rest_object(obj) for obj in objs], + list_view_type=list_view_type, + **self._scope_kwargs, + ) + if self._registry_name + else self._model_container_operation.list( + workspace_name=self._workspace_name, + cls=lambda objs: [Model._from_container_rest_object(obj) for obj in objs], + list_view_type=list_view_type, + **self._scope_kwargs, + ) + ), + ) + + @monitor_with_activity(ops_logger, "Model.Share", ActivityType.PUBLICAPI) + @experimental + def share( + self, name: str, version: str, *, share_with_name: str, share_with_version: str, registry_name: str + ) -> Model: + """Share a model asset from workspace to registry. + + :param name: Name of model asset. + :type name: str + :param version: Version of model asset. + :type version: str + :keyword share_with_name: Name of model asset to share with. + :paramtype share_with_name: str + :keyword share_with_version: Version of model asset to share with. + :paramtype share_with_version: str + :keyword registry_name: Name of the destination registry. + :paramtype registry_name: str + :return: Model asset object. + :rtype: ~azure.ai.ml.entities.Model + """ + + # Get workspace info to get workspace GUID + workspace = self._service_client.workspaces.get( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + ) + workspace_guid = workspace.workspace_id + workspace_location = workspace.location + + # Get model asset ID + asset_id = ASSET_ID_FORMAT.format( + workspace_location, + workspace_guid, + AzureMLResourceType.MODEL, + name, + version, + ) + + model_ref = WorkspaceAssetReference( + name=share_with_name if share_with_name else name, + version=share_with_version if share_with_version else version, + asset_id=asset_id, + ) + + with self._set_registry_client(registry_name): + return self.create_or_update(model_ref) + + def _get_latest_version(self, name: str) -> Model: + """Returns the latest version of the asset with the given name. + + Latest is defined as the most recently created, not the most recently updated. + """ + result = _get_latest( + name, + self._model_versions_operation, + self._resource_group_name, + self._workspace_name, + self._registry_name, + ) + return Model._from_rest_object(result) + + @contextmanager + def _set_registry_client(self, registry_name: str) -> Generator: + """Sets the registry client for the model operations. + + :param registry_name: Name of the registry. + :type registry_name: str + """ + rg_ = self._operation_scope._resource_group_name + sub_ = self._operation_scope._subscription_id + registry_ = self._operation_scope.registry_name + client_ = self._service_client + model_versions_operation_ = self._model_versions_operation + + try: + _client, _rg, _sub = get_registry_client(self._service_client._config.credential, registry_name) + self._operation_scope.registry_name = registry_name + self._operation_scope._resource_group_name = _rg + self._operation_scope._subscription_id = _sub + self._service_client = _client + self._model_versions_operation = _client.model_versions + yield + finally: + self._operation_scope.registry_name = registry_ + self._operation_scope._resource_group_name = rg_ + self._operation_scope._subscription_id = sub_ + self._service_client = client_ + self._model_versions_operation = model_versions_operation_ + + @experimental + @monitor_with_activity(ops_logger, "Model.Package", ActivityType.PUBLICAPI) + def package(self, name: str, version: str, package_request: ModelPackage, **kwargs: Any) -> Environment: + """Package a model asset + + :param name: Name of model asset. + :type name: str + :param version: Version of model asset. + :type version: str + :param package_request: Model package request. + :type package_request: ~azure.ai.ml.entities.ModelPackage + :return: Environment object + :rtype: ~azure.ai.ml.entities.Environment + """ + + is_deployment_flow = kwargs.pop("skip_to_rest", False) + if not is_deployment_flow: + orchestrators = OperationOrchestrator( + operation_container=self._all_operations, # type: ignore[arg-type] + operation_scope=self._operation_scope, + operation_config=self._operation_config, + ) + + # Create a code asset if code is not already an ARM ID + if hasattr(package_request.inferencing_server, "code_configuration"): + if package_request.inferencing_server.code_configuration and not is_ARM_id_for_resource( + package_request.inferencing_server.code_configuration.code, + AzureMLResourceType.CODE, + ): + if package_request.inferencing_server.code_configuration.code.startswith(ARM_ID_PREFIX): + package_request.inferencing_server.code_configuration.code = orchestrators.get_asset_arm_id( + package_request.inferencing_server.code_configuration.code[len(ARM_ID_PREFIX) :], + azureml_type=AzureMLResourceType.CODE, + ) + else: + package_request.inferencing_server.code_configuration.code = orchestrators.get_asset_arm_id( + Code( + base_path=package_request._base_path, + path=package_request.inferencing_server.code_configuration.code, + ), + azureml_type=AzureMLResourceType.CODE, + ) + if package_request.inferencing_server.code_configuration and hasattr( + package_request.inferencing_server.code_configuration, "code" + ): + package_request.inferencing_server.code_configuration.code = ( + "azureml:/" + package_request.inferencing_server.code_configuration.code + ) + + if package_request.base_environment_source and hasattr( + package_request.base_environment_source, "resource_id" + ): + if not package_request.base_environment_source.resource_id.startswith(REGISTRY_URI_FORMAT): + package_request.base_environment_source.resource_id = orchestrators.get_asset_arm_id( + package_request.base_environment_source.resource_id, + azureml_type=AzureMLResourceType.ENVIRONMENT, + ) + + package_request.base_environment_source.resource_id = ( + "azureml:/" + package_request.base_environment_source.resource_id + if not package_request.base_environment_source.resource_id.startswith(ARM_ID_PREFIX) + else package_request.base_environment_source.resource_id + ) + + # create ARM id for the target environment + if self._operation_scope._workspace_location and self._operation_scope._workspace_id: + package_request.target_environment_id = f"azureml://locations/{self._operation_scope._workspace_location}/workspaces/{self._operation_scope._workspace_id}/environments/{package_request.target_environment_id}" + else: + if self._all_operations is not None: + ws: Any = self._all_operations.all_operations.get("workspaces") + ws_details = ws.get(self._workspace_name) + workspace_location, workspace_id = ( + ws_details.location, + ws_details._workspace_id, + ) + package_request.target_environment_id = f"azureml://locations/{workspace_location}/workspaces/{workspace_id}/environments/{package_request.target_environment_id}" + + if package_request.environment_version is not None: + package_request.target_environment_id = ( + package_request.target_environment_id + f"/versions/{package_request.environment_version}" + ) + package_request = package_request._to_rest_object() + + if self._registry_reference: + package_request.target_environment_id = f"azureml://locations/{self._operation_scope._workspace_location}/workspaces/{self._operation_scope._workspace_id}/environments/{package_request.target_environment_id}" + package_out = ( + self._model_versions_operation.begin_package( + name=name, + version=version, + registry_name=self._registry_name if self._registry_name else self._registry_reference, + body=package_request, + **self._scope_kwargs, + ).result() + if self._registry_name or self._registry_reference + else self._model_versions_operation.begin_package( + name=name, + version=version, + workspace_name=self._workspace_name, + body=package_request, + **self._scope_kwargs, + ).result() + ) + if is_deployment_flow: # No need to go through the schema, as this is for deployment notification only + return package_out + if hasattr(package_out, "target_environment_id"): + environment_id = package_out.target_environment_id + else: + environment_id = package_out.additional_properties["targetEnvironmentId"] + + pattern = r"azureml://locations/(\w+)/workspaces/([\w-]+)/environments/([\w.-]+)/versions/(\d+)" + parsed_id: Any = re.search(pattern, environment_id) + + if parsed_id: + environment_name = parsed_id.group(3) + environment_version = parsed_id.group(4) + else: + parsed_id = AMLVersionedArmId(environment_id) + environment_name = parsed_id.asset_name + environment_version = parsed_id.asset_version + + module_logger.info("\nPackage Created") + if package_out is not None and package_out.__class__.__name__ == "PackageResponse": + if self._registry_name: + current_rg = self._scope_kwargs.pop("resource_group_name", None) + self._scope_kwargs["resource_group_name"] = self._workspace_rg + self._control_plane_client._config.subscription_id = self._workspace_sub + env_out = self._control_plane_client.environment_versions.get( + name=environment_name, + version=environment_version, + workspace_name=self._workspace_name, + **self._scope_kwargs, + ) + package_out = Environment._from_rest_object(env_out) + self._scope_kwargs["resource_group_name"] = current_rg + else: + if self._all_operations is not None: + environment_operation = self._all_operations.all_operations[AzureMLResourceType.ENVIRONMENT] + package_out = environment_operation.get(name=environment_name, version=environment_version) + + return package_out + + def _get_model_properties( + self, name: str, version: Optional[str] = None, label: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + """ + Return the model properties if the model with this name exists. + + :param name: Model name. + :type name: str + :param version: Model version. + :type version: Optional[str] + :param label: model label. + :type label: Optional[str] + :return: Model properties, if the model exists, or None. + """ + try: + if version or label: + return self.get(name, version, label).properties + return self._get_latest_version(name).properties + except (ResourceNotFoundError, ValidationException): + return None diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_online_deployment_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_online_deployment_operations.py new file mode 100644 index 00000000..13fe2357 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_online_deployment_operations.py @@ -0,0 +1,415 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access,broad-except + +import random +import re +import subprocess +from typing import Any, Dict, Optional + +from marshmallow.exceptions import ValidationError as SchemaValidationError + +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._local_endpoints import LocalEndpointMode +from azure.ai.ml._restclient.v2022_02_01_preview.models import DeploymentLogsRequest +from azure.ai.ml._restclient.v2023_04_01_preview import AzureMachineLearningWorkspaces as ServiceClient042023Preview +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationsContainer, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._arm_id_utils import AMLVersionedArmId +from azure.ai.ml._utils._azureml_polling import AzureMLPolling +from azure.ai.ml._utils._endpoint_utils import upload_dependencies, validate_scoring_script +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml._utils._package_utils import package_deployment +from azure.ai.ml.constants._common import ARM_ID_PREFIX, AzureMLResourceType, LROConfigurations +from azure.ai.ml.constants._deployment import DEFAULT_MDC_PATH, EndpointDeploymentLogContainerType, SmallSKUs +from azure.ai.ml.entities import Data, OnlineDeployment +from azure.ai.ml.exceptions import ( + ErrorCategory, + ErrorTarget, + InvalidVSCodeRequestError, + LocalDeploymentGPUNotAvailable, + ValidationErrorType, + ValidationException, +) +from azure.core.credentials import TokenCredential +from azure.core.paging import ItemPaged +from azure.core.polling import LROPoller +from azure.core.tracing.decorator import distributed_trace + +from ._local_deployment_helper import _LocalDeploymentHelper +from ._operation_orchestrator import OperationOrchestrator + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class OnlineDeploymentOperations(_ScopeDependentOperations): + """OnlineDeploymentOperations. + + You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it + for you and attaches it as an attribute. + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client_04_2023_preview: ServiceClient042023Preview, + all_operations: OperationsContainer, + local_deployment_helper: _LocalDeploymentHelper, + credentials: Optional[TokenCredential] = None, + **kwargs: Dict, + ): + super(OnlineDeploymentOperations, self).__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._local_deployment_helper = local_deployment_helper + self._online_deployment = service_client_04_2023_preview.online_deployments + self._online_endpoint_operations = service_client_04_2023_preview.online_endpoints + self._all_operations = all_operations + self._credentials = credentials + self._init_kwargs = kwargs + + @distributed_trace + @monitor_with_activity(ops_logger, "OnlineDeployment.BeginCreateOrUpdate", ActivityType.PUBLICAPI) + def begin_create_or_update( + self, + deployment: OnlineDeployment, + *, + local: bool = False, + vscode_debug: bool = False, + skip_script_validation: bool = False, + local_enable_gpu: bool = False, + **kwargs: Any, + ) -> LROPoller[OnlineDeployment]: + """Create or update a deployment. + + :param deployment: the deployment entity + :type deployment: ~azure.ai.ml.entities.OnlineDeployment + :keyword local: Whether deployment should be created locally, defaults to False + :paramtype local: bool + :keyword vscode_debug: Whether to open VSCode instance to debug local deployment, defaults to False + :paramtype vscode_debug: bool + :keyword skip_script_validation: Whether or not to skip validation of the deployment script. Defaults to False. + :paramtype skip_script_validation: bool + :keyword local_enable_gpu: enable local container to access gpu + :paramtype local_enable_gpu: bool + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if OnlineDeployment cannot + be successfully validated. Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.AssetException: Raised if OnlineDeployment assets + (e.g. Data, Code, Model, Environment) cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.ModelException: Raised if OnlineDeployment model cannot be + successfully validated. Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.DeploymentException: Raised if OnlineDeployment type is unsupported. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.LocalEndpointNotFoundError: Raised if local endpoint resource does not exist. + :raises ~azure.ai.ml.exceptions.LocalEndpointInFailedStateError: Raised if local endpoint is in a failed state. + :raises ~azure.ai.ml.exceptions.InvalidLocalEndpointError: Raised if Docker image cannot be + found for local deployment. + :raises ~azure.ai.ml.exceptions.LocalEndpointImageBuildError: Raised if Docker image cannot be + successfully built for local deployment. + :raises ~azure.ai.ml.exceptions.RequiredLocalArtifactsNotFoundError: Raised if local artifacts cannot be + found for local deployment. + :raises ~azure.ai.ml.exceptions.InvalidVSCodeRequestError: Raised if VS Debug is invoked with a remote endpoint. + VSCode debug is only supported for local endpoints. + :raises ~azure.ai.ml.exceptions.LocalDeploymentGPUNotAvailable: Raised if Nvidia GPU is not available in the + system and local_enable_gpu is set while local deployment + :raises ~azure.ai.ml.exceptions.VSCodeCommandNotFound: Raised if VSCode instance cannot be instantiated. + :return: A poller to track the operation status + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.OnlineDeployment] + """ + try: + if vscode_debug and not local: + raise InvalidVSCodeRequestError( + msg="VSCode Debug is only support for local endpoints. Please set local to True." + ) + if local: + if local_enable_gpu: + try: + subprocess.run("nvidia-smi", check=True) + except Exception as ex: + raise LocalDeploymentGPUNotAvailable( + msg=( + "Nvidia GPU is not available in your local system." + " Use nvidia-smi command to see the available GPU" + ) + ) from ex + return self._local_deployment_helper.create_or_update( + deployment=deployment, + local_endpoint_mode=self._get_local_endpoint_mode(vscode_debug), + local_enable_gpu=local_enable_gpu, + ) + if deployment and deployment.instance_type and deployment.instance_type.lower() in SmallSKUs: + module_logger.warning( + "Instance type %s may be too small for compute resources. " + "Minimum recommended compute SKU is Standard_DS3_v2 for general purpose endpoints. Learn more about SKUs here: " # pylint: disable=line-too-long + "https://learn.microsoft.com/azure/machine-learning/referencemanaged-online-endpoints-vm-sku-list", + deployment.instance_type, + ) + if ( + not skip_script_validation + and deployment + and deployment.code_configuration + and not deployment.code_configuration.code.startswith(ARM_ID_PREFIX) # type: ignore[union-attr] + and not re.match(AMLVersionedArmId.REGEX_PATTERN, deployment.code_configuration.code) # type: ignore + ): + validate_scoring_script(deployment) + + path_format_arguments = { + "endpointName": deployment.name, + "resourceGroupName": self._resource_group_name, + "workspaceName": self._workspace_name, + } + + # This get() is to ensure, the endpoint exists and fail before even start the deployment + module_logger.info("Check: endpoint %s exists", deployment.endpoint_name) + self._online_endpoint_operations.get( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=deployment.endpoint_name, + ) + orchestrators = OperationOrchestrator( + operation_container=self._all_operations, + operation_scope=self._operation_scope, + operation_config=self._operation_config, + ) + if deployment.data_collector: + self._register_collection_data_assets(deployment=deployment) + + upload_dependencies(deployment, orchestrators) + try: + location = self._get_workspace_location() + is_package_model = deployment.package_model if hasattr(deployment, "package_model") else False + if kwargs.pop("package_model", False) or is_package_model: + deployment = package_deployment(deployment, self._all_operations.all_operations["models"]) + module_logger.info("\nStarting deployment") + + deployment_rest = deployment._to_rest_object(location=location) # type: ignore + + poller = self._online_deployment.begin_create_or_update( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=deployment.endpoint_name, + deployment_name=deployment.name, + body=deployment_rest, + polling=AzureMLPolling( + LROConfigurations.POLL_INTERVAL, + path_format_arguments=path_format_arguments, + **self._init_kwargs, + ), + polling_interval=LROConfigurations.POLL_INTERVAL, + **self._init_kwargs, + cls=lambda response, deserialized, headers: OnlineDeployment._from_rest_object(deserialized), + ) + return poller + except Exception as ex: + raise ex + except Exception as ex: # pylint: disable=W0718 + if isinstance(ex, (ValidationException, SchemaValidationError)): + log_and_raise_error(ex) + else: + raise ex + + @distributed_trace + @monitor_with_activity(ops_logger, "OnlineDeployment.Get", ActivityType.PUBLICAPI) + def get(self, name: str, endpoint_name: str, *, local: Optional[bool] = False) -> OnlineDeployment: + """Get a deployment resource. + + :param name: The name of the deployment + :type name: str + :param endpoint_name: The name of the endpoint + :type endpoint_name: str + :keyword local: Whether deployment should be retrieved from local docker environment, defaults to False + :paramtype local: Optional[bool] + :raises ~azure.ai.ml.exceptions.LocalEndpointNotFoundError: Raised if local endpoint resource does not exist. + :return: a deployment entity + :rtype: ~azure.ai.ml.entities.OnlineDeployment + """ + if local: + deployment = self._local_deployment_helper.get(endpoint_name=endpoint_name, deployment_name=name) + else: + deployment = OnlineDeployment._from_rest_object( + self._online_deployment.get( + endpoint_name=endpoint_name, + deployment_name=name, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + **self._init_kwargs, + ) + ) + + deployment.endpoint_name = endpoint_name + return deployment + + @distributed_trace + @monitor_with_activity(ops_logger, "OnlineDeployment.Delete", ActivityType.PUBLICAPI) + def begin_delete(self, name: str, endpoint_name: str, *, local: Optional[bool] = False) -> LROPoller[None]: + """Delete a deployment. + + :param name: The name of the deployment + :type name: str + :param endpoint_name: The name of the endpoint + :type endpoint_name: str + :keyword local: Whether deployment should be retrieved from local docker environment, defaults to False + :paramtype local: Optional[bool] + :raises ~azure.ai.ml.exceptions.LocalEndpointNotFoundError: Raised if local endpoint resource does not exist. + :return: A poller to track the operation status + :rtype: ~azure.core.polling.LROPoller[None] + """ + if local: + return self._local_deployment_helper.delete(name=endpoint_name, deployment_name=name) + return self._online_deployment.begin_delete( + endpoint_name=endpoint_name, + deployment_name=name, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + **self._init_kwargs, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "OnlineDeployment.GetLogs", ActivityType.PUBLICAPI) + def get_logs( + self, + name: str, + endpoint_name: str, + lines: int, + *, + container_type: Optional[str] = None, + local: bool = False, + ) -> str: + """Retrive the logs from online deployment. + + :param name: The name of the deployment + :type name: str + :param endpoint_name: The name of the endpoint + :type endpoint_name: str + :param lines: The maximum number of lines to tail + :type lines: int + :keyword container_type: The type of container to retrieve logs from. Possible values include: + "StorageInitializer", "InferenceServer", defaults to None + :type container_type: Optional[str] + :keyword local: [description], defaults to False + :paramtype local: bool + :return: the logs + :rtype: str + """ + if local: + return self._local_deployment_helper.get_deployment_logs( + endpoint_name=endpoint_name, deployment_name=name, lines=lines + ) + if container_type: + container_type = self._validate_deployment_log_container_type(container_type) # type: ignore + log_request = DeploymentLogsRequest(container_type=container_type, tail=lines) + return str( + self._online_deployment.get_logs( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=endpoint_name, + deployment_name=name, + body=log_request, + **self._init_kwargs, + ).content + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "OnlineDeployment.List", ActivityType.PUBLICAPI) + def list(self, endpoint_name: str, *, local: bool = False) -> ItemPaged[OnlineDeployment]: + """List a deployment resource. + + :param endpoint_name: The name of the endpoint + :type endpoint_name: str + :keyword local: Whether deployment should be retrieved from local docker environment, defaults to False + :paramtype local: bool + :return: an iterator of deployment entities + :rtype: Iterable[~azure.ai.ml.entities.OnlineDeployment] + """ + if local: + return self._local_deployment_helper.list() + return self._online_deployment.list( + endpoint_name=endpoint_name, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + cls=lambda objs: [OnlineDeployment._from_rest_object(obj) for obj in objs], + **self._init_kwargs, + ) + + def _validate_deployment_log_container_type(self, container_type: EndpointDeploymentLogContainerType) -> str: + if container_type == EndpointDeploymentLogContainerType.INFERENCE_SERVER: + return EndpointDeploymentLogContainerType.INFERENCE_SERVER_REST + + if container_type == EndpointDeploymentLogContainerType.STORAGE_INITIALIZER: + return EndpointDeploymentLogContainerType.STORAGE_INITIALIZER_REST + + msg = "Invalid container type '{}'. Supported container types are {} and {}" + msg = msg.format( + container_type, + EndpointDeploymentLogContainerType.INFERENCE_SERVER, + EndpointDeploymentLogContainerType.STORAGE_INITIALIZER, + ) + raise ValidationException( + message=msg, + target=ErrorTarget.ONLINE_DEPLOYMENT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + def _get_ARM_deployment_name(self, name: str) -> str: + random.seed(version=2) + return f"{self._workspace_name}-{name}-{random.randint(1, 10000000)}" + + def _get_workspace_location(self) -> str: + """Get the workspace location + + TODO[TASK 1260265]: can we cache this information and only refresh when the operation_scope is changed? + + :return: The workspace location + :rtype: str + """ + return str( + self._all_operations.all_operations[AzureMLResourceType.WORKSPACE].get(self._workspace_name).location + ) + + def _get_local_endpoint_mode(self, vscode_debug: Any) -> LocalEndpointMode: + return LocalEndpointMode.VSCodeDevContainer if vscode_debug else LocalEndpointMode.DetachedContainer + + def _register_collection_data_assets(self, deployment: OnlineDeployment) -> None: + for name, value in deployment.data_collector.collections.items(): + data_name = f"{deployment.endpoint_name}-{deployment.name}-{name}" + data_version = "1" + data_path = f"{DEFAULT_MDC_PATH}/{deployment.endpoint_name}/{deployment.name}/{name}" + if value.data: + if value.data.name: + data_name = value.data.name + + if value.data.version: + data_version = value.data.version + + if value.data.path: + data_path = value.data.path + + data_object = Data( + name=data_name, + version=data_version, + path=data_path, + ) + + try: + result = self._all_operations._all_operations[AzureMLResourceType.DATA].create_or_update(data_object) + except Exception as e: + if "already exists" in str(e): + result = self._all_operations._all_operations[AzureMLResourceType.DATA].get(data_name, data_version) + else: + raise e + deployment.data_collector.collections[name].data = ( + f"/subscriptions/{self._subscription_id}/resourceGroups/{self._resource_group_name}" + f"/providers/Microsoft.MachineLearningServices/workspaces/{self._workspace_name}" + f"/data/{result.name}/versions/{result.version}" + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_online_endpoint_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_online_endpoint_operations.py new file mode 100644 index 00000000..6dce4283 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_online_endpoint_operations.py @@ -0,0 +1,471 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import json +from typing import Any, Dict, Optional, Union + +from marshmallow.exceptions import ValidationError as SchemaValidationError + +from azure.ai.ml._azure_environments import _resource_to_scopes +from azure.ai.ml._exception_helper import log_and_raise_error +from azure.ai.ml._restclient.v2022_02_01_preview import AzureMachineLearningWorkspaces as ServiceClient022022Preview +from azure.ai.ml._restclient.v2022_02_01_preview.models import KeyType, RegenerateEndpointKeysRequest +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationsContainer, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._azureml_polling import AzureMLPolling +from azure.ai.ml._utils._endpoint_utils import validate_response +from azure.ai.ml._utils._http_utils import HttpPipeline +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml.constants._common import ( + AAD_TOKEN, + AAD_TOKEN_RESOURCE_ENDPOINT, + EMPTY_CREDENTIALS_ERROR, + KEY, + AzureMLResourceType, + LROConfigurations, +) +from azure.ai.ml.constants._endpoint import EndpointInvokeFields, EndpointKeyType +from azure.ai.ml.entities import OnlineDeployment, OnlineEndpoint +from azure.ai.ml.entities._assets import Data +from azure.ai.ml.entities._endpoint.online_endpoint import EndpointAadToken, EndpointAuthKeys, EndpointAuthToken +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, MlException, ValidationErrorType, ValidationException +from azure.ai.ml.operations._local_endpoint_helper import _LocalEndpointHelper +from azure.core.credentials import TokenCredential +from azure.core.paging import ItemPaged +from azure.core.polling import LROPoller +from azure.core.tracing.decorator import distributed_trace + +from ._operation_orchestrator import OperationOrchestrator + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +def _strip_zeroes_from_traffic(traffic: Dict[str, str]) -> Dict[str, str]: + return {k.lower(): v for k, v in traffic.items() if v and int(v) != 0} + + +class OnlineEndpointOperations(_ScopeDependentOperations): + """OnlineEndpointOperations. + + You should not instantiate this class directly. Instead, you should + create an MLClient instance that instantiates it for you and + attaches it as an attribute. + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client_02_2022_preview: ServiceClient022022Preview, + all_operations: OperationsContainer, + local_endpoint_helper: _LocalEndpointHelper, + credentials: Optional[TokenCredential] = None, + **kwargs: Dict, + ): + super(OnlineEndpointOperations, self).__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._online_operation = service_client_02_2022_preview.online_endpoints + self._online_deployment_operation = service_client_02_2022_preview.online_deployments + self._all_operations = all_operations + self._local_endpoint_helper = local_endpoint_helper + self._credentials = credentials + self._init_kwargs = kwargs + + self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline") + + @distributed_trace + @monitor_with_activity(ops_logger, "OnlineEndpoint.List", ActivityType.PUBLICAPI) + def list(self, *, local: bool = False) -> ItemPaged[OnlineEndpoint]: + """List endpoints of the workspace. + + :keyword local: (Optional) Flag to indicate whether to interact with endpoints in local Docker environment. + Default: False + :type local: bool + :return: A list of endpoints + :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.OnlineEndpoint] + """ + if local: + return self._local_endpoint_helper.list() + return self._online_operation.list( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + cls=lambda objs: [OnlineEndpoint._from_rest_object(obj) for obj in objs], + **self._init_kwargs, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "OnlineEndpoint.ListKeys", ActivityType.PUBLICAPI) + def get_keys(self, name: str) -> Union[EndpointAuthKeys, EndpointAuthToken, EndpointAadToken]: + """Get the auth credentials. + + :param name: The endpoint name + :type name: str + :raise: Exception if cannot get online credentials + :return: Depending on the auth mode in the endpoint, returns either keys or token + :rtype: Union[~azure.ai.ml.entities.EndpointAuthKeys, ~azure.ai.ml.entities.EndpointAuthToken] + """ + return self._get_online_credentials(name=name) + + @distributed_trace + @monitor_with_activity(ops_logger, "OnlineEndpoint.Get", ActivityType.PUBLICAPI) + def get( + self, + name: str, + *, + local: bool = False, + ) -> OnlineEndpoint: + """Get a Endpoint resource. + + :param name: Name of the endpoint. + :type name: str + :keyword local: Indicates whether to interact with endpoints in local Docker environment. Defaults to False. + :paramtype local: Optional[bool] + :raises ~azure.ai.ml.exceptions.LocalEndpointNotFoundError: Raised if local endpoint resource does not exist. + :return: Endpoint object retrieved from the service. + :rtype: ~azure.ai.ml.entities.OnlineEndpoint + """ + # first get the endpoint + if local: + return self._local_endpoint_helper.get(endpoint_name=name) + + endpoint = self._online_operation.get( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=name, + **self._init_kwargs, + ) + + deployments_list = self._online_deployment_operation.list( + endpoint_name=name, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + cls=lambda objs: [OnlineDeployment._from_rest_object(obj) for obj in objs], + **self._init_kwargs, + ) + + # populate deployments without traffic with zeroes in traffic map + converted_endpoint = OnlineEndpoint._from_rest_object(endpoint) + if deployments_list: + for deployment in deployments_list: + if not converted_endpoint.traffic.get(deployment.name) and not converted_endpoint.mirror_traffic.get( + deployment.name + ): + converted_endpoint.traffic[deployment.name] = 0 + + return converted_endpoint + + @distributed_trace + @monitor_with_activity(ops_logger, "OnlineEndpoint.BeginDelete", ActivityType.PUBLICAPI) + def begin_delete(self, name: Optional[str] = None, *, local: bool = False) -> LROPoller[None]: + """Delete an Online Endpoint. + + :param name: Name of the endpoint. + :type name: str + :keyword local: Whether to interact with the endpoint in local Docker environment. Defaults to False. + :paramtype local: bool + :raises ~azure.ai.ml.exceptions.LocalEndpointNotFoundError: Raised if local endpoint resource does not exist. + :return: A poller to track the operation status if remote, else returns None if local. + :rtype: ~azure.core.polling.LROPoller[None] + """ + if local: + return self._local_endpoint_helper.delete(name=str(name)) + + path_format_arguments = { + "endpointName": name, + "resourceGroupName": self._resource_group_name, + "workspaceName": self._workspace_name, + } + + delete_poller = self._online_operation.begin_delete( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=name, + polling=AzureMLPolling( + LROConfigurations.POLL_INTERVAL, + path_format_arguments=path_format_arguments, + **self._init_kwargs, + ), + polling_interval=LROConfigurations.POLL_INTERVAL, + **self._init_kwargs, + ) + return delete_poller + + @distributed_trace + @monitor_with_activity(ops_logger, "OnlineEndpoint.BeginDeleteOrUpdate", ActivityType.PUBLICAPI) + def begin_create_or_update(self, endpoint: OnlineEndpoint, *, local: bool = False) -> LROPoller[OnlineEndpoint]: + """Create or update an endpoint. + + :param endpoint: The endpoint entity. + :type endpoint: ~azure.ai.ml.entities.OnlineEndpoint + :keyword local: Whether to interact with the endpoint in local Docker environment. Defaults to False. + :paramtype local: bool + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if OnlineEndpoint cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.AssetException: Raised if OnlineEndpoint assets + (e.g. Data, Code, Model, Environment) cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.ModelException: Raised if OnlineEndpoint model cannot be successfully validated. + Details will be provided in the error message. + :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty directory. + :raises ~azure.ai.ml.exceptions.LocalEndpointNotFoundError: Raised if local endpoint resource does not exist. + :return: A poller to track the operation status if remote, else returns None if local. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.OnlineEndpoint] + """ + try: + if local: + return self._local_endpoint_helper.create_or_update(endpoint=endpoint) + + try: + location = self._get_workspace_location() + + if endpoint.traffic: + endpoint.traffic = _strip_zeroes_from_traffic(endpoint.traffic) + + if endpoint.mirror_traffic: + endpoint.mirror_traffic = _strip_zeroes_from_traffic(endpoint.mirror_traffic) + + endpoint_resource = endpoint._to_rest_online_endpoint(location=location) + orchestrators = OperationOrchestrator( + operation_container=self._all_operations, + operation_scope=self._operation_scope, + operation_config=self._operation_config, + ) + if hasattr(endpoint_resource.properties, "compute"): + endpoint_resource.properties.compute = orchestrators.get_asset_arm_id( + endpoint_resource.properties.compute, + azureml_type=AzureMLResourceType.COMPUTE, + ) + poller = self._online_operation.begin_create_or_update( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=endpoint.name, + body=endpoint_resource, + cls=lambda response, deserialized, headers: OnlineEndpoint._from_rest_object(deserialized), + **self._init_kwargs, + ) + return poller + + except Exception as ex: + raise ex + except Exception as ex: # pylint: disable=W0718 + if isinstance(ex, (ValidationException, SchemaValidationError)): + log_and_raise_error(ex) + else: + raise ex + + @distributed_trace + @monitor_with_activity(ops_logger, "OnlineEndpoint.BeginGenerateKeys", ActivityType.PUBLICAPI) + def begin_regenerate_keys( + self, + name: str, + *, + key_type: str = EndpointKeyType.PRIMARY_KEY_TYPE, + ) -> LROPoller[None]: + """Regenerate keys for endpoint. + + :param name: The endpoint name. + :type name: The endpoint type. Defaults to ONLINE_ENDPOINT_TYPE. + :keyword key_type: One of "primary", "secondary". Defaults to "primary". + :paramtype key_type: str + :return: A poller to track the operation status. + :rtype: ~azure.core.polling.LROPoller[None] + """ + endpoint = self._online_operation.get( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=name, + **self._init_kwargs, + ) + + if endpoint.properties.auth_mode.lower() == "key": + return self._regenerate_online_keys(name=name, key_type=key_type) + raise ValidationException( + message=f"Endpoint '{name}' does not use keys for authentication.", + target=ErrorTarget.ONLINE_ENDPOINT, + no_personal_data_message="Endpoint does not use keys for authentication.", + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "OnlineEndpoint.Invoke", ActivityType.PUBLICAPI) + def invoke( + self, + endpoint_name: str, + *, + request_file: Optional[str] = None, + deployment_name: Optional[str] = None, + # pylint: disable=unused-argument + input_data: Optional[Union[str, Data]] = None, + params_override: Any = None, + local: bool = False, + **kwargs: Any, + ) -> str: + """Invokes the endpoint with the provided payload. + + :param endpoint_name: The endpoint name + :type endpoint_name: str + :keyword request_file: File containing the request payload. This is only valid for online endpoint. + :paramtype request_file: Optional[str] + :keyword deployment_name: Name of a specific deployment to invoke. This is optional. + By default requests are routed to any of the deployments according to the traffic rules. + :paramtype deployment_name: Optional[str] + :keyword input_data: To use a pre-registered data asset, pass str in format + :paramtype input_data: Optional[Union[str, Data]] + :keyword params_override: A dictionary of payload parameters to override and their desired values. + :paramtype params_override: Any + :keyword local: Indicates whether to interact with endpoints in local Docker environment. Defaults to False. + :paramtype local: Optional[bool] + :raises ~azure.ai.ml.exceptions.LocalEndpointNotFoundError: Raised if local endpoint resource does not exist. + :raises ~azure.ai.ml.exceptions.MultipleLocalDeploymentsFoundError: Raised if there are multiple deployments + and no deployment_name is specified. + :raises ~azure.ai.ml.exceptions.InvalidLocalEndpointError: Raised if local endpoint is None. + :return: Prediction output for online endpoint. + :rtype: str + """ + params_override = params_override or [] + # Until this bug is resolved https://msdata.visualstudio.com/Vienna/_workitems/edit/1446538 + if deployment_name: + self._validate_deployment_name(endpoint_name, deployment_name) + + with open(request_file, "rb") as f: # type: ignore[arg-type] + data = json.loads(f.read()) + if local: + return self._local_endpoint_helper.invoke( + endpoint_name=endpoint_name, data=data, deployment_name=deployment_name + ) + endpoint = self._online_operation.get( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=endpoint_name, + **self._init_kwargs, + ) + keys = self._get_online_credentials(name=endpoint_name, auth_mode=endpoint.properties.auth_mode) + if isinstance(keys, EndpointAuthKeys): + key = keys.primary_key + elif isinstance(keys, (EndpointAuthToken, EndpointAadToken)): + key = keys.access_token + else: + key = "" + headers = EndpointInvokeFields.DEFAULT_HEADER + if key: + headers[EndpointInvokeFields.AUTHORIZATION] = f"Bearer {key}" + if deployment_name: + headers[EndpointInvokeFields.MODEL_DEPLOYMENT] = deployment_name + + response = self._requests_pipeline.post(endpoint.properties.scoring_uri, json=data, headers=headers) + validate_response(response) + return str(response.text()) + + def _get_workspace_location(self) -> str: + return str( + self._all_operations.all_operations[AzureMLResourceType.WORKSPACE].get(self._workspace_name).location + ) + + def _get_online_credentials( + self, name: str, auth_mode: Optional[str] = None + ) -> Union[EndpointAuthKeys, EndpointAuthToken, EndpointAadToken]: + if not auth_mode: + endpoint = self._online_operation.get( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=name, + **self._init_kwargs, + ) + auth_mode = endpoint.properties.auth_mode + if auth_mode is not None and auth_mode.lower() == KEY: + return self._online_operation.list_keys( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=name, + # pylint: disable=protected-access + cls=lambda x, response, z: EndpointAuthKeys._from_rest_object(response), + **self._init_kwargs, + ) + + if auth_mode is not None and auth_mode.lower() == AAD_TOKEN: + if self._credentials: + return EndpointAadToken(self._credentials.get_token(*_resource_to_scopes(AAD_TOKEN_RESOURCE_ENDPOINT))) + msg = EMPTY_CREDENTIALS_ERROR + raise MlException(message=msg, no_personal_data_message=msg) + + return self._online_operation.get_token( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=name, + # pylint: disable=protected-access + cls=lambda x, response, z: EndpointAuthToken._from_rest_object(response), + **self._init_kwargs, + ) + + def _regenerate_online_keys( + self, + name: str, + key_type: str = EndpointKeyType.PRIMARY_KEY_TYPE, + ) -> LROPoller[None]: + keys = self._online_operation.list_keys( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=name, + **self._init_kwargs, + ) + if key_type.lower() == EndpointKeyType.PRIMARY_KEY_TYPE: + key_request = RegenerateEndpointKeysRequest(key_type=KeyType.Primary, key_value=keys.primary_key) + elif key_type.lower() == EndpointKeyType.SECONDARY_KEY_TYPE: + key_request = RegenerateEndpointKeysRequest(key_type=KeyType.Secondary, key_value=keys.secondary_key) + else: + msg = "Key type must be 'primary' or 'secondary'." + raise ValidationException( + message=msg, + target=ErrorTarget.ONLINE_ENDPOINT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + poller = self._online_operation.begin_regenerate_keys( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=name, + body=key_request, + **self._init_kwargs, + ) + + return poller + + def _validate_deployment_name(self, endpoint_name: str, deployment_name: str) -> None: + deployments_list = self._online_deployment_operation.list( + endpoint_name=endpoint_name, + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + cls=lambda objs: [obj.name for obj in objs], + **self._init_kwargs, + ) + + if deployments_list: + if deployment_name not in deployments_list: + raise ValidationException( + message=f"Deployment name {deployment_name} not found for this endpoint", + target=ErrorTarget.ONLINE_ENDPOINT, + no_personal_data_message="Deployment name not found for this endpoint", + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.RESOURCE_NOT_FOUND, + ) + else: + msg = "No deployment exists for this endpoint" + raise ValidationException( + message=msg, + target=ErrorTarget.ONLINE_ENDPOINT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.RESOURCE_NOT_FOUND, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_operation_orchestrator.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_operation_orchestrator.py new file mode 100644 index 00000000..2ad44bca --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_operation_orchestrator.py @@ -0,0 +1,571 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import logging +import re +from os import PathLike +from typing import Any, Optional, Tuple, Union + +from typing_extensions import Protocol + +from azure.ai.ml._artifacts._artifact_utilities import _check_and_upload_env_build_context, _check_and_upload_path +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationsContainer, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._utils._arm_id_utils import ( + AMLLabelledArmId, + AMLNamedArmId, + AMLVersionedArmId, + get_arm_id_with_version, + is_ARM_id_for_resource, + is_registry_id_for_resource, + is_singularity_full_name_for_resource, + is_singularity_id_for_resource, + is_singularity_short_name_for_resource, + parse_name_label, + parse_prefixed_name_version, +) +from azure.ai.ml._utils._asset_utils import _resolve_label_to_asset, get_storage_info_for_non_registry_asset +from azure.ai.ml._utils._storage_utils import AzureMLDatastorePathUri +from azure.ai.ml.constants._common import ( + ARM_ID_PREFIX, + AZUREML_RESOURCE_PROVIDER, + CURATED_ENV_PREFIX, + DEFAULT_LABEL_NAME, + FILE_PREFIX, + FOLDER_PREFIX, + HTTPS_PREFIX, + JOB_URI_REGEX_FORMAT, + LABELLED_RESOURCE_ID_FORMAT, + LABELLED_RESOURCE_NAME, + MLFLOW_URI_REGEX_FORMAT, + NAMED_RESOURCE_ID_FORMAT, + REGISTRY_VERSION_PATTERN, + SINGULARITY_FULL_NAME_REGEX_FORMAT, + SINGULARITY_ID_FORMAT, + SINGULARITY_SHORT_NAME_REGEX_FORMAT, + VERSIONED_RESOURCE_ID_FORMAT, + VERSIONED_RESOURCE_NAME, + AzureMLResourceType, +) +from azure.ai.ml.entities import Component +from azure.ai.ml.entities._assets import Code, Data, Environment, Model +from azure.ai.ml.entities._assets.asset import Asset +from azure.ai.ml.exceptions import ( + AssetException, + EmptyDirectoryError, + ErrorCategory, + ErrorTarget, + MlException, + ModelException, + ValidationErrorType, + ValidationException, +) +from azure.core.exceptions import HttpResponseError, ResourceNotFoundError + +module_logger = logging.getLogger(__name__) + + +class OperationOrchestrator(object): + def __init__( + self, + operation_container: OperationsContainer, + operation_scope: OperationScope, + operation_config: OperationConfig, + ): + self._operation_container = operation_container + self._operation_scope = operation_scope + self._operation_config = operation_config + + @property + def _datastore_operation(self) -> _ScopeDependentOperations: + return self._operation_container.all_operations[AzureMLResourceType.DATASTORE] + + @property + def _code_assets(self) -> _ScopeDependentOperations: + return self._operation_container.all_operations[AzureMLResourceType.CODE] + + @property + def _model(self) -> _ScopeDependentOperations: + return self._operation_container.all_operations[AzureMLResourceType.MODEL] + + @property + def _environments(self) -> _ScopeDependentOperations: + return self._operation_container.all_operations[AzureMLResourceType.ENVIRONMENT] + + @property + def _data(self) -> _ScopeDependentOperations: + return self._operation_container.all_operations[AzureMLResourceType.DATA] + + @property + def _component(self) -> _ScopeDependentOperations: + return self._operation_container.all_operations[AzureMLResourceType.COMPONENT] + + @property + def _virtual_cluster(self) -> _ScopeDependentOperations: + return self._operation_container.all_operations[AzureMLResourceType.VIRTUALCLUSTER] + + def get_asset_arm_id( + self, + asset: Optional[Union[str, Asset]], + azureml_type: str, + register_asset: bool = True, + sub_workspace_resource: bool = True, + ) -> Optional[Union[str, Asset]]: + """This method converts AzureML Id to ARM Id. Or if the given asset is entity object, it tries to + register/upload the asset based on register_asset and azureml_type. + + :param asset: The asset to resolve/register. It can be a ARM id or a entity's object. + :type asset: Optional[Union[str, Asset]] + :param azureml_type: The AzureML resource type. Defined in AzureMLResourceType. + :type azureml_type: str + :param register_asset: Indicates if the asset should be registered, defaults to True. + :type register_asset: Optional[bool] + :param sub_workspace_resource: + :type sub_workspace_resource: Optional[bool] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if asset's ID cannot be converted + or asset cannot be successfully registered. + :return: The ARM Id or entity object + :rtype: Optional[Union[str, ~azure.ai.ml.entities.Asset]] + """ + # pylint: disable=too-many-return-statements, too-many-branches + if ( + asset is None + or is_ARM_id_for_resource(asset, azureml_type, sub_workspace_resource) + or is_registry_id_for_resource(asset) + or is_singularity_id_for_resource(asset) + ): + return asset + if is_singularity_full_name_for_resource(asset): + return self._get_singularity_arm_id_from_full_name(str(asset)) + if is_singularity_short_name_for_resource(asset): + return self._get_singularity_arm_id_from_short_name(str(asset)) + if isinstance(asset, str): + if azureml_type in AzureMLResourceType.NAMED_TYPES: + return NAMED_RESOURCE_ID_FORMAT.format( + self._operation_scope.subscription_id, + self._operation_scope.resource_group_name, + AZUREML_RESOURCE_PROVIDER, + self._operation_scope.workspace_name, + azureml_type, + asset, + ) + if azureml_type in AzureMLResourceType.VERSIONED_TYPES: + # Short form of curated env will be expanded on the backend side. + # CLI strips off azureml: in the schema, appending it back as required by backend + if azureml_type == AzureMLResourceType.ENVIRONMENT: + azureml_prefix = "azureml:" + # return the same value if resolved result is passed in + _asset = asset[len(azureml_prefix) :] if asset.startswith(azureml_prefix) else asset + if _asset.startswith(CURATED_ENV_PREFIX) or re.match( + REGISTRY_VERSION_PATTERN, f"{azureml_prefix}{_asset}" + ): + return f"{azureml_prefix}{_asset}" + + name, label = parse_name_label(asset) + # TODO: remove this condition after label is fully supported for all versioned resources + if label == DEFAULT_LABEL_NAME and azureml_type == AzureMLResourceType.COMPONENT: + return LABELLED_RESOURCE_ID_FORMAT.format( + self._operation_scope.subscription_id, + self._operation_scope.resource_group_name, + AZUREML_RESOURCE_PROVIDER, + self._operation_scope.workspace_name, + azureml_type, + name, + label, + ) + name, version = self._resolve_name_version_from_name_label(asset, azureml_type) + if not version: + name, version = parse_prefixed_name_version(asset) + + if not version: + msg = ( + "Failed to extract version when parsing asset {} of type {} as arm id. " + "Version must be provided." + ) + raise ValidationException( + message=msg.format(asset, azureml_type), + target=ErrorTarget.ASSET, + no_personal_data_message=msg.format("", azureml_type), + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.MISSING_FIELD, + ) + if self._operation_scope.registry_name: + # Short form for env not supported with registry flow except when it's a curated env. + # Adding a graceful error message for the scenario + if not asset.startswith(CURATED_ENV_PREFIX): + msg = ( + "Use fully qualified name to reference custom environments " + "when creating assets in registry. " + "The syntax for fully qualified names is " + "azureml://registries/azureml/environments/{{env-name}}/versions/{{version}}" + ) + raise ValidationException( + message=msg.format(asset, azureml_type), + target=ErrorTarget.ASSET, + no_personal_data_message=msg.format("", azureml_type), + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + return VERSIONED_RESOURCE_ID_FORMAT.format( + self._operation_scope.subscription_id, + self._operation_scope.resource_group_name, + AZUREML_RESOURCE_PROVIDER, + self._operation_scope.workspace_name, + azureml_type, + name, + version, + ) + msg = "Unsupported azureml type {} for asset: {}" + raise ValidationException( + message=msg.format(azureml_type, asset), + target=ErrorTarget.ASSET, + no_personal_data_message=msg.format(azureml_type, ""), + error_type=ValidationErrorType.INVALID_VALUE, + ) + if isinstance(asset, Asset): + try: + result: Any = None + # TODO: once the asset redesign is finished, this logic can be replaced with unified API + if azureml_type == AzureMLResourceType.CODE and isinstance(asset, Code): + result = self._get_code_asset_arm_id(asset, register_asset=register_asset) + elif azureml_type == AzureMLResourceType.ENVIRONMENT and isinstance(asset, Environment): + result = self._get_environment_arm_id(asset, register_asset=register_asset) + elif azureml_type == AzureMLResourceType.MODEL and isinstance(asset, Model): + result = self._get_model_arm_id(asset, register_asset=register_asset) + elif azureml_type == AzureMLResourceType.DATA and isinstance(asset, Data): + result = self._get_data_arm_id(asset, register_asset=register_asset) + elif azureml_type == AzureMLResourceType.COMPONENT and isinstance(asset, Component): + result = self._get_component_arm_id(asset) + else: + msg = "Unsupported azureml type {} for asset: {}" + raise ValidationException( + message=msg.format(azureml_type, asset), + target=ErrorTarget.ASSET, + no_personal_data_message=msg.format(azureml_type, ""), + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + except EmptyDirectoryError as e: + msg = f"Error creating {azureml_type} asset : {e.message}" + raise AssetException( + message=msg.format(azureml_type, e.message), + target=ErrorTarget.ASSET, + no_personal_data_message=msg.format(azureml_type, ""), + error=e, + error_category=ErrorCategory.SYSTEM_ERROR, + ) from e + return result + msg = f"Error creating {azureml_type} asset: must be type Optional[Union[str, Asset]]" + raise ValidationException( + message=msg, + target=ErrorTarget.ASSET, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + def _get_code_asset_arm_id(self, code_asset: Code, register_asset: bool = True) -> Union[Code, str]: + try: + self._validate_datastore_name(code_asset.path) + if register_asset: + code_asset = self._code_assets.create_or_update(code_asset) # type: ignore[attr-defined] + return str(code_asset.id) + sas_info = get_storage_info_for_non_registry_asset( + service_client=self._code_assets._service_client, # type: ignore[attr-defined] + workspace_name=self._operation_scope.workspace_name, + name=code_asset.name, + version=code_asset.version, + resource_group=self._operation_scope.resource_group_name, + ) + uploaded_code_asset, _ = _check_and_upload_path( + artifact=code_asset, + asset_operations=self._code_assets, # type: ignore[arg-type] + artifact_type=ErrorTarget.CODE, + show_progress=self._operation_config.show_progress, + sas_uri=sas_info["sas_uri"], + blob_uri=sas_info["blob_uri"], + ) + uploaded_code_asset._id = get_arm_id_with_version( + self._operation_scope, + AzureMLResourceType.CODE, + code_asset.name, + code_asset.version, + ) + return uploaded_code_asset + except (MlException, HttpResponseError) as e: + raise e + except Exception as e: + raise AssetException( + message=f"Error with code: {e}", + target=ErrorTarget.ASSET, + no_personal_data_message="Error getting code asset", + error=e, + error_category=ErrorCategory.SYSTEM_ERROR, + ) from e + + def _get_environment_arm_id(self, environment: Environment, register_asset: bool = True) -> Union[str, Environment]: + if register_asset: + if environment.id: + return environment.id + env_response = self._environments.create_or_update(environment) # type: ignore[attr-defined] + return env_response.id + environment = _check_and_upload_env_build_context( + environment=environment, + operations=self._environments, # type: ignore[arg-type] + show_progress=self._operation_config.show_progress, + ) + environment._id = get_arm_id_with_version( + self._operation_scope, + AzureMLResourceType.ENVIRONMENT, + environment.name, + environment.version, + ) + return environment + + def _get_model_arm_id(self, model: Model, register_asset: bool = True) -> Union[str, Model]: + try: + self._validate_datastore_name(model.path) + + if register_asset: + if model.id: + return model.id + return self._model.create_or_update(model).id # type: ignore[attr-defined] + uploaded_model, _ = _check_and_upload_path( + artifact=model, + asset_operations=self._model, # type: ignore[arg-type] + artifact_type=ErrorTarget.MODEL, + show_progress=self._operation_config.show_progress, + ) + uploaded_model._id = get_arm_id_with_version( + self._operation_scope, + AzureMLResourceType.MODEL, + model.name, + model.version, + ) + return uploaded_model + except (MlException, HttpResponseError) as e: + raise e + except Exception as e: + raise ModelException( + message=f"Error with model: {e}", + target=ErrorTarget.MODEL, + no_personal_data_message="Error getting model", + error=e, + error_category=ErrorCategory.SYSTEM_ERROR, + ) from e + + def _get_data_arm_id(self, data_asset: Data, register_asset: bool = True) -> Union[str, Data]: + self._validate_datastore_name(data_asset.path) + + if register_asset: + return self._data.create_or_update(data_asset).id # type: ignore[attr-defined] + data_asset, _ = _check_and_upload_path( + artifact=data_asset, + asset_operations=self._data, # type: ignore[arg-type] + artifact_type=ErrorTarget.DATA, + show_progress=self._operation_config.show_progress, + ) + return data_asset + + def _get_component_arm_id(self, component: Component) -> str: + """Gets the component ARM ID. + + :param component: The component + :type component: Component + :return: The component id + :rtype: str + """ + + # If component arm id is already resolved, return the id otherwise get arm id via remote call. + # Register the component if necessary, and FILL BACK the arm id to component to reduce remote call. + if not component.id: + component._id = self._component.create_or_update( # type: ignore[attr-defined] + component, is_anonymous=True, show_progress=self._operation_config.show_progress + ).id + return str(component.id) + + def _get_singularity_arm_id_from_full_name(self, singularity: str) -> str: + match = re.match(SINGULARITY_FULL_NAME_REGEX_FORMAT, singularity) + subscription_id = match.group("subscription_id") if match is not None else "" + resource_group_name = match.group("resource_group_name") if match is not None else "" + vc_name = match.group("name") if match is not None else "" + arm_id = SINGULARITY_ID_FORMAT.format(subscription_id, resource_group_name, vc_name) + vc = self._virtual_cluster.get(arm_id) # type: ignore[attr-defined] + return str(vc["id"]) + + def _get_singularity_arm_id_from_short_name(self, singularity: str) -> str: + match = re.match(SINGULARITY_SHORT_NAME_REGEX_FORMAT, singularity) + vc_name = match.group("name") if match is not None else "" + # below list operation can be time-consuming, may need an optimization on this + match_vcs = [vc for vc in self._virtual_cluster.list() if vc["name"] == vc_name] # type: ignore[attr-defined] + num_match_vc = len(match_vcs) + if num_match_vc != 1: + if num_match_vc == 0: + msg = "The virtual cluster {} could not be found." + else: + msg = "More than one match virtual clusters {} found." + raise ValidationException( + message=msg.format(vc_name), + no_personal_data_message=msg.format(""), + target=ErrorTarget.COMPUTE, + error_type=ValidationErrorType.INVALID_VALUE, + ) + return str(match_vcs[0]["id"]) + + def _resolve_name_version_from_name_label(self, aml_id: str, azureml_type: str) -> Tuple[str, Optional[str]]: + """Given an AzureML id of the form name@label, resolves the label to the actual ID. + + :param aml_id: AzureML id of the form name@label + :type aml_id: str + :param azureml_type: The AzureML resource type. Defined in AzureMLResourceType. + :type azureml_type: str + :return: Returns tuple (name, version) on success, (name@label, None) if resolution fails + :rtype: Tuple[str, Optional[str]] + """ + name, label = parse_name_label(aml_id) + if ( + azureml_type not in AzureMLResourceType.VERSIONED_TYPES + or azureml_type == AzureMLResourceType.CODE + or not label + ): + return aml_id, None + + return ( + name, + _resolve_label_to_asset( + self._operation_container.all_operations[azureml_type], + name, + label=label, + ).version, + ) + + # pylint: disable=unused-argument + def resolve_azureml_id(self, arm_id: Optional[str] = None, **kwargs: Any) -> Optional[str]: + """This function converts ARM id to name or name:version AzureML id. It parses the ARM id and matches the + subscription Id, resource group name and workspace_name. + + TODO: It is debatable whether this method should be in operation_orchestrator. + + :param arm_id: entity's ARM id, defaults to None + :type arm_id: str + :return: AzureML id + :rtype: str + """ + + if arm_id: + if not isinstance(arm_id, str): + msg = "arm_id cannot be resolved: str expected but got {}".format(type(arm_id)) # type: ignore + raise ValidationException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.GENERAL, + error_type=ValidationErrorType.INVALID_VALUE, + ) + try: + arm_id_obj = AMLVersionedArmId(arm_id) + if arm_id_obj.is_registry_id: + return arm_id + if self._match(arm_id_obj): + return str(VERSIONED_RESOURCE_NAME.format(arm_id_obj.asset_name, arm_id_obj.asset_version)) + except ValidationException: + pass # fall back to named arm id + try: + arm_id_obj = AMLLabelledArmId(arm_id) + if self._match(arm_id_obj): + return str(LABELLED_RESOURCE_NAME.format(arm_id_obj.asset_name, arm_id_obj.asset_label)) + except ValidationException: + pass # fall back to named arm id + try: + arm_id_obj = AMLNamedArmId(arm_id) + if self._match(arm_id_obj): + return str(arm_id_obj.asset_name) + except ValidationException: + pass # fall back to be not a ARM_id + return arm_id + + def _match(self, id_: Any) -> bool: + return bool( + ( + id_.subscription_id == self._operation_scope.subscription_id + and id_.resource_group_name == self._operation_scope.resource_group_name + and id_.workspace_name == self._operation_scope.workspace_name + ) + ) + + def _validate_datastore_name(self, datastore_uri: Optional[Union[str, PathLike]]) -> None: + if datastore_uri: + try: + if isinstance(datastore_uri, str): + if datastore_uri.startswith(FILE_PREFIX): + datastore_uri = datastore_uri[len(FILE_PREFIX) :] + elif datastore_uri.startswith(FOLDER_PREFIX): + datastore_uri = datastore_uri[len(FOLDER_PREFIX) :] + elif isinstance(datastore_uri, PathLike): + return + + if datastore_uri.startswith(HTTPS_PREFIX) and datastore_uri.count("/") == 7: + # only long-form (i.e. "https://x.blob.core.windows.net/datastore/LocalUpload/guid/x/x") + # format includes datastore + datastore_name = datastore_uri.split("/")[3] + elif datastore_uri.startswith(ARM_ID_PREFIX) and not ( + re.match(MLFLOW_URI_REGEX_FORMAT, datastore_uri) or re.match(JOB_URI_REGEX_FORMAT, datastore_uri) + ): + datastore_name = AzureMLDatastorePathUri(datastore_uri).datastore + else: + # local path + return + + if datastore_name.startswith(ARM_ID_PREFIX): + datastore_name = datastore_name[len(ARM_ID_PREFIX) :] + + self._datastore_operation.get(datastore_name) # type: ignore[attr-defined] + except ResourceNotFoundError as e: + msg = "The datastore {} could not be found in this workspace." + raise ValidationException( + message=msg.format(datastore_name), + target=ErrorTarget.DATASTORE, + no_personal_data_message=msg.format(""), + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.RESOURCE_NOT_FOUND, + ) from e + + +class _AssetResolver(Protocol): + """Describes the type of a function used by operation classes like :py:class:`JobOperations` and + :py:class:`ComponentOperations` to resolve Assets + + .. see-also:: methods :py:method:`OperationOrchestrator.get_asset_arm_id`, + :py:method:`OperationOrchestrator.resolve_azureml_id` + + """ + + def __call__( + self, + asset: Optional[Union[str, Asset]], + azureml_type: str, + register_asset: bool = True, + sub_workspace_resource: bool = True, + ) -> Optional[Union[str, Asset]]: + """Resolver function + + :param asset: The asset to resolve/register. It can be a ARM id or a entity's object. + :type asset: Optional[Union[str, Asset]] + :param azureml_type: The AzureML resource type. Defined in AzureMLResourceType. + :type azureml_type: str + :param register_asset: Indicates if the asset should be registered, defaults to True. + :type register_asset: Optional[bool] + :param sub_workspace_resource: + :type sub_workspace_resource: Optional[bool] + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if asset's ID cannot be converted + or asset cannot be successfully registered. + :return: The ARM Id or entity object + :rtype: Optional[Union[str, ~azure.ai.ml.entities.Asset]] + """ diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_registry_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_registry_operations.py new file mode 100644 index 00000000..6cd1a326 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_registry_operations.py @@ -0,0 +1,168 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access,unused-argument + +from typing import Dict, Iterable, Optional, cast + +from azure.ai.ml._restclient.v2022_10_01_preview import AzureMachineLearningWorkspaces as ServiceClient102022 +from azure.ai.ml._scope_dependent_operations import OperationsContainer, OperationScope +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml.entities import Registry +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException +from azure.core.credentials import TokenCredential +from azure.core.polling import LROPoller + +from .._utils._azureml_polling import AzureMLPolling +from ..constants._common import LROConfigurations, Scope + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class RegistryOperations: + """RegistryOperations. + + You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it + for you and attaches it as an attribute. + """ + + def __init__( + self, + operation_scope: OperationScope, + service_client: ServiceClient102022, + all_operations: OperationsContainer, + credentials: Optional[TokenCredential] = None, + **kwargs: Dict, + ): + ops_logger.update_filter() + self._subscription_id = operation_scope.subscription_id + self._resource_group_name = operation_scope.resource_group_name + self._default_registry_name = operation_scope.registry_name + self._operation = service_client.registries + self._all_operations = all_operations + self._credentials = credentials + self.containerRegistry = "none" + self._init_kwargs = kwargs + + @monitor_with_activity(ops_logger, "Registry.List", ActivityType.PUBLICAPI) + def list(self, *, scope: str = Scope.RESOURCE_GROUP) -> Iterable[Registry]: + """List all registries that the user has access to in the current resource group or subscription. + + :keyword scope: scope of the listing, "resource_group" or "subscription", defaults to "resource_group" + :paramtype scope: str + :return: An iterator like instance of Registry objects + :rtype: ~azure.core.paging.ItemPaged[Registry] + """ + if scope.lower() == Scope.SUBSCRIPTION: + return cast( + Iterable[Registry], + self._operation.list_by_subscription( + cls=lambda objs: [Registry._from_rest_object(obj) for obj in objs] + ), + ) + return cast( + Iterable[Registry], + self._operation.list( + cls=lambda objs: [Registry._from_rest_object(obj) for obj in objs], + resource_group_name=self._resource_group_name, + ), + ) + + @monitor_with_activity(ops_logger, "Registry.Get", ActivityType.PUBLICAPI) + def get(self, name: Optional[str] = None) -> Registry: + """Get a registry by name. + + :param name: Name of the registry. + :type name: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Registry name cannot be + successfully validated. Details will be provided in the error message. + :raises ~azure.core.exceptions.HttpResponseError: Raised if the corresponding name and version cannot be + retrieved from the service. + :return: The registry with the provided name. + :rtype: ~azure.ai.ml.entities.Registry + """ + + registry_name = self._check_registry_name(name) + resource_group = self._resource_group_name + obj = self._operation.get(resource_group, registry_name) + return Registry._from_rest_object(obj) # type: ignore[return-value] + + def _check_registry_name(self, name: Optional[str]) -> str: + registry_name = name or self._default_registry_name + if not registry_name: + msg = "Please provide a registry name or use a MLClient with a registry name set." + raise ValidationException( + message=msg, + target=ErrorTarget.REGISTRY, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + return registry_name + + def _get_polling(self, name: str) -> AzureMLPolling: + """Return the polling with custom poll interval. + + :param name: The registry name + :type name: str + :return: A poller with custom poll interval. + :rtype: AzureMLPolling + """ + path_format_arguments = { + "registryName": name, + "resourceGroupName": self._resource_group_name, + } + return AzureMLPolling( + LROConfigurations.POLL_INTERVAL, + path_format_arguments=path_format_arguments, + ) + + @monitor_with_activity(ops_logger, "Registry.BeginCreate", ActivityType.PUBLICAPI) + def begin_create( + self, + registry: Registry, + **kwargs: Dict, + ) -> LROPoller[Registry]: + """Create a new Azure Machine Learning Registry, or try to update if it already exists. + + Note: Due to service limitations we have to sleep for + an additional 30~45 seconds AFTER the LRO Poller concludes + before the registry will be consistently deleted from the + perspective of subsequent operations. + If a deletion is required for subsequent operations to + work properly, callers should implement that logic until the + service has been fixed to return a reliable LRO. + + :param registry: Registry definition. + :type registry: Registry + :return: A poller to track the operation status. + :rtype: LROPoller + """ + registry_data = registry._to_rest_object() + poller = self._operation.begin_create_or_update( + resource_group_name=self._resource_group_name, + registry_name=registry.name, + body=registry_data, + polling=self._get_polling(str(registry.name)), + cls=lambda response, deserialized, headers: Registry._from_rest_object(deserialized), + ) + + return poller + + @monitor_with_activity(ops_logger, "Registry.BeginDelete", ActivityType.PUBLICAPI) + def begin_delete(self, *, name: str, **kwargs: Dict) -> LROPoller[None]: + """Delete a registry if it exists. Returns nothing on a successful operation. + + :keyword name: Name of the registry + :paramtype name: str + :return: A poller to track the operation status. + :rtype: LROPoller + """ + resource_group = kwargs.get("resource_group") or self._resource_group_name + return self._operation.begin_delete( + resource_group_name=resource_group, + registry_name=name, + **self._init_kwargs, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_run_history_constants.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_run_history_constants.py new file mode 100644 index 00000000..17aaac29 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_run_history_constants.py @@ -0,0 +1,82 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import os + +# default timeout of session for getting content in the job run, +# the 1st element is conn timeout, the 2nd is the read timeout. + + +class JobStatus(object): + """ + * NotStarted - This is a temporary state client-side Run objects are in before cloud submission. + * Starting - The Run has started being processed in the cloud. The caller has a run ID at this point. + * Provisioning - Returned when on-demand compute is being created for a given job submission. + * Preparing - The run environment is being prepared: + * docker image build + * conda environment setup + * Queued - The job is queued in the compute target. For example, in BatchAI the job is in queued state + while waiting for all the requested nodes to be ready. + * Running - The job started to run in the compute target. + * Finalizing - User code has completed and the run is in post-processing stages. + * CancelRequested - Cancellation has been requested for the job. + * Completed - The run completed successfully. This includes both the user code and run + post-processing stages. + * Failed - The run failed. Usually the Error property on a run will provide details as to why. + * Canceled - Follows a cancellation request and indicates that the run is now successfully cancelled. + * NotResponding - For runs that have Heartbeats enabled, no heartbeat has been recently sent. + """ + + # Ordered by transition order + QUEUED = "Queued" + NOT_STARTED = "NotStarted" + PREPARING = "Preparing" + PROVISIONING = "Provisioning" + STARTING = "Starting" + RUNNING = "Running" + CANCEL_REQUESTED = "CancelRequested" + CANCELED = "Canceled" # Not official yet + FINALIZING = "Finalizing" + COMPLETED = "Completed" + FAILED = "Failed" + PAUSED = "Paused" + NOTRESPONDING = "NotResponding" + + +class RunHistoryConstants(object): + _DEFAULT_GET_CONTENT_TIMEOUT = (5, 120) + _WAIT_COMPLETION_POLLING_INTERVAL_MIN = os.environ.get("AZUREML_RUN_POLLING_INTERVAL_MIN", 2) + _WAIT_COMPLETION_POLLING_INTERVAL_MAX = os.environ.get("AZUREML_RUN_POLLING_INTERVAL_MAX", 60) + ALL_STATUSES = [ + JobStatus.QUEUED, + JobStatus.PREPARING, + JobStatus.PROVISIONING, + JobStatus.STARTING, + JobStatus.RUNNING, + JobStatus.CANCEL_REQUESTED, + JobStatus.CANCELED, + JobStatus.FINALIZING, + JobStatus.COMPLETED, + JobStatus.FAILED, + JobStatus.NOT_STARTED, + JobStatus.FAILED, + JobStatus.PAUSED, + JobStatus.NOTRESPONDING, + ] + IN_PROGRESS_STATUSES = [ + JobStatus.NOT_STARTED, + JobStatus.QUEUED, + JobStatus.PREPARING, + JobStatus.PROVISIONING, + JobStatus.STARTING, + JobStatus.RUNNING, + ] + POST_PROCESSING_STATUSES = [JobStatus.CANCEL_REQUESTED, JobStatus.FINALIZING] + TERMINAL_STATUSES = [ + JobStatus.COMPLETED, + JobStatus.FAILED, + JobStatus.CANCELED, + JobStatus.NOTRESPONDING, + JobStatus.PAUSED, + ] diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_run_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_run_operations.py new file mode 100644 index 00000000..627b1936 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_run_operations.py @@ -0,0 +1,94 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import logging +from typing import Any, Iterable, Optional, cast + +from azure.ai.ml._restclient.runhistory import AzureMachineLearningWorkspaces as RunHistoryServiceClient +from azure.ai.ml._restclient.runhistory.models import GetRunDataRequest, GetRunDataResult, Run, RunDetails +from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations +from azure.ai.ml.constants._common import AZUREML_RESOURCE_PROVIDER, NAMED_RESOURCE_ID_FORMAT, AzureMLResourceType +from azure.ai.ml.entities._job.base_job import _BaseJob +from azure.ai.ml.entities._job.job import Job +from azure.ai.ml.exceptions import JobParsingError + +module_logger = logging.getLogger(__name__) + + +class RunOperations(_ScopeDependentOperations): + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client: RunHistoryServiceClient, + ): + super(RunOperations, self).__init__(operation_scope, operation_config) + self._operation = service_client.runs + + def get_run(self, run_id: str) -> Run: + return self._operation.get( + self._operation_scope.subscription_id, + self._operation_scope.resource_group_name, + self._workspace_name, + run_id, + ) + + def get_run_details(self, run_id: str) -> RunDetails: + return self._operation.get_details( + self._operation_scope.subscription_id, + self._operation_scope.resource_group_name, + self._workspace_name, + run_id, + ) + + def get_run_children(self, run_id: str, **kwargs) -> Iterable[_BaseJob]: + return cast( + Iterable[_BaseJob], + self._operation.get_child( + self._subscription_id, + self._resource_group_name, + self._workspace_name, + run_id, + top=kwargs.pop("max_results", None), + cls=lambda objs: [self._translate_from_rest_object(obj) for obj in objs], + ), + ) + + def _translate_from_rest_object(self, job_object: Run) -> Optional[_BaseJob]: + """Handle errors during list operation. + + :param job_object: The job object + :type job_object: Run + :return: A job entity if parsing was successful + :rtype: Optional[_BaseJob] + """ + try: + from_rest_job: Any = Job._from_rest_object(job_object) + from_rest_job._id = NAMED_RESOURCE_ID_FORMAT.format( + self._subscription_id, + self._resource_group_name, + AZUREML_RESOURCE_PROVIDER, + self._workspace_name, + AzureMLResourceType.JOB, + from_rest_job.name, + ) + return from_rest_job + except JobParsingError: + return None + + def get_run_data(self, run_id: str) -> GetRunDataResult: + run_data_request = GetRunDataRequest( + run_id=run_id, + select_run_metadata=True, + select_run_definition=True, + select_job_specification=True, + ) + return self._operation.get_run_data( + self._subscription_id, + self._resource_group_name, + self._workspace_name, + body=run_data_request, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_schedule_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_schedule_operations.py new file mode 100644 index 00000000..8c34af40 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_schedule_operations.py @@ -0,0 +1,608 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=protected-access +from datetime import datetime, timezone +from typing import Any, Iterable, List, Optional, Tuple, cast + +from azure.ai.ml._restclient.v2023_06_01_preview import AzureMachineLearningWorkspaces as ServiceClient062023Preview +from azure.ai.ml._restclient.v2024_01_01_preview import AzureMachineLearningWorkspaces as ServiceClient012024Preview +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationsContainer, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity, monitor_with_telemetry_mixin +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml.entities import Job, JobSchedule, Schedule +from azure.ai.ml.entities._inputs_outputs.input import Input +from azure.ai.ml.entities._monitoring.schedule import MonitorSchedule +from azure.ai.ml.entities._monitoring.signals import ( + BaselineDataRange, + FADProductionData, + GenerationTokenStatisticsSignal, + LlmData, + ProductionData, + ReferenceData, +) +from azure.ai.ml.entities._monitoring.target import MonitoringTarget +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ScheduleException +from azure.core.credentials import TokenCredential +from azure.core.polling import LROPoller +from azure.core.tracing.decorator import distributed_trace + +from .._restclient.v2022_10_01.models import ScheduleListViewType +from .._restclient.v2024_01_01_preview.models import TriggerOnceRequest +from .._utils._arm_id_utils import AMLNamedArmId, AMLVersionedArmId, is_ARM_id_for_parented_resource +from .._utils._azureml_polling import AzureMLPolling +from .._utils.utils import snake_to_camel +from ..constants._common import ( + ARM_ID_PREFIX, + AZUREML_RESOURCE_PROVIDER, + NAMED_RESOURCE_ID_FORMAT_WITH_PARENT, + AzureMLResourceType, + LROConfigurations, +) +from ..constants._monitoring import ( + DEPLOYMENT_MODEL_INPUTS_COLLECTION_KEY, + DEPLOYMENT_MODEL_INPUTS_NAME_KEY, + DEPLOYMENT_MODEL_INPUTS_VERSION_KEY, + DEPLOYMENT_MODEL_OUTPUTS_COLLECTION_KEY, + DEPLOYMENT_MODEL_OUTPUTS_NAME_KEY, + DEPLOYMENT_MODEL_OUTPUTS_VERSION_KEY, + MonitorDatasetContext, + MonitorSignalType, +) +from ..entities._schedule.schedule import ScheduleTriggerResult +from . import DataOperations, JobOperations, OnlineDeploymentOperations +from ._job_ops_helper import stream_logs_until_completion +from ._operation_orchestrator import OperationOrchestrator + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class ScheduleOperations(_ScopeDependentOperations): + # pylint: disable=too-many-instance-attributes + """ + ScheduleOperations + + You should not instantiate this class directly. + Instead, you should create an MLClient instance that instantiates it for you and attaches it as an attribute. + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client_06_2023_preview: ServiceClient062023Preview, + service_client_01_2024_preview: ServiceClient012024Preview, + all_operations: OperationsContainer, + credential: TokenCredential, + **kwargs: Any, + ): + super(ScheduleOperations, self).__init__(operation_scope, operation_config) + ops_logger.update_filter() + self.service_client = service_client_06_2023_preview.schedules + # Note: Trigger once is supported since 24_01, we don't upgrade other operations' client because there are + # some breaking changes, for example: AzMonMonitoringAlertNotificationSettings is removed. + self.schedule_trigger_service_client = service_client_01_2024_preview.schedules + self._all_operations = all_operations + self._stream_logs_until_completion = stream_logs_until_completion + # Dataplane service clients are lazily created as they are needed + self._runs_operations_client = None + self._dataset_dataplane_operations_client = None + self._model_dataplane_operations_client = None + # Kwargs to propagate to dataplane service clients + self._service_client_kwargs = kwargs.pop("_service_client_kwargs", {}) + self._api_base_url = None + self._container = "azureml" + self._credential = credential + self._orchestrators = OperationOrchestrator(self._all_operations, self._operation_scope, self._operation_config) + + self._kwargs = kwargs + + @property + def _job_operations(self) -> JobOperations: + return cast( + JobOperations, + self._all_operations.get_operation( # type: ignore[misc] + AzureMLResourceType.JOB, lambda x: isinstance(x, JobOperations) + ), + ) + + @property + def _online_deployment_operations(self) -> OnlineDeploymentOperations: + return cast( + OnlineDeploymentOperations, + self._all_operations.get_operation( # type: ignore[misc] + AzureMLResourceType.ONLINE_DEPLOYMENT, lambda x: isinstance(x, OnlineDeploymentOperations) + ), + ) + + @property + def _data_operations(self) -> DataOperations: + return cast( + DataOperations, + self._all_operations.get_operation( # type: ignore[misc] + AzureMLResourceType.DATA, lambda x: isinstance(x, DataOperations) + ), + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "Schedule.List", ActivityType.PUBLICAPI) + def list( + self, + *, + list_view_type: ScheduleListViewType = ScheduleListViewType.ENABLED_ONLY, + **kwargs: Any, + ) -> Iterable[Schedule]: + """List schedules in specified workspace. + + :keyword list_view_type: View type for including/excluding (for example) + archived schedules. Default: ENABLED_ONLY. + :type list_view_type: Optional[ScheduleListViewType] + :return: An iterator to list Schedule. + :rtype: Iterable[Schedule] + """ + + def safe_from_rest_object(objs: Any) -> List: + result = [] + for obj in objs: + try: + result.append(Schedule._from_rest_object(obj)) + except Exception as e: # pylint: disable=W0718 + print(f"Translate {obj.name} to Schedule failed with: {e}") + return result + + return cast( + Iterable[Schedule], + self.service_client.list( + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + list_view_type=list_view_type, + cls=safe_from_rest_object, + **self._kwargs, + **kwargs, + ), + ) + + def _get_polling(self, name: Optional[str]) -> AzureMLPolling: + """Return the polling with custom poll interval. + + :param name: The schedule name + :type name: str + :return: The AzureMLPolling object + :rtype: AzureMLPolling + """ + path_format_arguments = { + "scheduleName": name, + "resourceGroupName": self._resource_group_name, + "workspaceName": self._workspace_name, + } + return AzureMLPolling( + LROConfigurations.POLL_INTERVAL, + path_format_arguments=path_format_arguments, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "Schedule.Delete", ActivityType.PUBLICAPI) + def begin_delete( + self, + name: str, + **kwargs: Any, + ) -> LROPoller[None]: + """Delete schedule. + + :param name: Schedule name. + :type name: str + :return: A poller for deletion status + :rtype: LROPoller[None] + """ + poller = self.service_client.begin_delete( + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + name=name, + polling=self._get_polling(name), + **self._kwargs, + **kwargs, + ) + return poller + + @distributed_trace + @monitor_with_telemetry_mixin(ops_logger, "Schedule.Get", ActivityType.PUBLICAPI) + def get( + self, + name: str, + **kwargs: Any, + ) -> Schedule: + """Get a schedule. + + :param name: Schedule name. + :type name: str + :return: The schedule object. + :rtype: Schedule + """ + return self.service_client.get( + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + name=name, + cls=lambda _, obj, __: Schedule._from_rest_object(obj), + **self._kwargs, + **kwargs, + ) + + @distributed_trace + @monitor_with_telemetry_mixin(ops_logger, "Schedule.CreateOrUpdate", ActivityType.PUBLICAPI) + def begin_create_or_update( + self, + schedule: Schedule, + **kwargs: Any, + ) -> LROPoller[Schedule]: + """Create or update schedule. + + :param schedule: Schedule definition. + :type schedule: Schedule + :return: An instance of LROPoller that returns Schedule if no_wait=True, or Schedule if no_wait=False + :rtype: Union[LROPoller, Schedule] + :rtype: Union[LROPoller, ~azure.ai.ml.entities.Schedule] + """ + + if isinstance(schedule, JobSchedule): + schedule._validate(raise_error=True) + if isinstance(schedule.create_job, Job): + # Create all dependent resources for job inside schedule + self._job_operations._resolve_arm_id_or_upload_dependencies(schedule.create_job) + elif isinstance(schedule, MonitorSchedule): + # resolve ARM id for target, compute, and input datasets for each signal + self._resolve_monitor_schedule_arm_id(schedule) + # Create schedule + schedule_data = schedule._to_rest_object() # type: ignore + poller = self.service_client.begin_create_or_update( + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + name=schedule.name, + cls=lambda _, obj, __: Schedule._from_rest_object(obj), + body=schedule_data, + polling=self._get_polling(schedule.name), + **self._kwargs, + **kwargs, + ) + return poller + + @distributed_trace + @monitor_with_activity(ops_logger, "Schedule.Enable", ActivityType.PUBLICAPI) + def begin_enable( + self, + name: str, + **kwargs: Any, + ) -> LROPoller[Schedule]: + """Enable a schedule. + + :param name: Schedule name. + :type name: str + :return: An instance of LROPoller that returns Schedule + :rtype: LROPoller + """ + schedule = self.get(name=name) + schedule._is_enabled = True + return self.begin_create_or_update(schedule, **kwargs) + + @distributed_trace + @monitor_with_activity(ops_logger, "Schedule.Disable", ActivityType.PUBLICAPI) + def begin_disable( + self, + name: str, + **kwargs: Any, + ) -> LROPoller[Schedule]: + """Disable a schedule. + + :param name: Schedule name. + :type name: str + :return: An instance of LROPoller that returns Schedule if no_wait=True, or Schedule if no_wait=False + :rtype: LROPoller + """ + schedule = self.get(name=name) + schedule._is_enabled = False + return self.begin_create_or_update(schedule, **kwargs) + + @distributed_trace + @monitor_with_activity(ops_logger, "Schedule.Trigger", ActivityType.PUBLICAPI) + def trigger( + self, + name: str, + **kwargs: Any, + ) -> ScheduleTriggerResult: + """Trigger a schedule once. + + :param name: Schedule name. + :type name: str + :return: TriggerRunSubmissionDto, or the result of cls(response) + :rtype: ~azure.ai.ml.entities.ScheduleTriggerResult + """ + schedule_time = kwargs.pop("schedule_time", datetime.now(timezone.utc).isoformat()) + return self.schedule_trigger_service_client.trigger( + name=name, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._workspace_name, + body=TriggerOnceRequest(schedule_time=schedule_time), + cls=lambda _, obj, __: ScheduleTriggerResult._from_rest_object(obj), + **kwargs, + ) + + def _resolve_monitor_schedule_arm_id( # pylint:disable=too-many-branches,too-many-statements,too-many-locals + self, schedule: MonitorSchedule + ) -> None: + # resolve target ARM ID + model_inputs_name, model_outputs_name = None, None + app_traces_name, app_traces_version = None, None + model_inputs_version, model_outputs_version = None, None + mdc_input_enabled, mdc_output_enabled = False, False + target = schedule.create_monitor.monitoring_target + if target and target.endpoint_deployment_id: + endpoint_name, deployment_name = self._process_and_get_endpoint_deployment_names_from_id(target) + online_deployment = self._online_deployment_operations.get(deployment_name, endpoint_name) + deployment_data_collector = online_deployment.data_collector + if deployment_data_collector: + in_reg = AMLVersionedArmId(deployment_data_collector.collections.get("model_inputs").data) + out_reg = AMLVersionedArmId(deployment_data_collector.collections.get("model_outputs").data) + if "app_traces" in deployment_data_collector.collections: + app_traces = AMLVersionedArmId(deployment_data_collector.collections.get("app_traces").data) + app_traces_name = app_traces.asset_name + app_traces_version = app_traces.asset_version + model_inputs_name = in_reg.asset_name + model_inputs_version = in_reg.asset_version + model_outputs_name = out_reg.asset_name + model_outputs_version = out_reg.asset_version + mdc_input_enabled_str = deployment_data_collector.collections.get("model_inputs").enabled + mdc_output_enabled_str = deployment_data_collector.collections.get("model_outputs").enabled + else: + model_inputs_name = online_deployment.tags.get(DEPLOYMENT_MODEL_INPUTS_NAME_KEY) + model_inputs_version = online_deployment.tags.get(DEPLOYMENT_MODEL_INPUTS_VERSION_KEY) + model_outputs_name = online_deployment.tags.get(DEPLOYMENT_MODEL_OUTPUTS_NAME_KEY) + model_outputs_version = online_deployment.tags.get(DEPLOYMENT_MODEL_OUTPUTS_VERSION_KEY) + mdc_input_enabled_str = online_deployment.tags.get(DEPLOYMENT_MODEL_INPUTS_COLLECTION_KEY) + mdc_output_enabled_str = online_deployment.tags.get(DEPLOYMENT_MODEL_OUTPUTS_COLLECTION_KEY) + if mdc_input_enabled_str and mdc_input_enabled_str.lower() == "true": + mdc_input_enabled = True + if mdc_output_enabled_str and mdc_output_enabled_str.lower() == "true": + mdc_output_enabled = True + elif target and target.model_id: + target.model_id = self._orchestrators.get_asset_arm_id( # type: ignore + target.model_id, + AzureMLResourceType.MODEL, + register_asset=False, + ) + + if not schedule.create_monitor.monitoring_signals: + if mdc_input_enabled and mdc_output_enabled: + schedule._create_default_monitor_definition() + else: + msg = ( + "An ARM id for a deployment with data collector enabled for model inputs and outputs must be " + "given if monitoring_signals is None" + ) + raise ScheduleException( + message=msg, + no_personal_data_message=msg, + target=ErrorTarget.SCHEDULE, + error_category=ErrorCategory.USER_ERROR, + ) + # resolve ARM id for each signal and populate any defaults if needed + for signal_name, signal in schedule.create_monitor.monitoring_signals.items(): # type: ignore + if signal.type == MonitorSignalType.GENERATION_SAFETY_QUALITY: + for llm_data in signal.production_data: # type: ignore[union-attr] + self._job_operations._resolve_job_input(llm_data.input_data, schedule._base_path) + continue + if signal.type == MonitorSignalType.GENERATION_TOKEN_STATISTICS: + if not signal.production_data: # type: ignore[union-attr] + # if target dataset is absent and data collector for input is enabled, + # create a default target dataset with production app traces as target + if isinstance(signal, GenerationTokenStatisticsSignal): + signal.production_data = LlmData( # type: ignore[union-attr] + input_data=Input( + path=f"{app_traces_name}:{app_traces_version}", + type=self._data_operations.get(app_traces_name, app_traces_version).type, + ), + data_window=BaselineDataRange(lookback_window_size="P7D", lookback_window_offset="P0D"), + ) + self._job_operations._resolve_job_input( + signal.production_data.input_data, schedule._base_path # type: ignore[union-attr] + ) + continue + if signal.type == MonitorSignalType.CUSTOM: + if signal.inputs: # type: ignore[union-attr] + for inputs in signal.inputs.values(): # type: ignore[union-attr] + self._job_operations._resolve_job_input(inputs, schedule._base_path) + for data in signal.input_data.values(): # type: ignore[union-attr] + if data.input_data is not None: + for inputs in data.input_data.values(): + self._job_operations._resolve_job_input(inputs, schedule._base_path) + data.pre_processing_component = self._orchestrators.get_asset_arm_id( + asset=data.pre_processing_component if hasattr(data, "pre_processing_component") else None, + azureml_type=AzureMLResourceType.COMPONENT, + ) + continue + error_messages = [] + if not signal.production_data or not signal.reference_data: # type: ignore[union-attr] + # if there is no target dataset, we check the type of signal + if signal.type in {MonitorSignalType.DATA_DRIFT, MonitorSignalType.DATA_QUALITY}: + if mdc_input_enabled: + if not signal.production_data: # type: ignore[union-attr] + # if target dataset is absent and data collector for input is enabled, + # create a default target dataset with production model inputs as target + signal.production_data = ProductionData( # type: ignore[union-attr] + input_data=Input( + path=f"{model_inputs_name}:{model_inputs_version}", + type=self._data_operations.get(model_inputs_name, model_inputs_version).type, + ), + data_context=MonitorDatasetContext.MODEL_INPUTS, + data_window=BaselineDataRange( + lookback_window_size="default", lookback_window_offset="P0D" + ), + ) + if not signal.reference_data: # type: ignore[union-attr] + signal.reference_data = ReferenceData( # type: ignore[union-attr] + input_data=Input( + path=f"{model_inputs_name}:{model_inputs_version}", + type=self._data_operations.get(model_inputs_name, model_inputs_version).type, + ), + data_context=MonitorDatasetContext.MODEL_INPUTS, + data_window=BaselineDataRange( + lookback_window_size="default", lookback_window_offset="default" + ), + ) + elif not mdc_input_enabled and not ( + signal.production_data and signal.reference_data # type: ignore[union-attr] + ): + # if target or baseline dataset is absent and data collector for input is not enabled, + # collect exception message + msg = ( + f"A target and baseline dataset must be provided for signal with name {signal_name}" + f"and type {signal.type} if the monitoring_target endpoint_deployment_id is empty" + "or refers to a deployment for which data collection for model inputs is not enabled." + ) + error_messages.append(msg) + elif signal.type == MonitorSignalType.PREDICTION_DRIFT: + if mdc_output_enabled: + if not signal.production_data: # type: ignore[union-attr] + # if target dataset is absent and data collector for output is enabled, + # create a default target dataset with production model outputs as target + signal.production_data = ProductionData( # type: ignore[union-attr] + input_data=Input( + path=f"{model_outputs_name}:{model_outputs_version}", + type=self._data_operations.get(model_outputs_name, model_outputs_version).type, + ), + data_context=MonitorDatasetContext.MODEL_OUTPUTS, + data_window=BaselineDataRange( + lookback_window_size="default", lookback_window_offset="P0D" + ), + ) + if not signal.reference_data: # type: ignore[union-attr] + signal.reference_data = ReferenceData( # type: ignore[union-attr] + input_data=Input( + path=f"{model_outputs_name}:{model_outputs_version}", + type=self._data_operations.get(model_outputs_name, model_outputs_version).type, + ), + data_context=MonitorDatasetContext.MODEL_OUTPUTS, + data_window=BaselineDataRange( + lookback_window_size="default", lookback_window_offset="default" + ), + ) + elif not mdc_output_enabled and not ( + signal.production_data and signal.reference_data # type: ignore[union-attr] + ): + # if target dataset is absent and data collector for output is not enabled, + # collect exception message + msg = ( + f"A target and baseline dataset must be provided for signal with name {signal_name}" + f"and type {signal.type} if the monitoring_target endpoint_deployment_id is empty" + "or refers to a deployment for which data collection for model outputs is not enabled." + ) + error_messages.append(msg) + elif signal.type == MonitorSignalType.FEATURE_ATTRIBUTION_DRIFT: + if mdc_input_enabled: + if not signal.production_data: # type: ignore[union-attr] + # if production dataset is absent and data collector for input is enabled, + # create a default prod dataset with production model inputs and outputs as target + signal.production_data = [ # type: ignore[union-attr] + FADProductionData( + input_data=Input( + path=f"{model_inputs_name}:{model_inputs_version}", + type=self._data_operations.get(model_inputs_name, model_inputs_version).type, + ), + data_context=MonitorDatasetContext.MODEL_INPUTS, + data_window=BaselineDataRange( + lookback_window_size="default", lookback_window_offset="P0D" + ), + ), + FADProductionData( + input_data=Input( + path=f"{model_outputs_name}:{model_outputs_version}", + type=self._data_operations.get(model_outputs_name, model_outputs_version).type, + ), + data_context=MonitorDatasetContext.MODEL_OUTPUTS, + data_window=BaselineDataRange( + lookback_window_size="default", lookback_window_offset="P0D" + ), + ), + ] + elif not mdc_output_enabled and not signal.production_data: # type: ignore[union-attr] + # if target dataset is absent and data collector for output is not enabled, + # collect exception message + msg = ( + f"A production data must be provided for signal with name {signal_name}" + f"and type {signal.type} if the monitoring_target endpoint_deployment_id is empty" + "or refers to a deployment for which data collection for model outputs is not enabled." + ) + error_messages.append(msg) + if error_messages: + # if any error messages, raise an exception with all of them so user knows which signals + # need to be fixed + msg = "\n".join(error_messages) + raise ScheduleException( + message=msg, + no_personal_data_message=msg, + ErrorTarget=ErrorTarget.SCHEDULE, + ErrorCategory=ErrorCategory.USER_ERROR, + ) + if signal.type == MonitorSignalType.FEATURE_ATTRIBUTION_DRIFT: + for prod_data in signal.production_data: # type: ignore[union-attr] + self._job_operations._resolve_job_input(prod_data.input_data, schedule._base_path) + prod_data.pre_processing_component = self._orchestrators.get_asset_arm_id( # type: ignore + asset=prod_data.pre_processing_component, # type: ignore[union-attr] + azureml_type=AzureMLResourceType.COMPONENT, + ) + self._job_operations._resolve_job_input( + signal.reference_data.input_data, schedule._base_path # type: ignore[union-attr] + ) + signal.reference_data.pre_processing_component = self._orchestrators.get_asset_arm_id( # type: ignore + asset=signal.reference_data.pre_processing_component, # type: ignore[union-attr] + azureml_type=AzureMLResourceType.COMPONENT, + ) + continue + + self._job_operations._resolve_job_inputs( + [signal.production_data.input_data, signal.reference_data.input_data], # type: ignore[union-attr] + schedule._base_path, + ) + signal.production_data.pre_processing_component = self._orchestrators.get_asset_arm_id( # type: ignore + asset=signal.production_data.pre_processing_component, # type: ignore[union-attr] + azureml_type=AzureMLResourceType.COMPONENT, + ) + signal.reference_data.pre_processing_component = self._orchestrators.get_asset_arm_id( # type: ignore + asset=signal.reference_data.pre_processing_component, # type: ignore[union-attr] + azureml_type=AzureMLResourceType.COMPONENT, + ) + + def _process_and_get_endpoint_deployment_names_from_id(self, target: MonitoringTarget) -> Tuple: + target.endpoint_deployment_id = ( + target.endpoint_deployment_id[len(ARM_ID_PREFIX) :] # type: ignore + if target.endpoint_deployment_id is not None and target.endpoint_deployment_id.startswith(ARM_ID_PREFIX) + else target.endpoint_deployment_id + ) + + # if it is an ARM ID, don't process it + if not is_ARM_id_for_parented_resource( + target.endpoint_deployment_id, + snake_to_camel(AzureMLResourceType.ONLINE_ENDPOINT), + AzureMLResourceType.DEPLOYMENT, + ): + endpoint_name, deployment_name = target.endpoint_deployment_id.split(":") # type: ignore + target.endpoint_deployment_id = NAMED_RESOURCE_ID_FORMAT_WITH_PARENT.format( + self._subscription_id, + self._resource_group_name, + AZUREML_RESOURCE_PROVIDER, + self._workspace_name, + snake_to_camel(AzureMLResourceType.ONLINE_ENDPOINT), + endpoint_name, + AzureMLResourceType.DEPLOYMENT, + deployment_name, + ) + else: + deployment_arm_id_entity = AMLNamedArmId(target.endpoint_deployment_id) + endpoint_name = deployment_arm_id_entity.parent_asset_name + deployment_name = deployment_arm_id_entity.asset_name + + return endpoint_name, deployment_name diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_serverless_endpoint_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_serverless_endpoint_operations.py new file mode 100644 index 00000000..5efec117 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_serverless_endpoint_operations.py @@ -0,0 +1,223 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import re +from typing import Iterable + +from azure.ai.ml._restclient.v2024_01_01_preview import ( + AzureMachineLearningWorkspaces as ServiceClient202401Preview, +) +from azure.ai.ml._restclient.v2024_01_01_preview.models import ( + KeyType, + RegenerateEndpointKeysRequest, +) +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationsContainer, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml.constants._common import REGISTRY_VERSION_PATTERN, AzureMLResourceType +from azure.ai.ml.constants._endpoint import EndpointKeyType +from azure.ai.ml.entities._autogen_entities.models import ServerlessEndpoint +from azure.ai.ml.entities._endpoint.online_endpoint import EndpointAuthKeys +from azure.ai.ml.exceptions import ( + ErrorCategory, + ErrorTarget, + ValidationErrorType, + ValidationException, +) +from azure.core.polling import LROPoller + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class ServerlessEndpointOperations(_ScopeDependentOperations): + """ServerlessEndpointOperations. + + You should not instantiate this class directly. Instead, you should + create an MLClient instance that instantiates it for you and + attaches it as an attribute. + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client: ServiceClient202401Preview, + all_operations: OperationsContainer, + ): + super().__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._service_client = service_client.serverless_endpoints + self._marketplace_subscriptions = service_client.marketplace_subscriptions + self._all_operations = all_operations + + def _get_workspace_location(self) -> str: + return str( + self._all_operations.all_operations[AzureMLResourceType.WORKSPACE].get(self._workspace_name).location + ) + + @experimental + @monitor_with_activity(ops_logger, "ServerlessEndpoint.BeginCreateOrUpdate", ActivityType.PUBLICAPI) + def begin_create_or_update(self, endpoint: ServerlessEndpoint, **kwargs) -> LROPoller[ServerlessEndpoint]: + """Create or update a serverless endpoint. + + :param endpoint: The serverless endpoint entity. + :type endpoint: ~azure.ai.ml.entities.ServerlessEndpoint + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if ServerlessEndpoint cannot be + successfully validated. Details will be provided in the error message. + :return: A poller to track the operation status + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.ServerlessEndpoint] + """ + if not endpoint.location: + endpoint.location = self._get_workspace_location() + if re.match(REGISTRY_VERSION_PATTERN, endpoint.model_id): + msg = ( + "The given model_id {} points to a specific model version, which is not supported. " + "Please provide a model_id without the version information." + ) + raise ValidationException( + message=msg.format(endpoint.model_id), + no_personal_data_message="Invalid model_id given for serverless endpoint", + target=ErrorTarget.SERVERLESS_ENDPOINT, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + return self._service_client.begin_create_or_update( + self._resource_group_name, + self._workspace_name, + endpoint.name, + endpoint._to_rest_object(), # type: ignore + cls=( + lambda response, deserialized, headers: ServerlessEndpoint._from_rest_object( # type: ignore + deserialized + ) + ), + **kwargs, + ) + + @experimental + @monitor_with_activity(ops_logger, "ServerlessEndpoint.Get", ActivityType.PUBLICAPI) + def get(self, name: str, **kwargs) -> ServerlessEndpoint: + """Get a Serverless Endpoint resource. + + :param name: Name of the serverless endpoint. + :type name: str + :return: Serverless endpoint object retrieved from the service. + :rtype: ~azure.ai.ml.entities.ServerlessEndpoint + """ + return self._service_client.get( + self._resource_group_name, + self._workspace_name, + name, + cls=( + lambda response, deserialized, headers: ServerlessEndpoint._from_rest_object( # type: ignore + deserialized + ) + ), + **kwargs, + ) + + @experimental + @monitor_with_activity(ops_logger, "ServerlessEndpoint.list", ActivityType.PUBLICAPI) + def list(self, **kwargs) -> Iterable[ServerlessEndpoint]: + """List serverless endpoints of the workspace. + + :return: A list of serverless endpoints + :rtype: ~typing.Iterable[~azure.ai.ml.entities.ServerlessEndpoint] + """ + return self._service_client.list( + self._resource_group_name, + self._workspace_name, + cls=lambda objs: [ServerlessEndpoint._from_rest_object(obj) for obj in objs], # type: ignore + **kwargs, + ) + + @experimental + @monitor_with_activity(ops_logger, "ServerlessEndpoint.BeginDelete", ActivityType.PUBLICAPI) + def begin_delete(self, name: str, **kwargs) -> LROPoller[None]: + """Delete a Serverless Endpoint. + + :param name: Name of the serverless endpoint. + :type name: str + :return: A poller to track the operation status. + :rtype: ~azure.core.polling.LROPoller[None] + """ + return self._service_client.begin_delete( + self._resource_group_name, + self._workspace_name, + name, + **kwargs, + ) + + @experimental + @monitor_with_activity(ops_logger, "ServerlessEndpoint.GetKeys", ActivityType.PUBLICAPI) + def get_keys(self, name: str, **kwargs) -> EndpointAuthKeys: + """Get serveless endpoint auth keys. + + :param name: The serverless endpoint name + :type name: str + :return: Returns the keys of the serverless endpoint + :rtype: ~azure.ai.ml.entities.EndpointAuthKeys + """ + return self._service_client.list_keys( + self._resource_group_name, + self._workspace_name, + name, + cls=lambda response, deserialized, headers: EndpointAuthKeys._from_rest_object(deserialized), + **kwargs, + ) + + @experimental + @monitor_with_activity(ops_logger, "ServerlessEndpoint.BeginRegenerateKeys", ActivityType.PUBLICAPI) + def begin_regenerate_keys( + self, + name: str, + *, + key_type: str = EndpointKeyType.PRIMARY_KEY_TYPE, + **kwargs, + ) -> LROPoller[EndpointAuthKeys]: + """Regenerate keys for a serverless endpoint. + + :param name: The endpoint name. + :type name: str + :keyword key_type: One of "primary", "secondary". Defaults to "primary". + :paramtype key_type: str + :raises ~azure.ai.ml.exceptions.ValidationException: Raised if key_type is not "primary" + or "secondary" + :return: A poller to track the operation status. + :rtype: ~azure.core.polling.LROPoller[EndpointAuthKeys] + """ + keys = self.get_keys( + name=name, + ) + if key_type.lower() == EndpointKeyType.PRIMARY_KEY_TYPE: + key_request = RegenerateEndpointKeysRequest(key_type=KeyType.Primary, key_value=keys.primary_key) + elif key_type.lower() == EndpointKeyType.SECONDARY_KEY_TYPE: + key_request = RegenerateEndpointKeysRequest(key_type=KeyType.Secondary, key_value=keys.secondary_key) + else: + msg = "Key type must be 'primary' or 'secondary'." + raise ValidationException( + message=msg, + target=ErrorTarget.SERVERLESS_ENDPOINT, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + error_type=ValidationErrorType.INVALID_VALUE, + ) + + return self._service_client.begin_regenerate_keys( + resource_group_name=self._resource_group_name, + workspace_name=self._workspace_name, + endpoint_name=name, + body=key_request, + cls=lambda response, deserialized, headers: EndpointAuthKeys._from_rest_object(deserialized), + **kwargs, + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_virtual_cluster_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_virtual_cluster_operations.py new file mode 100644 index 00000000..256664ad --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_virtual_cluster_operations.py @@ -0,0 +1,174 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Dict, Iterable, Optional, cast + +from azure.ai.ml._scope_dependent_operations import OperationScope +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._arm_id_utils import AzureResourceId +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml._utils._virtual_cluster_utils import ( + CrossRegionIndexEntitiesRequest, + IndexEntitiesRequest, + IndexEntitiesRequestFilter, + IndexEntitiesRequestOrder, + IndexEntitiesRequestOrderDirection, + IndexServiceAPIs, + index_entity_response_to_job, +) +from azure.ai.ml._utils.azure_resource_utils import ( + get_generic_resource_by_id, + get_virtual_cluster_by_name, + get_virtual_clusters_from_subscriptions, +) +from azure.ai.ml.constants._common import AZUREML_RESOURCE_PROVIDER, LEVEL_ONE_NAMED_RESOURCE_ID_FORMAT, Scope +from azure.ai.ml.entities import Job +from azure.ai.ml.exceptions import UserErrorException, ValidationException +from azure.core.credentials import TokenCredential +from azure.core.tracing.decorator import distributed_trace + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class VirtualClusterOperations: + """VirtualClusterOperations. + + You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it + for you and attaches it as an attribute. + """ + + def __init__( + self, + operation_scope: OperationScope, + credentials: TokenCredential, + *, + _service_client_kwargs: Dict, + **kwargs: Dict, + ): + ops_logger.update_filter() + self._resource_group_name = operation_scope.resource_group_name + self._subscription_id = operation_scope.subscription_id + self._credentials = credentials + self._init_kwargs = kwargs + self._service_client_kwargs = _service_client_kwargs + + """A (mostly) autogenerated rest client for the index service. + + TODO: Remove this property and the rest client when list job by virtual cluster is added to virtual cluster rp + """ + self._index_service = IndexServiceAPIs( + credential=self._credentials, base_url="https://westus2.api.azureml.ms", **self._service_client_kwargs + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "VirtualCluster.List", ActivityType.PUBLICAPI) + def list(self, *, scope: Optional[str] = None) -> Iterable[Dict]: + """List virtual clusters a user has access to. + + :keyword scope: scope of the listing, "subscription" or None, defaults to None. + If None, list virtual clusters across all subscriptions a customer has access to. + :paramtype scope: str + :return: An iterator like instance of dictionaries. + :rtype: ~azure.core.paging.ItemPaged[Dict] + """ + + if scope is None: + subscription_list = None + elif scope.lower() == Scope.SUBSCRIPTION: + subscription_list = [self._subscription_id] + else: + message = f"Invalid scope: {scope}. Valid values are 'subscription' or None." + raise UserErrorException(message=message, no_personal_data_message=message) + + try: + return cast( + Iterable[Dict], + get_virtual_clusters_from_subscriptions(self._credentials, subscription_list=subscription_list), + ) + except ImportError as e: + raise UserErrorException( + message="Met ImportError when trying to list virtual clusters. " + "Please install azure-mgmt-resource to enable this feature; " + "and please install azure-mgmt-resource to enable listing virtual clusters " + "across all subscriptions a customer has access to." + ) from e + + @distributed_trace + @monitor_with_activity(ops_logger, "VirtualCluster.ListJobs", ActivityType.PUBLICAPI) + def list_jobs(self, name: str) -> Iterable[Job]: + """List of jobs that target the virtual cluster + + :param name: Name of virtual cluster + :type name: str + :return: An iterable of jobs. + :rtype: Iterable[Job] + """ + + def make_id(entity_type: str) -> str: + return str( + LEVEL_ONE_NAMED_RESOURCE_ID_FORMAT.format( + self._subscription_id, self._resource_group_name, AZUREML_RESOURCE_PROVIDER, entity_type, name + ) + ) + + # List of virtual cluster ids to match + # Needs to include several capitalizations for historical reasons. Will be fixed in a service side change + vc_ids = [make_id("virtualClusters"), make_id("virtualclusters"), make_id("virtualClusters").lower()] + + filters = [ + IndexEntitiesRequestFilter(field="type", operator="eq", values=["runs"]), + IndexEntitiesRequestFilter(field="annotations/archived", operator="eq", values=["false"]), + IndexEntitiesRequestFilter(field="properties/userProperties/azureml.VC", operator="eq", values=vc_ids), + ] + order = [ + IndexEntitiesRequestOrder( + field="properties/creationContext/createdTime", direction=IndexEntitiesRequestOrderDirection.DESC + ) + ] + index_entities_request = IndexEntitiesRequest(filters=filters, order=order) + + # cspell:ignore entites + return cast( + Iterable[Job], + self._index_service.index_entities.get_entites_cross_region( + body=CrossRegionIndexEntitiesRequest(index_entities_request=index_entities_request), + cls=lambda objs: [index_entity_response_to_job(obj) for obj in objs], + ), + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "VirtualCluster.Get", ActivityType.PUBLICAPI) + def get(self, name: str) -> Dict: + """ + Get a virtual cluster resource. If name is provided, the virtual cluster + with the name in the subscription and resource group of the MLClient object + will be returned. If an ARM id is provided, a virtual cluster with the id will be returned. + + :param name: Name or ARM ID of the virtual cluster. + :type name: str + :return: Virtual cluster object + :rtype: Dict + """ + + try: + arm_id = AzureResourceId(name) + sub_id = arm_id.subscription_id + + return cast( + Dict, + get_generic_resource_by_id( + arm_id=name, credential=self._credentials, subscription_id=sub_id, api_version="2021-03-01-preview" + ), + ) + except ValidationException: + return cast( + Dict, + get_virtual_cluster_by_name( + name=name, + resource_group=self._resource_group_name, + subscription_id=self._subscription_id, + credential=self._credentials, + ), + ) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_connections_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_connections_operations.py new file mode 100644 index 00000000..6debb243 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_connections_operations.py @@ -0,0 +1,189 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Any, Dict, Iterable, Optional, cast + +from azure.ai.ml._restclient.v2024_04_01_preview import AzureMachineLearningWorkspaces as ServiceClient082023Preview +from azure.ai.ml._scope_dependent_operations import ( + OperationConfig, + OperationsContainer, + OperationScope, + _ScopeDependentOperations, +) +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml._utils.utils import _snake_to_camel +from azure.ai.ml.entities._credentials import ApiKeyConfiguration +from azure.ai.ml.entities._workspace.connections.workspace_connection import WorkspaceConnection +from azure.core.credentials import TokenCredential +from azure.core.tracing.decorator import distributed_trace + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class WorkspaceConnectionsOperations(_ScopeDependentOperations): + """WorkspaceConnectionsOperations. + + You should not instantiate this class directly. Instead, you should create + an MLClient instance that instantiates it for you and attaches it as an attribute. + """ + + def __init__( + self, + operation_scope: OperationScope, + operation_config: OperationConfig, + service_client: ServiceClient082023Preview, + all_operations: OperationsContainer, + credentials: Optional[TokenCredential] = None, + **kwargs: Dict, + ): + super(WorkspaceConnectionsOperations, self).__init__(operation_scope, operation_config) + ops_logger.update_filter() + self._all_operations = all_operations + self._operation = service_client.workspace_connections + self._credentials = credentials + self._init_kwargs = kwargs + + def _try_fill_api_key(self, connection: WorkspaceConnection) -> None: + """Try to fill in a connection's credentials with it's actual values. + Connection data retrievals normally return an empty credential object that merely includes the + connection's credential type, but not the actual secrets of that credential. + However, it's extremely common for users to want to know the contents of their connection's credentials. + This method tries to fill in the user's credentials with the actual values by making + a secondary API call to the service. It requires that the user have the necessary permissions to do so, + and it only works on api key-based credentials. + + :param connection: The connection to try to fill in the credentials for. + :type connection: ~azure.ai.ml.entities.WorkspaceConnection + """ + if hasattr(connection, "credentials") and isinstance(connection.credentials, ApiKeyConfiguration): + list_secrets_response = self._operation.list_secrets( + connection_name=connection.name, + resource_group_name=self._operation_scope.resource_group_name, + workspace_name=self._operation_scope.workspace_name, + ) + if list_secrets_response.properties.credentials is not None: + connection.credentials.key = list_secrets_response.properties.credentials.key + + @distributed_trace + @monitor_with_activity(ops_logger, "WorkspaceConnections.Get", ActivityType.PUBLICAPI) + def get(self, name: str, *, populate_secrets: bool = False, **kwargs: Dict) -> WorkspaceConnection: + """Get a connection by name. + + :param name: Name of the connection. + :type name: str + :keyword populate_secrets: If true, make a secondary API call to try filling in the workspace + connections credentials. Currently only works for api key-based credentials. Defaults to False. + :paramtype populate_secrets: bool + :raises ~azure.core.exceptions.HttpResponseError: Raised if the corresponding name and version cannot be + retrieved from the service. + :return: The connection with the provided name. + :rtype: ~azure.ai.ml.entities.WorkspaceConnection + """ + + connection = WorkspaceConnection._from_rest_object( + rest_obj=self._operation.get( + workspace_name=self._workspace_name, + connection_name=name, + **self._scope_kwargs, + **kwargs, + ) + ) + + if populate_secrets and connection is not None: + self._try_fill_api_key(connection) + return connection # type: ignore[return-value] + + @distributed_trace + @monitor_with_activity(ops_logger, "WorkspaceConnections.CreateOrUpdate", ActivityType.PUBLICAPI) + def create_or_update( + self, workspace_connection: WorkspaceConnection, *, populate_secrets: bool = False, **kwargs: Any + ) -> WorkspaceConnection: + """Create or update a connection. + + :param workspace_connection: Definition of a Workspace Connection or one of its subclasses + or object which can be translated to a connection. + :type workspace_connection: ~azure.ai.ml.entities.WorkspaceConnection + :keyword populate_secrets: If true, make a secondary API call to try filling in the workspace + connections credentials. Currently only works for api key-based credentials. Defaults to False. + :paramtype populate_secrets: bool + :return: Created or update connection. + :rtype: ~azure.ai.ml.entities.WorkspaceConnection + """ + rest_workspace_connection = workspace_connection._to_rest_object() + response = self._operation.create( + workspace_name=self._workspace_name, + connection_name=workspace_connection.name, + body=rest_workspace_connection, + **self._scope_kwargs, + **kwargs, + ) + conn = WorkspaceConnection._from_rest_object(rest_obj=response) + if populate_secrets and conn is not None: + self._try_fill_api_key(conn) + return conn + + @distributed_trace + @monitor_with_activity(ops_logger, "WorkspaceConnections.Delete", ActivityType.PUBLICAPI) + def delete(self, name: str, **kwargs: Any) -> None: + """Delete the connection. + + :param name: Name of the connection. + :type name: str + """ + + self._operation.delete( + connection_name=name, + workspace_name=self._workspace_name, + **self._scope_kwargs, + **kwargs, + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "WorkspaceConnections.List", ActivityType.PUBLICAPI) + def list( + self, + connection_type: Optional[str] = None, + *, + populate_secrets: bool = False, + include_data_connections: bool = False, + **kwargs: Any, + ) -> Iterable[WorkspaceConnection]: + """List all connections for a workspace. + + :param connection_type: Type of connection to list. + :type connection_type: Optional[str] + :keyword populate_secrets: If true, make a secondary API call to try filling in the workspace + connections credentials. Currently only works for api key-based credentials. Defaults to False. + :paramtype populate_secrets: bool + :keyword include_data_connections: If true, also return data connections. Defaults to False. + :paramtype include_data_connections: bool + :return: An iterator like instance of connection objects + :rtype: Iterable[~azure.ai.ml.entities.WorkspaceConnection] + """ + + if include_data_connections: + if "params" in kwargs: + kwargs["params"]["includeAll"] = "true" + else: + kwargs["params"] = {"includeAll": "true"} + + def post_process_conn(rest_obj): + result = WorkspaceConnection._from_rest_object(rest_obj) + if populate_secrets and result is not None: + self._try_fill_api_key(result) + return result + + result = self._operation.list( + workspace_name=self._workspace_name, + cls=lambda objs: [post_process_conn(obj) for obj in objs], + category=_snake_to_camel(connection_type) if connection_type else connection_type, + **self._scope_kwargs, + **kwargs, + ) + + return cast(Iterable[WorkspaceConnection], result) diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_operations.py new file mode 100644 index 00000000..47fd8747 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_operations.py @@ -0,0 +1,443 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +from typing import Any, Dict, Iterable, List, Optional, Union, cast + +from marshmallow import ValidationError + +from azure.ai.ml._restclient.v2024_10_01_preview import AzureMachineLearningWorkspaces as ServiceClient102024Preview +from azure.ai.ml._restclient.v2024_10_01_preview.models import ManagedNetworkProvisionOptions +from azure.ai.ml._scope_dependent_operations import OperationsContainer, OperationScope +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._http_utils import HttpPipeline +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml._utils.utils import ( + _get_workspace_base_url, + get_resource_and_group_name_from_resource_id, + get_resource_group_name_from_resource_group_id, + modified_operation_client, +) +from azure.ai.ml.constants._common import AzureMLResourceType, Scope, WorkspaceKind +from azure.ai.ml.entities import ( + DiagnoseRequestProperties, + DiagnoseResponseResult, + DiagnoseResponseResultValue, + DiagnoseWorkspaceParameters, + ManagedNetworkProvisionStatus, + Workspace, + WorkspaceKeys, +) +from azure.ai.ml.entities._credentials import IdentityConfiguration +from azure.core.credentials import TokenCredential +from azure.core.exceptions import HttpResponseError +from azure.core.polling import LROPoller +from azure.core.tracing.decorator import distributed_trace + +from ._workspace_operations_base import WorkspaceOperationsBase + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class WorkspaceOperations(WorkspaceOperationsBase): + """Handles workspaces and its subclasses, hubs and projects. + + You should not instantiate this class directly. Instead, you should create + an MLClient instance that instantiates it for you and attaches it as an attribute. + """ + + def __init__( + self, + operation_scope: OperationScope, + service_client: ServiceClient102024Preview, + all_operations: OperationsContainer, + credentials: Optional[TokenCredential] = None, + **kwargs: Any, + ): + self.dataplane_workspace_operations = ( + kwargs.pop("dataplane_client").workspaces if kwargs.get("dataplane_client") else None + ) + self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline", None) + ops_logger.update_filter() + self._provision_network_operation = service_client.managed_network_provisions + super().__init__( + operation_scope=operation_scope, + service_client=service_client, + all_operations=all_operations, + credentials=credentials, + **kwargs, + ) + + @monitor_with_activity(ops_logger, "Workspace.List", ActivityType.PUBLICAPI) + def list( + self, *, scope: str = Scope.RESOURCE_GROUP, filtered_kinds: Optional[Union[str, List[str]]] = None + ) -> Iterable[Workspace]: + """List all Workspaces that the user has access to in the current resource group or subscription. + + :keyword scope: scope of the listing, "resource_group" or "subscription", defaults to "resource_group" + :paramtype scope: str + :keyword filtered_kinds: The kinds of workspaces to list. If not provided, all workspaces varieties will + be listed. Accepts either a single kind, or a list of them. + Valid kind options include: "default", "project", and "hub". + :return: An iterator like instance of Workspace objects + :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.Workspace] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START workspace_list] + :end-before: [END workspace_list] + :language: python + :dedent: 8 + :caption: List the workspaces by resource group or subscription. + """ + + # Kind should be converted to a comma-separating string if multiple values are supplied. + formatted_kinds = filtered_kinds + if filtered_kinds and not isinstance(filtered_kinds, str): + formatted_kinds = ",".join(filtered_kinds) # type: ignore[arg-type] + + if scope == Scope.SUBSCRIPTION: + return cast( + Iterable[Workspace], + self._operation.list_by_subscription( + cls=lambda objs: [Workspace._from_rest_object(obj) for obj in objs], + kind=formatted_kinds, + ), + ) + return cast( + Iterable[Workspace], + self._operation.list_by_resource_group( + self._resource_group_name, + cls=lambda objs: [Workspace._from_rest_object(obj) for obj in objs], + kind=formatted_kinds, + ), + ) + + @monitor_with_activity(ops_logger, "Workspace.Get", ActivityType.PUBLICAPI) + @distributed_trace + # pylint: disable=arguments-renamed + def get(self, name: Optional[str] = None, **kwargs: Dict) -> Optional[Workspace]: + """Get a Workspace by name. + + :param name: Name of the workspace. + :type name: str + :return: The workspace with the provided name. + :rtype: ~azure.ai.ml.entities.Workspace + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START workspace_get] + :end-before: [END workspace_get] + :language: python + :dedent: 8 + :caption: Get the workspace with the given name. + """ + + return super().get(workspace_name=name, **kwargs) + + @monitor_with_activity(ops_logger, "Workspace.Get_Keys", ActivityType.PUBLICAPI) + @distributed_trace + def get_keys(self, name: Optional[str] = None) -> Optional[WorkspaceKeys]: + """Get WorkspaceKeys by workspace name. + + :param name: Name of the workspace. + :type name: str + :return: Keys of workspace dependent resources. + :rtype: ~azure.ai.ml.entities.WorkspaceKeys + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START workspace_get_keys] + :end-before: [END workspace_get_keys] + :language: python + :dedent: 8 + :caption: Get the workspace keys for the workspace with the given name. + """ + workspace_name = self._check_workspace_name(name) + obj = self._operation.list_keys(self._resource_group_name, workspace_name) + return WorkspaceKeys._from_rest_object(obj) + + @monitor_with_activity(ops_logger, "Workspace.BeginSyncKeys", ActivityType.PUBLICAPI) + @distributed_trace + def begin_sync_keys(self, name: Optional[str] = None) -> LROPoller[None]: + """Triggers the workspace to immediately synchronize keys. If keys for any resource in the workspace are + changed, it can take around an hour for them to automatically be updated. This function enables keys to be + updated upon request. An example scenario is needing immediate access to storage after regenerating storage + keys. + + :param name: Name of the workspace. + :type name: str + :return: An instance of LROPoller that returns either None or the sync keys result. + :rtype: ~azure.core.polling.LROPoller[None] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START workspace_sync_keys] + :end-before: [END workspace_sync_keys] + :language: python + :dedent: 8 + :caption: Begin sync keys for the workspace with the given name. + """ + workspace_name = self._check_workspace_name(name) + return self._operation.begin_resync_keys(self._resource_group_name, workspace_name) + + @monitor_with_activity(ops_logger, "Workspace.BeginProvisionNetwork", ActivityType.PUBLICAPI) + @distributed_trace + def begin_provision_network( + self, + *, + workspace_name: Optional[str] = None, + include_spark: bool = False, + **kwargs: Any, + ) -> LROPoller[ManagedNetworkProvisionStatus]: + """Triggers the workspace to provision the managed network. Specifying spark enabled + as true prepares the workspace managed network for supporting Spark. + + :keyword workspace_name: Name of the workspace. + :paramtype workspace_name: str + :keyword include_spark: Whether the workspace managed network should prepare to support Spark. + :paramtype include_space: bool + :return: An instance of LROPoller. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.ManagedNetworkProvisionStatus] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START workspace_provision_network] + :end-before: [END workspace_provision_network] + :language: python + :dedent: 8 + :caption: Begin provision network for a workspace with managed network. + """ + workspace_name = self._check_workspace_name(workspace_name) + + poller = self._provision_network_operation.begin_provision_managed_network( + self._resource_group_name, + workspace_name, + ManagedNetworkProvisionOptions(include_spark=include_spark), + polling=True, + cls=lambda response, deserialized, headers: ManagedNetworkProvisionStatus._from_rest_object(deserialized), + **kwargs, + ) + module_logger.info("Provision network request initiated for workspace: %s\n", workspace_name) + return poller + + @monitor_with_activity(ops_logger, "Workspace.BeginCreate", ActivityType.PUBLICAPI) + @distributed_trace + # pylint: disable=arguments-differ + def begin_create( + self, + workspace: Workspace, + update_dependent_resources: bool = False, + **kwargs: Any, + ) -> LROPoller[Workspace]: + """Create a new Azure Machine Learning Workspace. + + Returns the workspace if already exists. + + :param workspace: Workspace definition. + :type workspace: ~azure.ai.ml.entities.Workspace + :param update_dependent_resources: Whether to update dependent resources, defaults to False. + :type update_dependent_resources: boolean + :return: An instance of LROPoller that returns a Workspace. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.Workspace] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START workspace_begin_create] + :end-before: [END workspace_begin_create] + :language: python + :dedent: 8 + :caption: Begin create for a workspace. + """ + # Add hub values to project if possible + if workspace._kind == WorkspaceKind.PROJECT: + try: + parent_name = workspace._hub_id.split("/")[-1] if workspace._hub_id else "" + parent = self.get(parent_name) + if parent: + # Project location can not differ from hub, so try to force match them if possible. + workspace.location = parent.location + # Project's technically don't save their PNA, since it implicitly matches their parent's. + # However, some PNA-dependent code is run server-side before that alignment is made, so make sure + # they're aligned before the request hits the server. + workspace.public_network_access = parent.public_network_access + except HttpResponseError: + module_logger.warning("Failed to get parent hub for project, some values won't be transferred:") + try: + return super().begin_create(workspace, update_dependent_resources=update_dependent_resources, **kwargs) + except HttpResponseError as error: + if error.status_code == 403 and workspace._kind == WorkspaceKind.PROJECT: + resource_group = kwargs.get("resource_group") or self._resource_group_name + hub_name, _ = get_resource_and_group_name_from_resource_id(workspace._hub_id) + rest_workspace_obj = self._operation.get(resource_group, hub_name) + hub_default_project_resource_group = get_resource_group_name_from_resource_group_id( + rest_workspace_obj.workspace_hub_config.default_workspace_resource_group + ) + # we only want to try joining the workspaceHub when the default workspace resource group + # is same with the user provided resource group. + if hub_default_project_resource_group == resource_group: + log_msg = ( + "User lacked permission to create project workspace," + + "trying to join the workspaceHub default resource group." + ) + module_logger.info(log_msg) + return self._begin_join(workspace, **kwargs) + raise error + + @monitor_with_activity(ops_logger, "Workspace.BeginUpdate", ActivityType.PUBLICAPI) + @distributed_trace + def begin_update( + self, + workspace: Workspace, + *, + update_dependent_resources: bool = False, + **kwargs: Any, + ) -> LROPoller[Workspace]: + """Updates a Azure Machine Learning Workspace. + + :param workspace: Workspace definition. + :type workspace: ~azure.ai.ml.entities.Workspace + :keyword update_dependent_resources: Whether to update dependent resources, defaults to False. + :paramtype update_dependent_resources: boolean + :return: An instance of LROPoller that returns a Workspace. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.Workspace] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START workspace_begin_update] + :end-before: [END workspace_begin_update] + :language: python + :dedent: 8 + :caption: Begin update for a workspace. + """ + return super().begin_update(workspace, update_dependent_resources=update_dependent_resources, **kwargs) + + @monitor_with_activity(ops_logger, "Workspace.BeginDelete", ActivityType.PUBLICAPI) + @distributed_trace + def begin_delete( + self, name: str, *, delete_dependent_resources: bool, permanently_delete: bool = False, **kwargs: Dict + ) -> LROPoller[None]: + """Delete a workspace. + + :param name: Name of the workspace + :type name: str + :keyword delete_dependent_resources: Whether to delete resources associated with the workspace, + i.e., container registry, storage account, key vault, application insights, log analytics. + The default is False. Set to True to delete these resources. + :paramtype delete_dependent_resources: bool + :keyword permanently_delete: Workspaces are soft-deleted by default to allow recovery of workspace data. + Set this flag to true to override the soft-delete behavior and permanently delete your workspace. + :paramtype permanently_delete: bool + :return: A poller to track the operation status. + :rtype: ~azure.core.polling.LROPoller[None] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START workspace_begin_delete] + :end-before: [END workspace_begin_delete] + :language: python + :dedent: 8 + :caption: Begin permanent (force) deletion for a workspace and delete dependent resources. + """ + return super().begin_delete( + name, delete_dependent_resources=delete_dependent_resources, permanently_delete=permanently_delete, **kwargs + ) + + @distributed_trace + @monitor_with_activity(ops_logger, "Workspace.BeginDiagnose", ActivityType.PUBLICAPI) + def begin_diagnose(self, name: str, **kwargs: Dict) -> LROPoller[DiagnoseResponseResultValue]: + """Diagnose workspace setup problems. + + If your workspace is not working as expected, you can run this diagnosis to + check if the workspace has been broken. + For private endpoint workspace, it will also help check if the network + setup to this workspace and its dependent resource has problems or not. + + :param name: Name of the workspace + :type name: str + :return: A poller to track the operation status. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.DiagnoseResponseResultValue] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START workspace_begin_diagnose] + :end-before: [END workspace_begin_diagnose] + :language: python + :dedent: 8 + :caption: Begin diagnose operation for a workspace. + """ + resource_group = kwargs.get("resource_group") or self._resource_group_name + parameters = DiagnoseWorkspaceParameters(value=DiagnoseRequestProperties())._to_rest_object() + + # pylint: disable=unused-argument, docstring-missing-param + def callback(_: Any, deserialized: Any, args: Any) -> Optional[DiagnoseResponseResultValue]: + """Callback to be called after completion + + :return: DiagnoseResponseResult deserialized. + :rtype: ~azure.ai.ml.entities.DiagnoseResponseResult + """ + diagnose_response_result = DiagnoseResponseResult._from_rest_object(deserialized) + res = None + if diagnose_response_result: + res = diagnose_response_result.value + return res + + poller = self._operation.begin_diagnose(resource_group, name, parameters, polling=True, cls=callback) + module_logger.info("Diagnose request initiated for workspace: %s\n", name) + return poller + + @distributed_trace + def _begin_join(self, workspace: Workspace, **kwargs: Dict) -> LROPoller[Workspace]: + """Join a WorkspaceHub by creating a project workspace under workspaceHub's default resource group. + + :param workspace: Project workspace definition to create + :type workspace: Workspace + :return: An instance of LROPoller that returns a project Workspace. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.Workspace] + """ + if not workspace._hub_id: + raise ValidationError( + "{0} is not a Project workspace, join operation can only perform with workspaceHub provided".format( + workspace.name + ) + ) + + resource_group = kwargs.get("resource_group") or self._resource_group_name + hub_name, _ = get_resource_and_group_name_from_resource_id(workspace._hub_id) + rest_workspace_obj = self._operation.get(resource_group, hub_name) + + # override the location to the same as the workspaceHub + workspace.location = rest_workspace_obj.location + # set this to system assigned so ARM will create MSI + if not hasattr(workspace, "identity") or not workspace.identity: + workspace.identity = IdentityConfiguration(type="system_assigned") + + workspace_operations = self._all_operations.all_operations[AzureMLResourceType.WORKSPACE] + workspace_base_uri = _get_workspace_base_url(workspace_operations, hub_name, self._requests_pipeline) + + # pylint:disable=unused-argument + def callback(_: Any, deserialized: Any, args: Any) -> Optional[Workspace]: + return Workspace._from_rest_object(deserialized) + + with modified_operation_client(self.dataplane_workspace_operations, workspace_base_uri): + result = self.dataplane_workspace_operations.begin_hub_join( # type: ignore + resource_group_name=resource_group, + workspace_name=hub_name, + project_workspace_name=workspace.name, + body=workspace._to_rest_object(), + cls=callback, + **self._init_kwargs, + ) + return result diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_operations_base.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_operations_base.py new file mode 100644 index 00000000..1a3293fe --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_operations_base.py @@ -0,0 +1,1167 @@ +# pylint: disable=too-many-lines +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=protected-access + +import time +from abc import ABC +from typing import Any, Callable, Dict, MutableMapping, Optional, Tuple + +from azure.ai.ml._arm_deployments import ArmDeploymentExecutor +from azure.ai.ml._arm_deployments.arm_helper import get_template +from azure.ai.ml._restclient.v2024_10_01_preview import AzureMachineLearningWorkspaces as ServiceClient102024Preview +from azure.ai.ml._restclient.v2024_10_01_preview.models import ( + EncryptionKeyVaultUpdateProperties, + EncryptionUpdateProperties, + WorkspaceUpdateParameters, +) +from azure.ai.ml._scope_dependent_operations import OperationsContainer, OperationScope +from azure.ai.ml._utils._appinsights_utils import get_log_analytics_arm_id + +# from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml._utils._workspace_utils import ( + delete_resource_by_arm_id, + get_deployment_name, + get_generic_arm_resource_by_arm_id, + get_name_for_dependent_resource, + get_resource_and_group_name, + get_resource_group_location, + get_sub_id_resource_and_group_name, +) +from azure.ai.ml._utils.utils import camel_to_snake, from_iso_duration_format_min_sec +from azure.ai.ml._version import VERSION +from azure.ai.ml.constants import ManagedServiceIdentityType +from azure.ai.ml.constants._common import ( + WORKSPACE_PATCH_REJECTED_KEYS, + ArmConstants, + LROConfigurations, + WorkspaceKind, + WorkspaceResourceConstants, +) +from azure.ai.ml.constants._workspace import IsolationMode, OutboundRuleCategory +from azure.ai.ml.entities import Hub, Project, Workspace +from azure.ai.ml.entities._credentials import IdentityConfiguration +from azure.ai.ml.entities._workspace._ai_workspaces._constants import ENDPOINT_AI_SERVICE_KIND +from azure.ai.ml.entities._workspace.network_acls import NetworkAcls +from azure.ai.ml.entities._workspace.networking import ManagedNetwork +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException +from azure.core.credentials import TokenCredential +from azure.core.polling import LROPoller, PollingMethod + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class WorkspaceOperationsBase(ABC): + """Base class for WorkspaceOperations, can't be instantiated directly.""" + + def __init__( + self, + operation_scope: OperationScope, + service_client: ServiceClient102024Preview, + all_operations: OperationsContainer, + credentials: Optional[TokenCredential] = None, + **kwargs: Dict, + ): + ops_logger.update_filter() + self._subscription_id = operation_scope.subscription_id + self._resource_group_name = operation_scope.resource_group_name + self._default_workspace_name = operation_scope.workspace_name + self._all_operations = all_operations + self._operation = service_client.workspaces + self._credentials = credentials + self._init_kwargs = kwargs + self.containerRegistry = "none" + + def get(self, workspace_name: Optional[str] = None, **kwargs: Any) -> Optional[Workspace]: + """Get a Workspace by name. + + :param workspace_name: Name of the workspace. + :type workspace_name: str + :return: The workspace with the provided name. + :rtype: ~azure.ai.ml.entities.Workspace + """ + workspace_name = self._check_workspace_name(workspace_name) + resource_group = kwargs.get("resource_group") or self._resource_group_name + obj = self._operation.get(resource_group, workspace_name) + v2_service_context = {} + + v2_service_context["subscription_id"] = self._subscription_id + v2_service_context["workspace_name"] = workspace_name + v2_service_context["resource_group_name"] = resource_group + v2_service_context["auth"] = self._credentials # type: ignore + + from urllib.parse import urlparse + + if obj is not None and obj.ml_flow_tracking_uri: + parsed_url = urlparse(obj.ml_flow_tracking_uri) + host_url = "https://{}".format(parsed_url.netloc) + v2_service_context["host_url"] = host_url + else: + v2_service_context["host_url"] = "" + + # host_url=service_context._get_mlflow_url(), + # cloud=_get_cloud_or_default( + # service_context.get_auth()._cloud_type.name + if obj is not None and obj.kind is not None and obj.kind.lower() == WorkspaceKind.HUB: + return Hub._from_rest_object(obj, v2_service_context) + if obj is not None and obj.kind is not None and obj.kind.lower() == WorkspaceKind.PROJECT: + return Project._from_rest_object(obj, v2_service_context) + return Workspace._from_rest_object(obj, v2_service_context) + + def begin_create( + self, + workspace: Workspace, + update_dependent_resources: bool = False, + get_callback: Optional[Callable[[], Workspace]] = None, + **kwargs: Any, + ) -> LROPoller[Workspace]: + """Create a new Azure Machine Learning Workspace. + + Returns the workspace if already exists. + + :param workspace: Workspace definition. + :type workspace: ~azure.ai.ml.entities.Workspace + :param update_dependent_resources: Whether to update dependent resources, defaults to False. + :type update_dependent_resources: boolean + :param get_callback: A callable function to call at the end of operation. + :type get_callback: Optional[Callable[[], ~azure.ai.ml.entities.Workspace]] + :return: An instance of LROPoller that returns a Workspace. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.Workspace] + :raises ~azure.ai.ml.ValidationException: Raised if workspace is Project workspace and user + specifies any of the following in workspace object: storage_account, container_registry, key_vault, + public_network_access, managed_network, customer_managed_key, system_datastores_auth_mode. + """ + existing_workspace = None + resource_group = kwargs.get("resource_group") or workspace.resource_group or self._resource_group_name + endpoint_resource_id = kwargs.pop("endpoint_resource_id", "") + endpoint_kind = kwargs.pop("endpoint_kind", ENDPOINT_AI_SERVICE_KIND) + + try: + existing_workspace = self.get(workspace.name, resource_group=resource_group) + except Exception: # pylint: disable=W0718 + pass + + # idempotent behavior + if existing_workspace: + if workspace.tags is not None and workspace.tags.get("createdByToolkit") is not None: + workspace.tags.pop("createdByToolkit") + if existing_workspace.tags is not None: + existing_workspace.tags.update(workspace.tags) # type: ignore + workspace.tags = existing_workspace.tags + # TODO do we want projects to do this? + if workspace._kind != WorkspaceKind.PROJECT: + workspace.container_registry = workspace.container_registry or existing_workspace.container_registry + workspace.application_insights = ( + workspace.application_insights or existing_workspace.application_insights + ) + workspace.identity = workspace.identity or existing_workspace.identity + workspace.primary_user_assigned_identity = ( + workspace.primary_user_assigned_identity or existing_workspace.primary_user_assigned_identity + ) + workspace._feature_store_settings = ( + workspace._feature_store_settings or existing_workspace._feature_store_settings + ) + return self.begin_update( + workspace, + update_dependent_resources=update_dependent_resources, + **kwargs, + ) + # add tag in the workspace to indicate which sdk version the workspace is created from + if workspace.tags is None: + workspace.tags = {} + if workspace.tags.get("createdByToolkit") is None: + workspace.tags["createdByToolkit"] = "sdk-v2-{}".format(VERSION) + + workspace.resource_group = resource_group + ( + template, + param, + resources_being_deployed, + ) = self._populate_arm_parameters( + workspace, + endpoint_resource_id=endpoint_resource_id, + endpoint_kind=endpoint_kind, + **kwargs, + ) + # check if create with workspace hub request is valid + if workspace._kind == WorkspaceKind.PROJECT: + if not all( + x is None + for x in [ + workspace.storage_account, + workspace.container_registry, + workspace.key_vault, + ] + ): + msg = ( + "To create a project workspace with a workspace hub, please only specify name, description, " + + "display_name, location, application insight and identity" + ) + raise ValidationException( + message=msg, + target=ErrorTarget.WORKSPACE, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + + arm_submit = ArmDeploymentExecutor( + credentials=self._credentials, + resource_group_name=resource_group, + subscription_id=self._subscription_id, + deployment_name=get_deployment_name(workspace.name), + ) + + # deploy_resource() blocks for the poller to succeed if wait is True + poller = arm_submit.deploy_resource( + template=template, + resources_being_deployed=resources_being_deployed, + parameters=param, + wait=False, + ) + + def callback() -> Optional[Workspace]: + """Callback to be called after completion + + :return: Result of calling appropriate callback. + :rtype: Any + """ + return get_callback() if get_callback else self.get(workspace.name, resource_group=resource_group) + + real_callback = callback + injected_callback = kwargs.get("cls", None) + if injected_callback: + # pylint: disable=function-redefined + def real_callback() -> Any: + """Callback to be called after completion + + :return: Result of calling appropriate callback. + :rtype: Any + """ + return injected_callback(callback()) + + return LROPoller( + self._operation._client, + None, + lambda *x, **y: None, + CustomArmTemplateDeploymentPollingMethod(poller, arm_submit, real_callback), + ) + + # pylint: disable=too-many-statements,too-many-locals + def begin_update( + self, + workspace: Workspace, + *, + update_dependent_resources: bool = False, + deserialize_callback: Optional[Callable] = None, + **kwargs: Any, + ) -> LROPoller[Workspace]: + """Updates a Azure Machine Learning Workspace. + + :param workspace: Workspace resource. + :type workspace: ~azure.ai.ml.entities.Workspace + :keyword update_dependent_resources: Whether to update dependent resources, defaults to False. + :paramtype update_dependent_resources: boolean + :keyword deserialize_callback: A callable function to call at the end of operation. + :paramtype deserialize_callback: Optional[Callable[[], ~azure.ai.ml.entities.Workspace]] + :return: An instance of LROPoller that returns a Workspace. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.Workspace] + :raises ~azure.ai.ml.ValidationException: Raised if updating container_registry for a workspace + and update_dependent_resources is not True. + :raises ~azure.ai.ml.ValidationException: Raised if updating application_insights for a workspace + and update_dependent_resources is not True. + """ + identity = kwargs.get("identity", workspace.identity) + workspace_name = kwargs.get("workspace_name", workspace.name) + resource_group = kwargs.get("resource_group") or workspace.resource_group or self._resource_group_name + existing_workspace: Any = self.get(workspace_name, **kwargs) + if identity: + identity = identity._to_workspace_rest_object() + rest_user_assigned_identities = identity.user_assigned_identities + # add the uai resource_id which needs to be deleted (which is not provided in the list) + if ( + existing_workspace + and existing_workspace.identity + and existing_workspace.identity.user_assigned_identities + ): + if rest_user_assigned_identities is None: + rest_user_assigned_identities = {} + for uai in existing_workspace.identity.user_assigned_identities: + if uai.resource_id not in rest_user_assigned_identities: + rest_user_assigned_identities[uai.resource_id] = None + identity.user_assigned_identities = rest_user_assigned_identities + + managed_network = kwargs.get("managed_network", workspace.managed_network) + if isinstance(managed_network, str): + managed_network = ManagedNetwork(isolation_mode=managed_network)._to_rest_object() + elif isinstance(managed_network, ManagedNetwork): + if managed_network.outbound_rules is not None: + # drop recommended and required rules from the update request since it would result in bad request + managed_network.outbound_rules = [ + rule + for rule in managed_network.outbound_rules + if rule.category not in (OutboundRuleCategory.REQUIRED, OutboundRuleCategory.RECOMMENDED) + ] + managed_network = managed_network._to_rest_object() + + container_registry = kwargs.get("container_registry", workspace.container_registry) + # Empty string is for erasing the value of container_registry, None is to be ignored value + if ( + container_registry is not None + and container_registry != existing_workspace.container_registry + and not update_dependent_resources + ): + msg = ( + "Updating the workspace-attached Azure Container Registry resource may break lineage of " + "previous jobs or your ability to rerun earlier jobs in this workspace. " + "Are you sure you want to perform this operation? " + "Include the update_dependent_resources argument in SDK or the " + "--update-dependent-resources/-u parameter in CLI with this request to confirm." + ) + raise ValidationException( + message=msg, + target=ErrorTarget.WORKSPACE, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + application_insights = kwargs.get("application_insights", workspace.application_insights) + # Empty string is for erasing the value of application_insights, None is to be ignored value + if ( + application_insights is not None + and application_insights != existing_workspace.application_insights + and not update_dependent_resources + ): + msg = ( + "Updating the workspace-attached Azure Application Insights resource may break lineage " + "of deployed inference endpoints this workspace. Are you sure you want to perform this " + "operation? " + "Include the update_dependent_resources argument in SDK or the " + "--update-dependent-resources/-u parameter in CLI with this request to confirm." + ) + raise ValidationException( + message=msg, + target=ErrorTarget.WORKSPACE, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + + feature_store_settings = kwargs.get("feature_store_settings", workspace._feature_store_settings) + if feature_store_settings: + feature_store_settings = feature_store_settings._to_rest_object() + + serverless_compute_settings = kwargs.get("serverless_compute", workspace.serverless_compute) + if serverless_compute_settings: + serverless_compute_settings = serverless_compute_settings._to_rest_object() + + public_network_access = kwargs.get("public_network_access", workspace.public_network_access) + network_acls = kwargs.get("network_acls", workspace.network_acls) + if network_acls: + network_acls = network_acls._to_rest_object() # pylint: disable=protected-access + + if public_network_access == "Disabled" or ( + existing_workspace + and existing_workspace.public_network_access == "Disabled" + and public_network_access is None + ): + network_acls = NetworkAcls()._to_rest_object() # pylint: disable=protected-access + + update_param = WorkspaceUpdateParameters( + tags=kwargs.get("tags", workspace.tags), + description=kwargs.get("description", workspace.description), + friendly_name=kwargs.get("display_name", workspace.display_name), + public_network_access=kwargs.get("public_network_access", workspace.public_network_access), + system_datastores_auth_mode=kwargs.get( + "system_datastores_auth_mode", workspace.system_datastores_auth_mode + ), + allow_role_assignment_on_rg=kwargs.get( + "allow_roleassignment_on_rg", workspace.allow_roleassignment_on_rg + ), # diff due to swagger restclient casing diff + image_build_compute=kwargs.get("image_build_compute", workspace.image_build_compute), + identity=identity, + primary_user_assigned_identity=kwargs.get( + "primary_user_assigned_identity", workspace.primary_user_assigned_identity + ), + managed_network=managed_network, + feature_store_settings=feature_store_settings, + network_acls=network_acls, + ) + if serverless_compute_settings: + update_param.serverless_compute_settings = serverless_compute_settings + update_param.container_registry = container_registry or None + update_param.application_insights = application_insights or None + + # Only the key uri property of customer_managed_key can be updated. + # Check if user is updating CMK key uri, if so, add to update_param + if workspace.customer_managed_key is not None and workspace.customer_managed_key.key_uri is not None: + customer_managed_key_uri = workspace.customer_managed_key.key_uri + update_param.encryption = EncryptionUpdateProperties( + key_vault_properties=EncryptionKeyVaultUpdateProperties( + key_identifier=customer_managed_key_uri, + ) + ) + + update_role_assignment = ( + kwargs.get("update_workspace_role_assignment", None) + or kwargs.get("update_offline_store_role_assignment", None) + or kwargs.get("update_online_store_role_assignment", None) + ) + grant_materialization_permissions = kwargs.get("grant_materialization_permissions", None) + + # Remove deprecated keys from older workspaces that might still have them before we try to update. + if workspace.tags is not None: + for bad_key in WORKSPACE_PATCH_REJECTED_KEYS: + _ = workspace.tags.pop(bad_key, None) + + # pylint: disable=unused-argument, docstring-missing-param + def callback(_: Any, deserialized: Any, args: Any) -> Optional[Workspace]: + """Callback to be called after completion + + :return: Workspace deserialized. + :rtype: ~azure.ai.ml.entities.Workspace + """ + if ( + workspace._kind + and workspace._kind.lower() == "featurestore" + and update_role_assignment + and grant_materialization_permissions + ): + module_logger.info("updating feature store materialization identity role assignments..") + template, param, resources_being_deployed = self._populate_feature_store_role_assignment_parameters( + workspace, resource_group=resource_group, location=existing_workspace.location, **kwargs + ) + + arm_submit = ArmDeploymentExecutor( + credentials=self._credentials, + resource_group_name=resource_group, + subscription_id=self._subscription_id, + deployment_name=get_deployment_name(workspace.name), + ) + + # deploy_resource() blocks for the poller to succeed if wait is True + poller = arm_submit.deploy_resource( + template=template, + resources_being_deployed=resources_being_deployed, + parameters=param, + wait=False, + ) + + poller.result() + return ( + deserialize_callback(deserialized) + if deserialize_callback + else Workspace._from_rest_object(deserialized) + ) + + real_callback = callback + injected_callback = kwargs.get("cls", None) + if injected_callback: + # pylint: disable=function-redefined, docstring-missing-param + def real_callback(_: Any, deserialized: Any, args: Any) -> Any: + """Callback to be called after completion + + :return: Result of calling appropriate callback. + :rtype: Any + """ + return injected_callback(callback(_, deserialized, args)) + + poller = self._operation.begin_update( + resource_group, workspace_name, update_param, polling=True, cls=real_callback + ) + return poller + + def begin_delete( + self, name: str, *, delete_dependent_resources: bool, permanently_delete: bool = False, **kwargs: Any + ) -> LROPoller[None]: + """Delete a Workspace. + + :param name: Name of the Workspace + :type name: str + :keyword delete_dependent_resources: Whether to delete resources associated with the Workspace, + i.e., container registry, storage account, key vault, application insights, log analytics. + The default is False. Set to True to delete these resources. + :paramtype delete_dependent_resources: bool + :keyword permanently_delete: Workspaces are soft-deleted by default to allow recovery of workspace data. + Set this flag to true to override the soft-delete behavior and permanently delete your workspace. + :paramtype permanently_delete: bool + :return: A poller to track the operation status. + :rtype: ~azure.core.polling.LROPoller[None] + """ + workspace: Any = self.get(name, **kwargs) + resource_group = kwargs.get("resource_group") or self._resource_group_name + + # prevent dependent resource delete for lean workspace, only delete appinsight and associated log analytics + if workspace._kind == WorkspaceKind.PROJECT and delete_dependent_resources: + app_insights = get_generic_arm_resource_by_arm_id( + self._credentials, + self._subscription_id, + workspace.application_insights, + ArmConstants.AZURE_MGMT_APPINSIGHT_API_VERSION, + ) + if app_insights is not None and "WorkspaceResourceId" in app_insights.properties: + delete_resource_by_arm_id( + self._credentials, + self._subscription_id, + app_insights.properties["WorkspaceResourceId"], + ArmConstants.AZURE_MGMT_LOGANALYTICS_API_VERSION, + ) + delete_resource_by_arm_id( + self._credentials, + self._subscription_id, + workspace.application_insights, + ArmConstants.AZURE_MGMT_APPINSIGHT_API_VERSION, + ) + elif delete_dependent_resources: + app_insights = get_generic_arm_resource_by_arm_id( + self._credentials, + self._subscription_id, + workspace.application_insights, + ArmConstants.AZURE_MGMT_APPINSIGHT_API_VERSION, + ) + if app_insights is not None and "WorkspaceResourceId" in app_insights.properties: + delete_resource_by_arm_id( + self._credentials, + self._subscription_id, + app_insights.properties["WorkspaceResourceId"], + ArmConstants.AZURE_MGMT_LOGANALYTICS_API_VERSION, + ) + delete_resource_by_arm_id( + self._credentials, + self._subscription_id, + workspace.application_insights, + ArmConstants.AZURE_MGMT_APPINSIGHT_API_VERSION, + ) + delete_resource_by_arm_id( + self._credentials, + self._subscription_id, + workspace.storage_account, + ArmConstants.AZURE_MGMT_STORAGE_API_VERSION, + ) + delete_resource_by_arm_id( + self._credentials, + self._subscription_id, + workspace.key_vault, + ArmConstants.AZURE_MGMT_KEYVAULT_API_VERSION, + ) + delete_resource_by_arm_id( + self._credentials, + self._subscription_id, + workspace.container_registry, + ArmConstants.AZURE_MGMT_CONTAINER_REG_API_VERSION, + ) + + poller = self._operation.begin_delete( + resource_group_name=resource_group, + workspace_name=name, + force_to_purge=permanently_delete, + **self._init_kwargs, + ) + module_logger.info("Delete request initiated for workspace: %s\n", name) + return poller + + # pylint: disable=too-many-statements,too-many-branches,too-many-locals + def _populate_arm_parameters(self, workspace: Workspace, **kwargs: Any) -> Tuple[dict, dict, dict]: + """Populates ARM template parameters for use to deploy a workspace resource. + + :param workspace: Workspace resource. + :type workspace: ~azure.ai.ml.entities.Workspace + :return: A tuple of three dicts: an ARM template, ARM template parameters, resources_being_deployed. + :rtype: Tuple[dict, dict, dict] + """ + resources_being_deployed: Dict = {} + if not workspace.location: + workspace.location = get_resource_group_location( + self._credentials, self._subscription_id, workspace.resource_group + ) + template = get_template(resource_type=ArmConstants.WORKSPACE_BASE) + param = get_template(resource_type=ArmConstants.WORKSPACE_PARAM) + if workspace._kind == WorkspaceKind.PROJECT: + template = get_template(resource_type=ArmConstants.WORKSPACE_PROJECT) + endpoint_resource_id = kwargs.get("endpoint_resource_id") or "" + endpoint_kind = kwargs.get("endpoint_kind") or ENDPOINT_AI_SERVICE_KIND + _set_val(param["workspaceName"], workspace.name) + if not workspace.display_name: + _set_val(param["friendlyName"], workspace.name) + else: + _set_val(param["friendlyName"], workspace.display_name) + + if not workspace.description: + _set_val(param["description"], workspace.name) + else: + _set_val(param["description"], workspace.description) + _set_val(param["location"], workspace.location) + + if not workspace._kind: + _set_val(param["kind"], "default") + else: + _set_val(param["kind"], workspace._kind) + + _set_val(param["resourceGroupName"], workspace.resource_group) + + if workspace.key_vault: + resource_name, group_name = get_resource_and_group_name(workspace.key_vault) + _set_val(param["keyVaultName"], resource_name) + _set_val(param["keyVaultOption"], "existing") + _set_val(param["keyVaultResourceGroupName"], group_name) + else: + key_vault = _generate_key_vault(workspace.name, resources_being_deployed) + _set_val(param["keyVaultName"], key_vault) + _set_val( + param["keyVaultResourceGroupName"], + workspace.resource_group, + ) + + if workspace.storage_account: + subscription_id, resource_name, group_name = get_sub_id_resource_and_group_name(workspace.storage_account) + _set_val(param["storageAccountName"], resource_name) + _set_val(param["storageAccountOption"], "existing") + _set_val(param["storageAccountResourceGroupName"], group_name) + _set_val(param["storageAccountSubscriptionId"], subscription_id) + else: + storage = _generate_storage(workspace.name, resources_being_deployed) + _set_val(param["storageAccountName"], storage) + _set_val( + param["storageAccountResourceGroupName"], + workspace.resource_group, + ) + _set_val( + param["storageAccountSubscriptionId"], + self._subscription_id, + ) + + if workspace.application_insights: + resource_name, group_name = get_resource_and_group_name(workspace.application_insights) + _set_val(param["applicationInsightsName"], resource_name) + _set_val(param["applicationInsightsOption"], "existing") + _set_val( + param["applicationInsightsResourceGroupName"], + group_name, + ) + elif workspace._kind and workspace._kind.lower() in {WorkspaceKind.HUB, WorkspaceKind.PROJECT}: + _set_val(param["applicationInsightsOption"], "none") + # Set empty values because arm templates whine over unset values. + _set_val(param["applicationInsightsName"], "ignoredButCantBeEmpty") + _set_val( + param["applicationInsightsResourceGroupName"], + "ignoredButCantBeEmpty", + ) + else: + log_analytics = _generate_log_analytics(workspace.name, resources_being_deployed) + _set_val(param["logAnalyticsName"], log_analytics) + _set_val( + param["logAnalyticsArmId"], + get_log_analytics_arm_id(self._subscription_id, self._resource_group_name, log_analytics), + ) + + app_insights = _generate_app_insights(workspace.name, resources_being_deployed) + _set_val(param["applicationInsightsName"], app_insights) + _set_val( + param["applicationInsightsResourceGroupName"], + workspace.resource_group, + ) + + if workspace.container_registry: + resource_name, group_name = get_resource_and_group_name(workspace.container_registry) + _set_val(param["containerRegistryName"], resource_name) + _set_val(param["containerRegistryOption"], "existing") + _set_val(param["containerRegistryResourceGroupName"], group_name) + + if workspace.customer_managed_key: + _set_val(param["cmk_keyvault"], workspace.customer_managed_key.key_vault) + _set_val(param["resource_cmk_uri"], workspace.customer_managed_key.key_uri) + _set_val( + param["encryption_status"], + WorkspaceResourceConstants.ENCRYPTION_STATUS_ENABLED, + ) + _set_val( + param["encryption_cosmosdb_resourceid"], + workspace.customer_managed_key.cosmosdb_id, + ) + _set_val( + param["encryption_storage_resourceid"], + workspace.customer_managed_key.storage_id, + ) + _set_val( + param["encryption_search_resourceid"], + workspace.customer_managed_key.search_id, + ) + + if workspace.hbi_workspace: + _set_val(param["confidential_data"], "true") + + if workspace.public_network_access: + _set_val(param["publicNetworkAccess"], workspace.public_network_access) + _set_val(param["associatedResourcePNA"], workspace.public_network_access) + + if workspace.system_datastores_auth_mode: + _set_val(param["systemDatastoresAuthMode"], workspace.system_datastores_auth_mode) + + if workspace.allow_roleassignment_on_rg is False: + _set_val(param["allowRoleAssignmentOnRG"], "false") + + if workspace.image_build_compute: + _set_val(param["imageBuildCompute"], workspace.image_build_compute) + + if workspace.tags: + for key, val in workspace.tags.items(): + param["tagValues"]["value"][key] = val + + identity = None + if workspace.identity: + identity = workspace.identity._to_workspace_rest_object() + else: + identity = IdentityConfiguration( + type=camel_to_snake(ManagedServiceIdentityType.SYSTEM_ASSIGNED) + )._to_workspace_rest_object() + _set_val(param["identity"], identity) + + if workspace.primary_user_assigned_identity: + _set_val(param["primaryUserAssignedIdentity"], workspace.primary_user_assigned_identity) + + if workspace._feature_store_settings: + _set_val( + param["spark_runtime_version"], workspace._feature_store_settings.compute_runtime.spark_runtime_version + ) + if workspace._feature_store_settings.offline_store_connection_name: + _set_val( + param["offline_store_connection_name"], + workspace._feature_store_settings.offline_store_connection_name, + ) + if workspace._feature_store_settings.online_store_connection_name: + _set_val( + param["online_store_connection_name"], + workspace._feature_store_settings.online_store_connection_name, + ) + + if workspace._kind and workspace._kind.lower() == "featurestore": + materialization_identity = kwargs.get("materialization_identity", None) + offline_store_target = kwargs.get("offline_store_target", None) + online_store_target = kwargs.get("online_store_target", None) + + from azure.ai.ml._utils._arm_id_utils import AzureResourceId, AzureStorageContainerResourceId + + # set workspace storage account access auth type to identity-based + _set_val(param["systemDatastoresAuthMode"], "identity") + + if offline_store_target: + arm_id = AzureStorageContainerResourceId(offline_store_target) + _set_val(param["offlineStoreStorageAccountOption"], "existing") + _set_val(param["offline_store_container_name"], arm_id.container) + _set_val(param["offline_store_storage_account_name"], arm_id.storage_account) + _set_val(param["offline_store_resource_group_name"], arm_id.resource_group_name) + _set_val(param["offline_store_subscription_id"], arm_id.subscription_id) + else: + _set_val(param["offlineStoreStorageAccountOption"], "new") + _set_val( + param["offline_store_container_name"], + _generate_storage_container(workspace.name, resources_being_deployed), + ) + if not workspace.storage_account: + _set_val(param["offline_store_storage_account_name"], param["storageAccountName"]["value"]) + else: + _set_val( + param["offline_store_storage_account_name"], + _generate_storage(workspace.name, resources_being_deployed), + ) + _set_val(param["offline_store_resource_group_name"], workspace.resource_group) + _set_val(param["offline_store_subscription_id"], self._subscription_id) + + if online_store_target: + arm_id = AzureResourceId(online_store_target) + _set_val(param["online_store_resource_id"], online_store_target) + _set_val(param["online_store_resource_group_name"], arm_id.resource_group_name) + _set_val(param["online_store_subscription_id"], arm_id.subscription_id) + + if materialization_identity: + arm_id = AzureResourceId(materialization_identity.resource_id) + _set_val(param["materializationIdentityOption"], "existing") + _set_val(param["materialization_identity_name"], arm_id.asset_name) + _set_val(param["materialization_identity_resource_group_name"], arm_id.resource_group_name) + _set_val(param["materialization_identity_subscription_id"], arm_id.subscription_id) + else: + _set_val(param["materializationIdentityOption"], "new") + _set_val( + param["materialization_identity_name"], + _generate_materialization_identity(workspace, self._subscription_id, resources_being_deployed), + ) + _set_val(param["materialization_identity_resource_group_name"], workspace.resource_group) + _set_val(param["materialization_identity_subscription_id"], self._subscription_id) + + if not kwargs.get("grant_materialization_permissions", None): + _set_val(param["grant_materialization_permissions"], "false") + + if workspace.provision_network_now: + _set_val(param["provisionNetworkNow"], "true") + + managed_network = None + if workspace.managed_network: + managed_network = workspace.managed_network._to_rest_object() + if workspace.managed_network.isolation_mode in [ + IsolationMode.ALLOW_INTERNET_OUTBOUND, + IsolationMode.ALLOW_ONLY_APPROVED_OUTBOUND, + ]: + _set_val(param["associatedResourcePNA"], "Disabled") + else: + managed_network = ManagedNetwork(isolation_mode=IsolationMode.DISABLED)._to_rest_object() + _set_obj_val(param["managedNetwork"], managed_network) + if workspace.enable_data_isolation: + _set_val(param["enable_data_isolation"], "true") + + if workspace._kind and workspace._kind.lower() == WorkspaceKind.HUB: + _set_obj_val(param["workspace_hub_config"], workspace._hub_values_to_rest_object()) # type: ignore + # A user-supplied resource ID (either AOAI or AI Services or null) + # endpoint_kind differentiates between a 'Bring a legacy AOAI resource hub' and 'any other kind of hub' + # The former doesn't create non-AOAI endpoints, and is set below if the user provided a byo AOAI + # resource ID. The latter case is the default and not shown here. + if endpoint_resource_id != "": + _set_val(param["endpoint_resource_id"], endpoint_resource_id) + _set_val(param["endpoint_kind"], endpoint_kind) + + # Lean related param + if ( + hasattr(workspace, "_kind") + and workspace._kind is not None + and workspace._kind.lower() == WorkspaceKind.PROJECT + ): + if hasattr(workspace, "_hub_id"): + _set_val(param["workspace_hub"], workspace._hub_id) + + # Serverless compute related param + serverless_compute = workspace.serverless_compute if workspace.serverless_compute else None + if serverless_compute: + _set_obj_val(param["serverless_compute_settings"], serverless_compute._to_rest_object()) + + resources_being_deployed[workspace.name] = (ArmConstants.WORKSPACE, None) + return template, param, resources_being_deployed + + def _populate_feature_store_role_assignment_parameters( + self, workspace: Workspace, **kwargs: Any + ) -> Tuple[dict, dict, dict]: + """Populates ARM template parameters for use to update feature store materialization identity role assignments. + + :param workspace: Workspace resource. + :type workspace: ~azure.ai.ml.entities.Workspace + :return: A tuple of three dicts: an ARM template, ARM template parameters, resources_being_deployed. + :rtype: Tuple[dict, dict, dict] + """ + resources_being_deployed = {} + template = get_template(resource_type=ArmConstants.FEATURE_STORE_ROLE_ASSIGNMENTS) + param = get_template(resource_type=ArmConstants.FEATURE_STORE_ROLE_ASSIGNMENTS_PARAM) + + materialization_identity_id = kwargs.get("materialization_identity_id", None) + if materialization_identity_id: + _set_val(param["materialization_identity_resource_id"], materialization_identity_id) + + _set_val(param["workspace_name"], workspace.name) + resource_group = kwargs.get("resource_group", workspace.resource_group) + _set_val(param["resource_group_name"], resource_group) + location = kwargs.get("location", workspace.location) + _set_val(param["location"], location) + + update_workspace_role_assignment = kwargs.get("update_workspace_role_assignment", None) + if update_workspace_role_assignment: + _set_val(param["update_workspace_role_assignment"], "true") + update_offline_store_role_assignment = kwargs.get("update_offline_store_role_assignment", None) + if update_offline_store_role_assignment: + _set_val(param["update_offline_store_role_assignment"], "true") + update_online_store_role_assignment = kwargs.get("update_online_store_role_assignment", None) + if update_online_store_role_assignment: + _set_val(param["update_online_store_role_assignment"], "true") + + offline_store_target = kwargs.get("offline_store_target", None) + online_store_target = kwargs.get("online_store_target", None) + + from azure.ai.ml._utils._arm_id_utils import AzureResourceId + + if offline_store_target: + arm_id = AzureResourceId(offline_store_target) + _set_val(param["offline_store_target"], offline_store_target) + _set_val(param["offline_store_resource_group_name"], arm_id.resource_group_name) + _set_val(param["offline_store_subscription_id"], arm_id.subscription_id) + + if online_store_target: + arm_id = AzureResourceId(online_store_target) + _set_val(param["online_store_target"], online_store_target) + _set_val(param["online_store_resource_group_name"], arm_id.resource_group_name) + _set_val(param["online_store_subscription_id"], arm_id.subscription_id) + + resources_being_deployed[materialization_identity_id] = (ArmConstants.USER_ASSIGNED_IDENTITIES, None) + return template, param, resources_being_deployed + + def _check_workspace_name(self, name: Optional[str]) -> str: + """Validates that a workspace name exists. + + :param name: Name for a workspace resource. + :type name: str + :return: No Return. + :rtype: None + :raises ~azure.ai.ml.ValidationException: Raised if updating nothing is specified for name and + MLClient does not have workspace name set. + """ + workspace_name = name or self._default_workspace_name + if not workspace_name: + msg = "Please provide a workspace name or use a MLClient with a workspace name set." + raise ValidationException( + message=msg, + target=ErrorTarget.WORKSPACE, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + return workspace_name + + +def _set_val(dict: dict, val: Optional[str]) -> None: + """Sets the value of a reference in parameters dict to a certain value. + + :param dict: Dict for a certain parameter. + :type dict: dict + :param val: The value to set for "value" in the passed in dict. + :type val: str + :return: No Return. + :rtype: None + """ + dict["value"] = val + + +def _set_obj_val(dict: dict, val: Any) -> None: + """Serializes a JSON string into the parameters dict. + + :param dict: Parameters dict. + :type dict: dict + :param val: The obj to serialize. + :type val: Any type. Must have `.serialize() -> MutableMapping[str, Any]` method. + :return: No Return. + :rtype: None + """ + from copy import deepcopy + + json: MutableMapping[str, Any] = val.serialize() + dict["value"] = deepcopy(json) + + +def _generate_key_vault(name: Optional[str], resources_being_deployed: dict) -> str: + """Generates a name for a key vault resource to be created with workspace based on workspace name, + sets name and type in resources_being_deployed. + + :param name: The name for the related workspace. + :type name: str + :param resources_being_deployed: Dict for resources being deployed. + :type resources_being_deployed: dict + :return: String for name of key vault. + :rtype: str + """ + # Vault name must only contain alphanumeric characters and dashes and cannot start with a number. + # Vault name must be between 3-24 alphanumeric characters. + # The name must begin with a letter, end with a letter or digit, and not contain consecutive hyphens. + key_vault = get_name_for_dependent_resource(name, "keyvault") + resources_being_deployed[key_vault] = (ArmConstants.KEY_VAULT, None) + return str(key_vault) + + +def _generate_storage(name: Optional[str], resources_being_deployed: dict) -> str: + """Generates a name for a storage account resource to be created with workspace based on workspace name, + sets name and type in resources_being_deployed. + + :param name: The name for the related workspace. + :type name: str + :param resources_being_deployed: Dict for resources being deployed. + :type resources_being_deployed: dict + :return: String for name of storage account. + :rtype: str + """ + storage = get_name_for_dependent_resource(name, "storage") + resources_being_deployed[storage] = (ArmConstants.STORAGE, None) + return str(storage) + + +def _generate_storage_container(name: Optional[str], resources_being_deployed: dict) -> str: + """Generates a name for a storage container resource to be created with workspace based on workspace name, + sets name and type in resources_being_deployed. + + :param name: The name for the related workspace. + :type name: str + :param resources_being_deployed: Dict for resources being deployed. + :type resources_being_deployed: dict + :return: String for name of storage container + :rtype: str + """ + storage_container = get_name_for_dependent_resource(name, "container") + resources_being_deployed[storage_container] = (ArmConstants.STORAGE_CONTAINER, None) + return str(storage_container) + + +def _generate_log_analytics(name: Optional[str], resources_being_deployed: dict) -> str: + """Generates a name for a log analytics resource to be created with workspace based on workspace name, + sets name and type in resources_being_deployed. + + :param name: The name for the related workspace. + :type name: str + :param resources_being_deployed: Dict for resources being deployed. + :type resources_being_deployed: dict + :return: String for name of log analytics. + :rtype: str + """ + log_analytics = get_name_for_dependent_resource(name, "logalytics") # cspell:disable-line + resources_being_deployed[log_analytics] = ( + ArmConstants.LOG_ANALYTICS, + None, + ) + return str(log_analytics) + + +def _generate_app_insights(name: Optional[str], resources_being_deployed: dict) -> str: + """Generates a name for an application insights resource to be created with workspace based on workspace name, + sets name and type in resources_being_deployed. + + :param name: The name for the related workspace. + :type name: str + :param resources_being_deployed: Dict for resources being deployed. + :type resources_being_deployed: dict + :return: String for name of app insights. + :rtype: str + """ + # Application name only allows alphanumeric characters, periods, underscores, + # hyphens and parenthesis and cannot end in a period + app_insights = get_name_for_dependent_resource(name, "insights") + resources_being_deployed[app_insights] = ( + ArmConstants.APP_INSIGHTS, + None, + ) + return str(app_insights) + + +def _generate_container_registry(name: Optional[str], resources_being_deployed: dict) -> str: + """Generates a name for a container registry resource to be created with workspace based on workspace name, + sets name and type in resources_being_deployed. + + :param name: The name for the related workspace. + :type name: str + :param resources_being_deployed: Dict for resources being deployed. + :type resources_being_deployed: dict + :return: String for name of container registry. + :rtype: str + """ + # Application name only allows alphanumeric characters, periods, underscores, + # hyphens and parenthesis and cannot end in a period + con_reg = get_name_for_dependent_resource(name, "containerRegistry") + resources_being_deployed[con_reg] = ( + ArmConstants.CONTAINER_REGISTRY, + None, + ) + return str(con_reg) + + +def _generate_materialization_identity( + workspace: Workspace, subscription_id: str, resources_being_deployed: dict +) -> str: + """Generates a name for a materialization identity resource to be created + with feature store based on workspace information, + sets name and type in resources_being_deployed. + + :param workspace: The workspace object. + :type workspace: Workspace + :param subscription_id: The subscription id + :type subscription_id: str + :param resources_being_deployed: Dict for resources being deployed. + :type resources_being_deployed: dict + :return: String for name of materialization identity. + :rtype: str + """ + import uuid + + namespace = "" + namespace_raw = f"{subscription_id[:12]}_{str(workspace.resource_group)[:12]}_{workspace.location}" + for char in namespace_raw.lower(): + if char.isalpha() or char.isdigit(): + namespace = namespace + char + namespace = namespace.encode("utf-8").hex() + uuid_namespace = uuid.UUID(namespace[:32].ljust(32, "0")) + materialization_identity = f"materialization-uai-" f"{uuid.uuid3(uuid_namespace, str(workspace.name).lower()).hex}" + resources_being_deployed[materialization_identity] = ( + ArmConstants.USER_ASSIGNED_IDENTITIES, + None, + ) + return materialization_identity + + +class CustomArmTemplateDeploymentPollingMethod(PollingMethod): + """A custom polling method for ARM template deployment used internally for workspace creation.""" + + def __init__(self, poller: Any, arm_submit: Any, func: Any) -> None: + self.poller = poller + self.arm_submit = arm_submit + self.func = func + super().__init__() + + def resource(self) -> Any: + """ + Polls for the resource creation completing every so often with ability to cancel deployment and outputs + either error or executes function to "deserialize" result. + + :return: The response from polling result and calling func from CustomArmTemplateDeploymentPollingMethod + :rtype: Any + """ + error: Any = None + try: + while not self.poller.done(): + try: + time.sleep(LROConfigurations.SLEEP_TIME) + self.arm_submit._check_deployment_status() + except KeyboardInterrupt as e: + self.arm_submit._client.close() + error = e + raise + + if self.poller._exception is not None: + error = self.poller._exception + except Exception as e: # pylint: disable=W0718 + error = e + finally: + # one last check to make sure all print statements make it + if not isinstance(error, KeyboardInterrupt): + self.arm_submit._check_deployment_status() + total_duration = self.poller.result().properties.duration + + if error is not None: + error_msg = f"Unable to create resource. \n {error}\n" + module_logger.error(error_msg) + raise error + module_logger.info( + "Total time : %s\n", from_iso_duration_format_min_sec(total_duration) # pylint: disable=E0606 + ) + return self.func() + + # pylint: disable=docstring-missing-param + def initialize(self, *args: Any, **kwargs: Any) -> None: + """ + unused stub overridden from ABC + + :return: No return. + :rtype: ~azure.ai.ml.entities.OutboundRule + """ + + def finished(self) -> None: + """ + unused stub overridden from ABC + + :return: No return. + :rtype: ~azure.ai.ml.entities.OutboundRule + """ + + def run(self) -> None: + """ + unused stub overridden from ABC + + :return: No return. + :rtype: ~azure.ai.ml.entities.OutboundRule + """ + + def status(self) -> None: + """ + unused stub overridden from ABC + + :return: No return. + :rtype: ~azure.ai.ml.entities.OutboundRule + """ diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_outbound_rule_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_outbound_rule_operations.py new file mode 100644 index 00000000..e7fde295 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_outbound_rule_operations.py @@ -0,0 +1,245 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Any, Dict, Iterable, Optional + +from azure.ai.ml._restclient.v2024_10_01_preview import AzureMachineLearningWorkspaces as ServiceClient102024Preview +from azure.ai.ml._restclient.v2024_10_01_preview.models import OutboundRuleBasicResource +from azure.ai.ml._scope_dependent_operations import OperationsContainer, OperationScope +from azure.ai.ml._telemetry import ActivityType, monitor_with_activity +from azure.ai.ml._utils._logger_utils import OpsLogger +from azure.ai.ml.entities._workspace.networking import OutboundRule +from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException +from azure.core.credentials import TokenCredential +from azure.core.polling import LROPoller + +ops_logger = OpsLogger(__name__) +module_logger = ops_logger.module_logger + + +class WorkspaceOutboundRuleOperations: + """WorkspaceOutboundRuleOperations. + + You should not instantiate this class directly. Instead, you should create + an MLClient instance that instantiates it for you and attaches it as an attribute. + """ + + def __init__( + self, + operation_scope: OperationScope, + service_client: ServiceClient102024Preview, + all_operations: OperationsContainer, + credentials: TokenCredential = None, + **kwargs: Dict, + ): + ops_logger.update_filter() + self._subscription_id = operation_scope.subscription_id + self._resource_group_name = operation_scope.resource_group_name + self._default_workspace_name = operation_scope.workspace_name + self._all_operations = all_operations + self._rule_operation = service_client.managed_network_settings_rule + self._credentials = credentials + self._init_kwargs = kwargs + + @monitor_with_activity(ops_logger, "WorkspaceOutboundRule.Get", ActivityType.PUBLICAPI) + def get(self, workspace_name: str, outbound_rule_name: str, **kwargs: Any) -> OutboundRule: + """Get a workspace OutboundRule by name. + + :param workspace_name: Name of the workspace. + :type workspace_name: str + :param outbound_rule_name: Name of the outbound rule. + :type outbound_rule_name: str + :return: The OutboundRule with the provided name for the workspace. + :rtype: ~azure.ai.ml.entities.OutboundRule + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START outbound_rule_get] + :end-before: [END outbound_rule_get] + :language: python + :dedent: 8 + :caption: Get the outbound rule for a workspace with the given name. + """ + + workspace_name = self._check_workspace_name(workspace_name) + resource_group = kwargs.get("resource_group") or self._resource_group_name + + obj = self._rule_operation.get(resource_group, workspace_name, outbound_rule_name) + # pylint: disable=protected-access + res: OutboundRule = OutboundRule._from_rest_object(obj.properties, name=obj.name) # type: ignore + return res + + @monitor_with_activity(ops_logger, "WorkspaceOutboundRule.BeginCreate", ActivityType.PUBLICAPI) + def begin_create(self, workspace_name: str, rule: OutboundRule, **kwargs: Any) -> LROPoller[OutboundRule]: + """Create a Workspace OutboundRule. + + :param workspace_name: Name of the workspace. + :type workspace_name: str + :param rule: OutboundRule definition (FqdnDestination, PrivateEndpointDestination, or ServiceTagDestination). + :type rule: ~azure.ai.ml.entities.OutboundRule + :return: An instance of LROPoller that returns an OutboundRule. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.OutboundRule] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START outbound_rule_begin_create] + :end-before: [END outbound_rule_begin_create] + :language: python + :dedent: 8 + :caption: Create an FQDN outbound rule for a workspace with the given name, + similar can be done for PrivateEndpointDestination or ServiceTagDestination. + """ + + workspace_name = self._check_workspace_name(workspace_name) + resource_group = kwargs.get("resource_group") or self._resource_group_name + + # pylint: disable=protected-access + rule_params = OutboundRuleBasicResource(properties=rule._to_rest_object()) # type: ignore + + # pylint: disable=unused-argument, docstring-missing-param + def callback(_: Any, deserialized: Any, args: Any) -> Optional[OutboundRule]: + """Callback to be called after completion + + :return: Outbound rule deserialized. + :rtype: ~azure.ai.ml.entities.OutboundRule + """ + properties = deserialized.properties + name = deserialized.name + return OutboundRule._from_rest_object(properties, name=name) # pylint: disable=protected-access + + poller = self._rule_operation.begin_create_or_update( + resource_group, workspace_name, rule.name, rule_params, polling=True, cls=callback + ) + module_logger.info("Create request initiated for outbound rule with name: %s\n", rule.name) + return poller + + @monitor_with_activity(ops_logger, "WorkspaceOutboundRule.BeginUpdate", ActivityType.PUBLICAPI) + def begin_update(self, workspace_name: str, rule: OutboundRule, **kwargs: Any) -> LROPoller[OutboundRule]: + """Update a Workspace OutboundRule. + + :param workspace_name: Name of the workspace. + :type workspace_name: str + :param rule: OutboundRule definition (FqdnDestination, PrivateEndpointDestination, or ServiceTagDestination). + :type rule: ~azure.ai.ml.entities.OutboundRule + :return: An instance of LROPoller that returns an OutboundRule. + :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.OutboundRule] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START outbound_rule_begin_update] + :end-before: [END outbound_rule_begin_update] + :language: python + :dedent: 8 + :caption: Update an FQDN outbound rule for a workspace with the given name, + similar can be done for PrivateEndpointDestination or ServiceTagDestination. + """ + + workspace_name = self._check_workspace_name(workspace_name) + resource_group = kwargs.get("resource_group") or self._resource_group_name + + # pylint: disable=protected-access + rule_params = OutboundRuleBasicResource(properties=rule._to_rest_object()) # type: ignore + + # pylint: disable=unused-argument, docstring-missing-param + def callback(_: Any, deserialized: Any, args: Any) -> Optional[OutboundRule]: + """Callback to be called after completion + + :return: Outbound rule deserialized. + :rtype: ~azure.ai.ml.entities.OutboundRule + """ + properties = deserialized.properties + name = deserialized.name + return OutboundRule._from_rest_object(properties, name=name) # pylint: disable=protected-access + + poller = self._rule_operation.begin_create_or_update( + resource_group, workspace_name, rule.name, rule_params, polling=True, cls=callback + ) + module_logger.info("Update request initiated for outbound rule with name: %s\n", rule.name) + return poller + + @monitor_with_activity(ops_logger, "WorkspaceOutboundRule.List", ActivityType.PUBLICAPI) + def list(self, workspace_name: str, **kwargs: Any) -> Iterable[OutboundRule]: + """List Workspace OutboundRules. + + :param workspace_name: Name of the workspace. + :type workspace_name: str + :return: An Iterable of OutboundRule. + :rtype: Iterable[OutboundRule] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START outbound_rule_list] + :end-before: [END outbound_rule_list] + :language: python + :dedent: 8 + :caption: List the outbound rule for a workspace with the given name. + """ + + workspace_name = self._check_workspace_name(workspace_name) + resource_group = kwargs.get("resource_group") or self._resource_group_name + + rest_rules = self._rule_operation.list(resource_group, workspace_name) + + result = [ + OutboundRule._from_rest_object(rest_obj=obj.properties, name=obj.name) # pylint: disable=protected-access + for obj in rest_rules + ] + return result # type: ignore + + @monitor_with_activity(ops_logger, "WorkspaceOutboundRule.Remove", ActivityType.PUBLICAPI) + def begin_remove(self, workspace_name: str, outbound_rule_name: str, **kwargs: Any) -> LROPoller[None]: + """Remove a Workspace OutboundRule. + + :param workspace_name: Name of the workspace. + :type workspace_name: str + :param outbound_rule_name: Name of the outbound rule to remove. + :type outbound_rule_name: str + :return: An Iterable of OutboundRule. + :rtype: Iterable[OutboundRule] + + .. admonition:: Example: + + .. literalinclude:: ../samples/ml_samples_workspace.py + :start-after: [START outbound_rule_begin_remove] + :end-before: [END outbound_rule_begin_remove] + :language: python + :dedent: 8 + :caption: Remove the outbound rule for a workspace with the given name. + """ + + workspace_name = self._check_workspace_name(workspace_name) + resource_group = kwargs.get("resource_group") or self._resource_group_name + + poller = self._rule_operation.begin_delete( + resource_group_name=resource_group, + workspace_name=workspace_name, + rule_name=outbound_rule_name, + ) + module_logger.info("Delete request initiated for outbound rule: %s\n", outbound_rule_name) + return poller + + def _check_workspace_name(self, name: str) -> str: + """Validates that a workspace name exists. + + :param name: Name for a workspace resource. + :type name: str + :raises ~azure.ai.ml.ValidationException: Raised if updating nothing is specified for name and + MLClient does not have workspace name set. + :return: No return + :rtype: None + """ + workspace_name = name or self._default_workspace_name + if not workspace_name: + msg = "Please provide a workspace name or use a MLClient with a workspace name set." + raise ValidationException( + message=msg, + target=ErrorTarget.WORKSPACE, + no_personal_data_message=msg, + error_category=ErrorCategory.USER_ERROR, + ) + return workspace_name |