aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/azure/ai/ml/operations
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/azure/ai/ml/operations')
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/__init__.py65
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_azure_openai_deployment_operations.py60
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_batch_deployment_operations.py392
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_batch_endpoint_operations.py553
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_capability_hosts_operations.py304
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_code_operations.py307
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_component_operations.py1289
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_compute_operations.py447
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_data_operations.py891
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_dataset_dataplane_operations.py32
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_datastore_operations.py329
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_environment_operations.py569
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_evaluator_operations.py222
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_set_operations.py456
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_store_entity_operations.py191
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_store_operations.py566
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py483
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_job_operations.py1677
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_job_ops_helper.py513
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_deployment_helper.py390
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_endpoint_helper.py205
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_job_invoker.py432
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_marketplace_subscription_operations.py122
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_model_dataplane_operations.py32
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_model_operations.py833
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_online_deployment_operations.py415
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_online_endpoint_operations.py471
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_operation_orchestrator.py571
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_registry_operations.py168
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_run_history_constants.py82
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_run_operations.py94
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_schedule_operations.py608
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_serverless_endpoint_operations.py223
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_virtual_cluster_operations.py174
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_connections_operations.py189
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_operations.py443
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_operations_base.py1167
-rw-r--r--.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_outbound_rule_operations.py245
38 files changed, 16210 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/__init__.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/__init__.py
new file mode 100644
index 00000000..4508bd95
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/__init__.py
@@ -0,0 +1,65 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+"""Contains supported operations for Azure Machine Learning SDKv2.
+
+Operations are classes contain logic to interact with backend services, usually auto generated operations call.
+"""
+__path__ = __import__("pkgutil").extend_path(__path__, __name__)
+
+
+from ._azure_openai_deployment_operations import AzureOpenAIDeploymentOperations
+from ._batch_deployment_operations import BatchDeploymentOperations
+from ._batch_endpoint_operations import BatchEndpointOperations
+from ._component_operations import ComponentOperations
+from ._compute_operations import ComputeOperations
+from ._data_operations import DataOperations
+from ._datastore_operations import DatastoreOperations
+from ._environment_operations import EnvironmentOperations
+from ._feature_set_operations import FeatureSetOperations
+from ._feature_store_entity_operations import FeatureStoreEntityOperations
+from ._feature_store_operations import FeatureStoreOperations
+from ._index_operations import IndexOperations
+from ._job_operations import JobOperations
+from ._model_operations import ModelOperations
+from ._online_deployment_operations import OnlineDeploymentOperations
+from ._online_endpoint_operations import OnlineEndpointOperations
+from ._registry_operations import RegistryOperations
+from ._schedule_operations import ScheduleOperations
+from ._workspace_connections_operations import WorkspaceConnectionsOperations
+from ._workspace_operations import WorkspaceOperations
+from ._workspace_outbound_rule_operations import WorkspaceOutboundRuleOperations
+from ._evaluator_operations import EvaluatorOperations
+from ._serverless_endpoint_operations import ServerlessEndpointOperations
+from ._marketplace_subscription_operations import MarketplaceSubscriptionOperations
+from ._capability_hosts_operations import CapabilityHostsOperations
+
+__all__ = [
+ "ComputeOperations",
+ "DatastoreOperations",
+ "JobOperations",
+ "ModelOperations",
+ "EvaluatorOperations",
+ "WorkspaceOperations",
+ "RegistryOperations",
+ "OnlineEndpointOperations",
+ "BatchEndpointOperations",
+ "OnlineDeploymentOperations",
+ "BatchDeploymentOperations",
+ "DataOperations",
+ "EnvironmentOperations",
+ "ComponentOperations",
+ "WorkspaceConnectionsOperations",
+ "RegistryOperations",
+ "ScheduleOperations",
+ "WorkspaceOutboundRuleOperations",
+ "FeatureSetOperations",
+ "FeatureStoreEntityOperations",
+ "FeatureStoreOperations",
+ "ServerlessEndpointOperations",
+ "MarketplaceSubscriptionOperations",
+ "IndexOperations",
+ "AzureOpenAIDeploymentOperations",
+ "CapabilityHostsOperations",
+]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_azure_openai_deployment_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_azure_openai_deployment_operations.py
new file mode 100644
index 00000000..542448a9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_azure_openai_deployment_operations.py
@@ -0,0 +1,60 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import logging
+from typing import Iterable
+
+from azure.ai.ml._restclient.v2024_04_01_preview import AzureMachineLearningWorkspaces as ServiceClient2020404Preview
+from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations
+from azure.ai.ml.entities._autogen_entities.models import AzureOpenAIDeployment
+
+from ._workspace_connections_operations import WorkspaceConnectionsOperations
+
+module_logger = logging.getLogger(__name__)
+
+
+class AzureOpenAIDeploymentOperations(_ScopeDependentOperations):
+ """AzureOpenAIDeploymentOperations.
+
+ You should not instantiate this class directly. Instead, you should
+ create an MLClient instance that instantiates it for you and
+ attaches it as an attribute.
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client: ServiceClient2020404Preview,
+ connections_operations: WorkspaceConnectionsOperations,
+ ):
+ super().__init__(operation_scope, operation_config)
+ self._service_client = service_client.connection
+ self._workspace_connections_operations = connections_operations
+
+ def list(self, connection_name: str, **kwargs) -> Iterable[AzureOpenAIDeployment]:
+ """List Azure OpenAI deployments of the workspace.
+
+ :param connection_name: Name of the connection from which to list deployments
+ :type connection_name: str
+ :return: A list of Azure OpenAI deployments
+ :rtype: ~typing.Iterable[~azure.ai.ml.entities.AzureOpenAIDeployment]
+ """
+ connection = self._workspace_connections_operations.get(connection_name)
+
+ def _from_rest_add_connection_name(obj):
+ from_rest_deployment = AzureOpenAIDeployment._from_rest_object(obj)
+ from_rest_deployment.connection_name = connection_name
+ from_rest_deployment.target_url = connection.target
+ return from_rest_deployment
+
+ return self._service_client.list_deployments(
+ self._resource_group_name,
+ self._workspace_name,
+ connection_name,
+ cls=lambda objs: [_from_rest_add_connection_name(obj) for obj in objs],
+ **kwargs,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_batch_deployment_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_batch_deployment_operations.py
new file mode 100644
index 00000000..9bde33c8
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_batch_deployment_operations.py
@@ -0,0 +1,392 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access, too-many-boolean-expressions
+
+import re
+from typing import Any, Optional, TypeVar, Union
+
+from azure.ai.ml._restclient.v2024_01_01_preview import AzureMachineLearningWorkspaces as ServiceClient012024Preview
+from azure.ai.ml._scope_dependent_operations import (
+ OperationConfig,
+ OperationsContainer,
+ OperationScope,
+ _ScopeDependentOperations,
+)
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._arm_id_utils import AMLVersionedArmId
+from azure.ai.ml._utils._azureml_polling import AzureMLPolling
+from azure.ai.ml._utils._endpoint_utils import upload_dependencies, validate_scoring_script
+from azure.ai.ml._utils._http_utils import HttpPipeline
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._utils._package_utils import package_deployment
+from azure.ai.ml._utils.utils import _get_mfe_base_url_from_discovery_service, modified_operation_client
+from azure.ai.ml.constants._common import ARM_ID_PREFIX, AzureMLResourceType, LROConfigurations
+from azure.ai.ml.entities import BatchDeployment, BatchJob, ModelBatchDeployment, PipelineComponent, PipelineJob
+from azure.ai.ml.entities._deployment.pipeline_component_batch_deployment import PipelineComponentBatchDeployment
+from azure.core.credentials import TokenCredential
+from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
+from azure.core.paging import ItemPaged
+from azure.core.polling import LROPoller
+from azure.core.tracing.decorator import distributed_trace
+
+from ._operation_orchestrator import OperationOrchestrator
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+DeploymentType = TypeVar(
+ "DeploymentType", bound=Union[BatchDeployment, PipelineComponentBatchDeployment, ModelBatchDeployment]
+)
+
+
+class BatchDeploymentOperations(_ScopeDependentOperations):
+ """BatchDeploymentOperations.
+
+ You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it
+ for you and attaches it as an attribute.
+
+ :param operation_scope: Scope variables for the operations classes of an MLClient object.
+ :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope
+ :param operation_config: Common configuration for operations classes of an MLClient object.
+ :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig
+ :param service_client_05_2022: Service client to allow end users to operate on Azure Machine Learning Workspace
+ resources.
+ :type service_client_05_2022: ~azure.ai.ml._restclient.v2022_05_01._azure_machine_learning_workspaces.
+ AzureMachineLearningWorkspaces
+ :param all_operations: All operations classes of an MLClient object.
+ :type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer
+ :param credentials: Credential to use for authentication.
+ :type credentials: ~azure.core.credentials.TokenCredential
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client_01_2024_preview: ServiceClient012024Preview,
+ all_operations: OperationsContainer,
+ credentials: Optional[TokenCredential] = None,
+ **kwargs: Any,
+ ):
+ super(BatchDeploymentOperations, self).__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+ self._batch_deployment = service_client_01_2024_preview.batch_deployments
+ self._batch_job_deployment = kwargs.pop("service_client_09_2020_dataplanepreview").batch_job_deployment
+ service_client_02_2023_preview = kwargs.pop("service_client_02_2023_preview")
+ self._component_batch_deployment_operations = service_client_02_2023_preview.batch_deployments
+ self._batch_endpoint_operations = service_client_01_2024_preview.batch_endpoints
+ self._component_operations = service_client_02_2023_preview.component_versions
+ self._all_operations = all_operations
+ self._credentials = credentials
+ self._init_kwargs = kwargs
+
+ self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline")
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "BatchDeployment.BeginCreateOrUpdate", ActivityType.PUBLICAPI)
+ def begin_create_or_update(
+ self,
+ deployment: DeploymentType,
+ *,
+ skip_script_validation: bool = False,
+ **kwargs: Any,
+ ) -> LROPoller[DeploymentType]:
+ """Create or update a batch deployment.
+
+ :param deployment: The deployment entity.
+ :type deployment: ~azure.ai.ml.entities.BatchDeployment
+ :keyword skip_script_validation: If set to True, the script validation will be skipped. Defaults to False.
+ :paramtype skip_script_validation: bool
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if BatchDeployment cannot be
+ successfully validated. Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.AssetException: Raised if BatchDeployment assets
+ (e.g. Data, Code, Model, Environment) cannot be successfully validated.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.ModelException: Raised if BatchDeployment model
+ cannot be successfully validated. Details will be provided in the error message.
+ :return: A poller to track the operation status.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.BatchDeployment]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START batch_deployment_operations_begin_create_or_update]
+ :end-before: [END batch_deployment_operations_begin_create_or_update]
+ :language: python
+ :dedent: 8
+ :caption: Create example.
+ """
+ if (
+ not skip_script_validation
+ and not isinstance(deployment, PipelineComponentBatchDeployment)
+ and deployment
+ and deployment.code_configuration # type: ignore
+ and not deployment.code_configuration.code.startswith(ARM_ID_PREFIX) # type: ignore
+ and not re.match(AMLVersionedArmId.REGEX_PATTERN, deployment.code_configuration.code) # type: ignore
+ ):
+ validate_scoring_script(deployment)
+ module_logger.debug("Checking endpoint %s exists", deployment.endpoint_name)
+ self._batch_endpoint_operations.get(
+ endpoint_name=deployment.endpoint_name,
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ )
+ orchestrators = OperationOrchestrator(
+ operation_container=self._all_operations,
+ operation_scope=self._operation_scope,
+ operation_config=self._operation_config,
+ )
+ if isinstance(deployment, PipelineComponentBatchDeployment):
+ self._validate_component(deployment, orchestrators) # type: ignore
+ else:
+ upload_dependencies(deployment, orchestrators)
+ try:
+ location = self._get_workspace_location()
+ if kwargs.pop("package_model", False):
+ deployment = package_deployment(deployment, self._all_operations.all_operations)
+ module_logger.info("\nStarting deployment")
+ deployment_rest = deployment._to_rest_object(location=location)
+ if isinstance(deployment, PipelineComponentBatchDeployment): # pylint: disable=no-else-return
+ return self._component_batch_deployment_operations.begin_create_or_update(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=deployment.endpoint_name,
+ deployment_name=deployment.name,
+ body=deployment_rest,
+ **self._init_kwargs,
+ cls=lambda response, deserialized, headers: PipelineComponentBatchDeployment._from_rest_object(
+ deserialized
+ ),
+ )
+ else:
+ return self._batch_deployment.begin_create_or_update(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=deployment.endpoint_name,
+ deployment_name=deployment.name,
+ body=deployment_rest,
+ **self._init_kwargs,
+ cls=lambda response, deserialized, headers: BatchDeployment._from_rest_object(deserialized),
+ )
+ except Exception as ex:
+ raise ex
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "BatchDeployment.Get", ActivityType.PUBLICAPI)
+ def get(self, name: str, endpoint_name: str) -> BatchDeployment:
+ """Get a deployment resource.
+
+ :param name: The name of the deployment
+ :type name: str
+ :param endpoint_name: The name of the endpoint
+ :type endpoint_name: str
+ :return: A deployment entity
+ :rtype: ~azure.ai.ml.entities.BatchDeployment
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START batch_deployment_operations_get]
+ :end-before: [END batch_deployment_operations_get]
+ :language: python
+ :dedent: 8
+ :caption: Get example.
+ """
+ deployment = BatchDeployment._from_rest_object(
+ self._batch_deployment.get(
+ endpoint_name=endpoint_name,
+ deployment_name=name,
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ **self._init_kwargs,
+ )
+ )
+ deployment.endpoint_name = endpoint_name
+ return deployment
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "BatchDeployment.BeginDelete", ActivityType.PUBLICAPI)
+ def begin_delete(self, name: str, endpoint_name: str) -> LROPoller[None]:
+ """Delete a batch deployment.
+
+ :param name: Name of the batch deployment.
+ :type name: str
+ :param endpoint_name: Name of the batch endpoint
+ :type endpoint_name: str
+ :return: A poller to track the operation status.
+ :rtype: ~azure.core.polling.LROPoller[None]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START batch_deployment_operations_delete]
+ :end-before: [END batch_deployment_operations_delete]
+ :language: python
+ :dedent: 8
+ :caption: Delete example.
+ """
+ path_format_arguments = {
+ "endpointName": name,
+ "resourceGroupName": self._resource_group_name,
+ "workspaceName": self._workspace_name,
+ }
+
+ delete_poller = self._batch_deployment.begin_delete(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=endpoint_name,
+ deployment_name=name,
+ polling=AzureMLPolling(
+ LROConfigurations.POLL_INTERVAL,
+ path_format_arguments=path_format_arguments,
+ **self._init_kwargs,
+ ),
+ polling_interval=LROConfigurations.POLL_INTERVAL,
+ **self._init_kwargs,
+ )
+ return delete_poller
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "BatchDeployment.List", ActivityType.PUBLICAPI)
+ def list(self, endpoint_name: str) -> ItemPaged[BatchDeployment]:
+ """List a deployment resource.
+
+ :param endpoint_name: The name of the endpoint
+ :type endpoint_name: str
+ :return: An iterator of deployment entities
+ :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.BatchDeployment]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START batch_deployment_operations_list]
+ :end-before: [END batch_deployment_operations_list]
+ :language: python
+ :dedent: 8
+ :caption: List deployment resource example.
+ """
+ return self._batch_deployment.list(
+ endpoint_name=endpoint_name,
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ cls=lambda objs: [BatchDeployment._from_rest_object(obj) for obj in objs],
+ **self._init_kwargs,
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "BatchDeployment.ListJobs", ActivityType.PUBLICAPI)
+ def list_jobs(self, endpoint_name: str, *, name: Optional[str] = None) -> ItemPaged[BatchJob]:
+ """List jobs under the provided batch endpoint deployment. This is only valid for batch endpoint.
+
+ :param endpoint_name: Name of endpoint.
+ :type endpoint_name: str
+ :keyword name: (Optional) Name of deployment.
+ :paramtype name: str
+ :raise: Exception if endpoint_type is not BATCH_ENDPOINT_TYPE
+ :return: List of jobs
+ :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.BatchJob]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START batch_deployment_operations_list_jobs]
+ :end-before: [END batch_deployment_operations_list_jobs]
+ :language: python
+ :dedent: 8
+ :caption: List jobs example.
+ """
+
+ workspace_operations = self._all_operations.all_operations[AzureMLResourceType.WORKSPACE]
+ mfe_base_uri = _get_mfe_base_url_from_discovery_service(
+ workspace_operations, self._workspace_name, self._requests_pipeline
+ )
+
+ with modified_operation_client(self._batch_job_deployment, mfe_base_uri):
+ result = self._batch_job_deployment.list(
+ endpoint_name=endpoint_name,
+ deployment_name=name,
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ **self._init_kwargs,
+ )
+
+ # This is necessary as the paged result need to be resolved inside the context manager
+ return list(result)
+
+ def _get_workspace_location(self) -> str:
+ """Get the workspace location
+
+ TODO[TASK 1260265]: can we cache this information and only refresh when the operation_scope is changed?
+
+ :return: The workspace location
+ :rtype: str
+ """
+ return str(
+ self._all_operations.all_operations[AzureMLResourceType.WORKSPACE].get(self._workspace_name).location
+ )
+
+ def _validate_component(self, deployment: Any, orchestrators: OperationOrchestrator) -> None:
+ """Validates that the value provided is associated to an existing component or otherwise we will try to create
+ an anonymous component that will be use for batch deployment.
+
+ :param deployment: Batch deployment
+ :type deployment: ~azure.ai.ml.entities._deployment.deployment.Deployment
+ :param orchestrators: Operation Orchestrator
+ :type orchestrators: _operation_orchestrator.OperationOrchestrator
+ """
+ if isinstance(deployment.component, PipelineComponent):
+ try:
+ registered_component = self._all_operations.all_operations[AzureMLResourceType.COMPONENT].get(
+ name=deployment.component.name, version=deployment.component.version
+ )
+ deployment.component = registered_component.id
+ except Exception as err: # pylint: disable=W0718
+ if isinstance(err, (ResourceNotFoundError, HttpResponseError)):
+ deployment.component = self._all_operations.all_operations[
+ AzureMLResourceType.COMPONENT
+ ].create_or_update(
+ name=deployment.component.name,
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ component=deployment.component,
+ version=deployment.component.version,
+ **self._init_kwargs,
+ )
+ else:
+ raise err
+ elif isinstance(deployment.component, str):
+ component_id = orchestrators.get_asset_arm_id(
+ deployment.component, azureml_type=AzureMLResourceType.COMPONENT
+ )
+ deployment.component = component_id
+ elif isinstance(deployment.job_definition, str):
+ job_component = PipelineComponent(source_job_id=deployment.job_definition)
+ job_component = self._component_operations.create_or_update(
+ name=job_component.name,
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ body=job_component._to_rest_object(),
+ version=job_component.version,
+ **self._init_kwargs,
+ )
+ deployment.component = job_component.id
+
+ elif isinstance(deployment.job_definition, PipelineJob):
+ try:
+ registered_job = self._all_operations.all_operations[AzureMLResourceType.JOB].get(
+ name=deployment.job_definition.name
+ )
+ if registered_job:
+ job_component = PipelineComponent(source_job_id=registered_job.name)
+ job_component = self._component_operations.create_or_update(
+ name=job_component.name,
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ body=job_component._to_rest_object(),
+ version=job_component.version,
+ **self._init_kwargs,
+ )
+ deployment.component = job_component.id
+ except ResourceNotFoundError as err:
+ raise err
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_batch_endpoint_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_batch_endpoint_operations.py
new file mode 100644
index 00000000..650185b3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_batch_endpoint_operations.py
@@ -0,0 +1,553 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import json
+import os
+import re
+import uuid
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Dict, Optional, cast
+
+from marshmallow.exceptions import ValidationError as SchemaValidationError
+
+from azure.ai.ml._artifacts._artifact_utilities import _upload_and_generate_remote_uri
+from azure.ai.ml._azure_environments import _get_aml_resource_id_from_metadata, _resource_to_scopes
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._restclient.v2020_09_01_dataplanepreview.models import BatchJobResource
+from azure.ai.ml._restclient.v2023_10_01 import AzureMachineLearningServices as ServiceClient102023
+from azure.ai.ml._schema._deployment.batch.batch_job import BatchJobSchema
+from azure.ai.ml._scope_dependent_operations import (
+ OperationConfig,
+ OperationsContainer,
+ OperationScope,
+ _ScopeDependentOperations,
+)
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._arm_id_utils import is_ARM_id_for_resource, remove_aml_prefix
+from azure.ai.ml._utils._azureml_polling import AzureMLPolling
+from azure.ai.ml._utils._endpoint_utils import convert_v1_dataset_to_v2, validate_response
+from azure.ai.ml._utils._http_utils import HttpPipeline
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._utils.utils import (
+ _get_mfe_base_url_from_discovery_service,
+ is_private_preview_enabled,
+ modified_operation_client,
+)
+from azure.ai.ml.constants._common import (
+ ARM_ID_FULL_PREFIX,
+ AZUREML_REGEX_FORMAT,
+ BASE_PATH_CONTEXT_KEY,
+ HTTP_PREFIX,
+ LONG_URI_REGEX_FORMAT,
+ PARAMS_OVERRIDE_KEY,
+ SHORT_URI_REGEX_FORMAT,
+ AssetTypes,
+ AzureMLResourceType,
+ InputTypes,
+ LROConfigurations,
+)
+from azure.ai.ml.constants._endpoint import EndpointInvokeFields, EndpointYamlFields
+from azure.ai.ml.entities import BatchEndpoint, BatchJob
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, MlException, ValidationErrorType, ValidationException
+from azure.core.credentials import TokenCredential
+from azure.core.exceptions import HttpResponseError, ServiceRequestError, ServiceResponseError
+from azure.core.paging import ItemPaged
+from azure.core.polling import LROPoller
+from azure.core.tracing.decorator import distributed_trace
+
+from ._operation_orchestrator import OperationOrchestrator
+
+if TYPE_CHECKING:
+ from azure.ai.ml.operations import DatastoreOperations
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class BatchEndpointOperations(_ScopeDependentOperations):
+ """BatchEndpointOperations.
+
+ You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it
+ for you and attaches it as an attribute.
+
+ :param operation_scope: Scope variables for the operations classes of an MLClient object.
+ :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope
+ :param operation_config: Common configuration for operations classes of an MLClient object.
+ :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig
+ :param service_client_10_2023: Service client to allow end users to operate on Azure Machine Learning Workspace
+ resources.
+ :type service_client_10_2023: ~azure.ai.ml._restclient.v2023_10_01._azure_machine_learning_workspaces.
+ AzureMachineLearningWorkspaces
+ :param all_operations: All operations classes of an MLClient object.
+ :type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer
+ :param credentials: Credential to use for authentication.
+ :type credentials: ~azure.core.credentials.TokenCredential
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client_10_2023: ServiceClient102023,
+ all_operations: OperationsContainer,
+ credentials: Optional[TokenCredential] = None,
+ **kwargs: Any,
+ ):
+ super(BatchEndpointOperations, self).__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+ self._batch_operation = service_client_10_2023.batch_endpoints
+ self._batch_deployment_operation = service_client_10_2023.batch_deployments
+ self._batch_job_endpoint = kwargs.pop("service_client_09_2020_dataplanepreview").batch_job_endpoint
+ self._all_operations = all_operations
+ self._credentials = credentials
+ self._init_kwargs = kwargs
+
+ self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline")
+
+ @property
+ def _datastore_operations(self) -> "DatastoreOperations":
+ from azure.ai.ml.operations import DatastoreOperations
+
+ return cast(DatastoreOperations, self._all_operations.all_operations[AzureMLResourceType.DATASTORE])
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "BatchEndpoint.List", ActivityType.PUBLICAPI)
+ def list(self) -> ItemPaged[BatchEndpoint]:
+ """List endpoints of the workspace.
+
+ :return: A list of endpoints
+ :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.BatchEndpoint]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START batch_endpoint_operations_list]
+ :end-before: [END batch_endpoint_operations_list]
+ :language: python
+ :dedent: 8
+ :caption: List example.
+ """
+ return self._batch_operation.list(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ cls=lambda objs: [BatchEndpoint._from_rest_object(obj) for obj in objs],
+ **self._init_kwargs,
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "BatchEndpoint.Get", ActivityType.PUBLICAPI)
+ def get(
+ self,
+ name: str,
+ ) -> BatchEndpoint:
+ """Get a Endpoint resource.
+
+ :param name: Name of the endpoint.
+ :type name: str
+ :return: Endpoint object retrieved from the service.
+ :rtype: ~azure.ai.ml.entities.BatchEndpoint
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START batch_endpoint_operations_get]
+ :end-before: [END batch_endpoint_operations_get]
+ :language: python
+ :dedent: 8
+ :caption: Get endpoint example.
+ """
+ # first get the endpoint
+ endpoint = self._batch_operation.get(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=name,
+ **self._init_kwargs,
+ )
+
+ endpoint_data = BatchEndpoint._from_rest_object(endpoint)
+ return endpoint_data
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "BatchEndpoint.BeginDelete", ActivityType.PUBLICAPI)
+ def begin_delete(self, name: str) -> LROPoller[None]:
+ """Delete a batch Endpoint.
+
+ :param name: Name of the batch endpoint.
+ :type name: str
+ :return: A poller to track the operation status.
+ :rtype: ~azure.core.polling.LROPoller[None]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START batch_endpoint_operations_delete]
+ :end-before: [END batch_endpoint_operations_delete]
+ :language: python
+ :dedent: 8
+ :caption: Delete endpoint example.
+ """
+ path_format_arguments = {
+ "endpointName": name,
+ "resourceGroupName": self._resource_group_name,
+ "workspaceName": self._workspace_name,
+ }
+
+ delete_poller = self._batch_operation.begin_delete(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=name,
+ polling=AzureMLPolling(
+ LROConfigurations.POLL_INTERVAL,
+ path_format_arguments=path_format_arguments,
+ **self._init_kwargs,
+ ),
+ polling_interval=LROConfigurations.POLL_INTERVAL,
+ **self._init_kwargs,
+ )
+ return delete_poller
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "BatchEndpoint.BeginCreateOrUpdate", ActivityType.PUBLICAPI)
+ def begin_create_or_update(self, endpoint: BatchEndpoint) -> LROPoller[BatchEndpoint]:
+ """Create or update a batch endpoint.
+
+ :param endpoint: The endpoint entity.
+ :type endpoint: ~azure.ai.ml.entities.BatchEndpoint
+ :return: A poller to track the operation status.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.BatchEndpoint]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START batch_endpoint_operations_create_or_update]
+ :end-before: [END batch_endpoint_operations_create_or_update]
+ :language: python
+ :dedent: 8
+ :caption: Create endpoint example.
+ """
+
+ try:
+ location = self._get_workspace_location()
+
+ endpoint_resource = endpoint._to_rest_batch_endpoint(location=location)
+ poller = self._batch_operation.begin_create_or_update(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=endpoint.name,
+ body=endpoint_resource,
+ polling=True,
+ **self._init_kwargs,
+ cls=lambda response, deserialized, headers: BatchEndpoint._from_rest_object(deserialized),
+ )
+ return poller
+
+ except Exception as ex:
+ if isinstance(ex, (ValidationException, SchemaValidationError)):
+ log_and_raise_error(ex)
+ raise ex
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "BatchEndpoint.Invoke", ActivityType.PUBLICAPI)
+ def invoke( # pylint: disable=too-many-statements
+ self,
+ endpoint_name: str,
+ *,
+ deployment_name: Optional[str] = None,
+ inputs: Optional[Dict[str, Input]] = None,
+ **kwargs: Any,
+ ) -> BatchJob:
+ """Invokes the batch endpoint with the provided payload.
+
+ :param endpoint_name: The endpoint name.
+ :type endpoint_name: str
+ :keyword deployment_name: (Optional) The name of a specific deployment to invoke. This is optional.
+ By default requests are routed to any of the deployments according to the traffic rules.
+ :paramtype deployment_name: str
+ :keyword inputs: (Optional) A dictionary of existing data asset, public uri file or folder
+ to use with the deployment
+ :paramtype inputs: Dict[str, Input]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if deployment cannot be successfully validated.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.AssetException: Raised if BatchEndpoint assets
+ (e.g. Data, Code, Model, Environment) cannot be successfully validated.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.ModelException: Raised if BatchEndpoint model cannot be successfully validated.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty directory.
+ :return: The invoked batch deployment job.
+ :rtype: ~azure.ai.ml.entities.BatchJob
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START batch_endpoint_operations_invoke]
+ :end-before: [END batch_endpoint_operations_invoke]
+ :language: python
+ :dedent: 8
+ :caption: Invoke endpoint example.
+ """
+ outputs = kwargs.get("outputs", None)
+ job_name = kwargs.get("job_name", None)
+ params_override = kwargs.get("params_override", None) or []
+ experiment_name = kwargs.get("experiment_name", None)
+ input = kwargs.get("input", None) # pylint: disable=redefined-builtin
+ # Until this bug is resolved https://msdata.visualstudio.com/Vienna/_workitems/edit/1446538
+ if deployment_name:
+ self._validate_deployment_name(endpoint_name, deployment_name)
+
+ if input and isinstance(input, Input):
+ if HTTP_PREFIX not in input.path:
+ self._resolve_input(input, os.getcwd())
+ # MFE expects a dictionary as input_data that's why we are using
+ # "UriFolder" or "UriFile" as keys depending on the input type
+ if input.type == "uri_folder":
+ params_override.append({EndpointYamlFields.BATCH_JOB_INPUT_DATA: {"UriFolder": input}})
+ elif input.type == "uri_file":
+ params_override.append({EndpointYamlFields.BATCH_JOB_INPUT_DATA: {"UriFile": input}})
+ else:
+ msg = (
+ "Unsupported input type please use a dictionary of either a path on the datastore, public URI, "
+ "a registered data asset, or a local folder path."
+ )
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.BATCH_ENDPOINT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ elif inputs:
+ for key, input_data in inputs.items():
+ if (
+ isinstance(input_data, Input)
+ and input_data.type
+ not in [InputTypes.NUMBER, InputTypes.BOOLEAN, InputTypes.INTEGER, InputTypes.STRING]
+ and HTTP_PREFIX not in input_data.path
+ ):
+ self._resolve_input(input_data, os.getcwd())
+ params_override.append({EndpointYamlFields.BATCH_JOB_INPUT_DATA: inputs})
+
+ properties = {}
+
+ if outputs:
+ params_override.append({EndpointYamlFields.BATCH_JOB_OUTPUT_DATA: outputs})
+ if job_name:
+ params_override.append({EndpointYamlFields.BATCH_JOB_NAME: job_name})
+ if experiment_name:
+ properties["experimentName"] = experiment_name
+
+ if properties:
+ params_override.append({EndpointYamlFields.BATCH_JOB_PROPERTIES: properties})
+
+ # Batch job doesn't have a python class, loading a rest object using params override
+ context = {
+ BASE_PATH_CONTEXT_KEY: Path(".").parent,
+ PARAMS_OVERRIDE_KEY: params_override,
+ }
+
+ batch_job = BatchJobSchema(context=context).load(data={})
+ # update output datastore to arm id if needed
+ # TODO: Unify datastore name -> arm id logic, TASK: 1104172
+ request = {}
+ if (
+ batch_job.output_dataset
+ and batch_job.output_dataset.datastore_id
+ and (not is_ARM_id_for_resource(batch_job.output_dataset.datastore_id))
+ ):
+ v2_dataset_dictionary = convert_v1_dataset_to_v2(batch_job.output_dataset, batch_job.output_file_name)
+ batch_job.output_dataset = None
+ batch_job.output_file_name = None
+ request = BatchJobResource(properties=batch_job).serialize()
+ request["properties"]["outputData"] = v2_dataset_dictionary
+ else:
+ request = BatchJobResource(properties=batch_job).serialize()
+
+ endpoint = self._batch_operation.get(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=endpoint_name,
+ **self._init_kwargs,
+ )
+
+ headers = EndpointInvokeFields.DEFAULT_HEADER
+ ml_audience_scopes = _resource_to_scopes(_get_aml_resource_id_from_metadata())
+ module_logger.debug("ml_audience_scopes used: `%s`\n", ml_audience_scopes)
+ key = self._credentials.get_token(*ml_audience_scopes).token if self._credentials is not None else ""
+ headers[EndpointInvokeFields.AUTHORIZATION] = f"Bearer {key}"
+ headers[EndpointInvokeFields.REPEATABILITY_REQUEST_ID] = str(uuid.uuid4())
+
+ if deployment_name:
+ headers[EndpointInvokeFields.MODEL_DEPLOYMENT] = deployment_name
+
+ retry_attempts = 0
+ while retry_attempts < 5:
+ try:
+ response = self._requests_pipeline.post(
+ endpoint.properties.scoring_uri,
+ json=request,
+ headers=headers,
+ )
+ except (ServiceRequestError, ServiceResponseError):
+ retry_attempts += 1
+ continue
+ break
+ if retry_attempts == 5:
+ retry_msg = "Max retry attempts reached while trying to connect to server. Please check connection and invoke again." # pylint: disable=line-too-long
+ raise MlException(message=retry_msg, no_personal_data_message=retry_msg, target=ErrorTarget.BATCH_ENDPOINT)
+ validate_response(response)
+ batch_job = json.loads(response.text())
+ return BatchJobResource.deserialize(batch_job)
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "BatchEndpoint.ListJobs", ActivityType.PUBLICAPI)
+ def list_jobs(self, endpoint_name: str) -> ItemPaged[BatchJob]:
+ """List jobs under the provided batch endpoint deployment. This is only valid for batch endpoint.
+
+ :param endpoint_name: The endpoint name
+ :type endpoint_name: str
+ :return: List of jobs
+ :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.BatchJob]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START batch_endpoint_operations_list_jobs]
+ :end-before: [END batch_endpoint_operations_list_jobs]
+ :language: python
+ :dedent: 8
+ :caption: List jobs example.
+ """
+
+ workspace_operations = self._all_operations.all_operations[AzureMLResourceType.WORKSPACE]
+ mfe_base_uri = _get_mfe_base_url_from_discovery_service(
+ workspace_operations, self._workspace_name, self._requests_pipeline
+ )
+
+ with modified_operation_client(self._batch_job_endpoint, mfe_base_uri):
+ result = self._batch_job_endpoint.list(
+ endpoint_name=endpoint_name,
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ **self._init_kwargs,
+ )
+
+ # This is necessary as the paged result need to be resolved inside the context manager
+ return list(result)
+
+ def _get_workspace_location(self) -> str:
+ return str(
+ self._all_operations.all_operations[AzureMLResourceType.WORKSPACE].get(self._workspace_name).location
+ )
+
+ def _validate_deployment_name(self, endpoint_name: str, deployment_name: str) -> None:
+ deployments_list = self._batch_deployment_operation.list(
+ endpoint_name=endpoint_name,
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ cls=lambda objs: [obj.name for obj in objs],
+ **self._init_kwargs,
+ )
+ if deployments_list:
+ if deployment_name not in deployments_list:
+ msg = f"Deployment name {deployment_name} not found for this endpoint"
+ raise ValidationException(
+ message=msg.format(deployment_name),
+ no_personal_data_message=msg.format("[deployment_name]"),
+ target=ErrorTarget.DEPLOYMENT,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.RESOURCE_NOT_FOUND,
+ )
+ else:
+ msg = "No deployment exists for this endpoint"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ target=ErrorTarget.DEPLOYMENT,
+ error_type=ValidationErrorType.RESOURCE_NOT_FOUND,
+ )
+
+ def _resolve_input(self, entry: Input, base_path: str) -> None:
+ # We should not verify anything that is not of type Input
+ if not isinstance(entry, Input):
+ return
+
+ # Input path should not be empty
+ if not entry.path:
+ msg = "Input path can't be empty for batch endpoint invoke"
+ raise MlException(message=msg, no_personal_data_message=msg)
+
+ if entry.type in [InputTypes.NUMBER, InputTypes.BOOLEAN, InputTypes.INTEGER, InputTypes.STRING]:
+ return
+
+ try:
+ if entry.path.startswith(ARM_ID_FULL_PREFIX):
+ if not is_ARM_id_for_resource(entry.path, AzureMLResourceType.DATA):
+ raise ValidationException(
+ message="Invalid input path",
+ target=ErrorTarget.BATCH_ENDPOINT,
+ no_personal_data_message="Invalid input path",
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ elif os.path.isabs(entry.path): # absolute local path, upload, transform to remote url
+ if entry.type == AssetTypes.URI_FOLDER and not os.path.isdir(entry.path):
+ raise ValidationException(
+ message="There is no folder on target path: {}".format(entry.path),
+ target=ErrorTarget.BATCH_ENDPOINT,
+ no_personal_data_message="There is no folder on target path",
+ error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND,
+ )
+ if entry.type == AssetTypes.URI_FILE and not os.path.isfile(entry.path):
+ raise ValidationException(
+ message="There is no file on target path: {}".format(entry.path),
+ target=ErrorTarget.BATCH_ENDPOINT,
+ no_personal_data_message="There is no file on target path",
+ error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND,
+ )
+ # absolute local path
+ entry.path = _upload_and_generate_remote_uri(
+ self._operation_scope,
+ self._datastore_operations,
+ entry.path,
+ )
+ if entry.type == AssetTypes.URI_FOLDER and entry.path and not entry.path.endswith("/"):
+ entry.path = entry.path + "/"
+ elif ":" in entry.path or "@" in entry.path: # Check registered file or folder datastore
+ # If we receive a datastore path in long/short form we don't need
+ # to get the arm asset id
+ if re.match(SHORT_URI_REGEX_FORMAT, entry.path) or re.match(LONG_URI_REGEX_FORMAT, entry.path):
+ return
+ if is_private_preview_enabled() and re.match(AZUREML_REGEX_FORMAT, entry.path):
+ return
+ asset_type = AzureMLResourceType.DATA
+ entry.path = remove_aml_prefix(entry.path)
+ orchestrator = OperationOrchestrator(
+ self._all_operations, self._operation_scope, self._operation_config
+ )
+ entry.path = orchestrator.get_asset_arm_id(entry.path, asset_type)
+ else: # relative local path, upload, transform to remote url
+ local_path = Path(base_path, entry.path).resolve()
+ entry.path = _upload_and_generate_remote_uri(
+ self._operation_scope,
+ self._datastore_operations,
+ local_path,
+ )
+ if entry.type == AssetTypes.URI_FOLDER and entry.path and not entry.path.endswith("/"):
+ entry.path = entry.path + "/"
+ except (MlException, HttpResponseError) as e:
+ raise e
+ except Exception as e:
+ raise ValidationException(
+ message=f"Supported input path value are: path on the datastore, public URI, "
+ "a registered data asset, or a local folder path.\n"
+ f"Met {type(e)}:\n{e}",
+ target=ErrorTarget.BATCH_ENDPOINT,
+ no_personal_data_message="Supported input path value are: path on the datastore, "
+ "public URI, a registered data asset, or a local folder path.",
+ error=e,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ ) from e
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_capability_hosts_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_capability_hosts_operations.py
new file mode 100644
index 00000000..e8cc03f2
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_capability_hosts_operations.py
@@ -0,0 +1,304 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access
+
+from typing import Any, List
+
+from marshmallow.exceptions import ValidationError as SchemaValidationError
+
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._restclient.v2024_10_01_preview import AzureMachineLearningWorkspaces as ServiceClient102024Preview
+from azure.ai.ml._scope_dependent_operations import (
+ OperationConfig,
+ OperationsContainer,
+ OperationScope,
+ _ScopeDependentOperations,
+)
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml.constants._common import DEFAULT_STORAGE_CONNECTION_NAME, WorkspaceKind
+from azure.ai.ml.entities._workspace._ai_workspaces.capability_host import CapabilityHost
+from azure.ai.ml.entities._workspace.workspace import Workspace
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+from azure.core.credentials import TokenCredential
+from azure.core.polling import LROPoller
+from azure.core.tracing.decorator import distributed_trace
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class CapabilityHostsOperations(_ScopeDependentOperations):
+ """CapabilityHostsOperations.
+
+ You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it
+ for you and attaches it as an attribute.
+
+ :param operation_scope: Scope variables for the operations classes of an MLClient object.
+ :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope
+ :param operation_config: Common configuration for operations classes of an MLClient object.
+ :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig
+ :param service_client_10_2024: Service client to allow end users to operate on Azure Machine Learning Workspace
+ resources (ServiceClient102024Preview).
+ :type service_client_10_2024: ~azure.ai.ml._restclient.v2024_10_01_preview._azure_machine_learning_workspaces.AzureMachineLearningWorkspaces # pylint: disable=line-too-long
+ :param all_operations: All operations classes of an MLClient object.
+ :type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer
+ :param credentials: Credential to use for authentication.
+ :type credentials: ~azure.core.credentials.TokenCredential
+ :param kwargs: Additional keyword arguments.
+ :type kwargs: Any
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client_10_2024: ServiceClient102024Preview,
+ all_operations: OperationsContainer,
+ credentials: TokenCredential,
+ **kwargs: Any,
+ ):
+ """Constructor of CapabilityHostsOperations class.
+
+ :param operation_scope: Scope variables for the operations classes of an MLClient object.
+ :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope
+ :param operation_config: Common configuration for operations classes of an MLClient object.
+ :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig
+ :param service_client_10_2024: Service client to allow end users to operate on Azure Machine Learning Workspace
+ resources (ServiceClient102024Preview).
+ :type service_client_10_2024: ~azure.ai.ml._restclient.v2024_10_01_preview._azure_machine_learning_workspaces.AzureMachineLearningWorkspaces # pylint: disable=line-too-long
+ :param all_operations: All operations classes of an MLClient object.
+ :type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer
+ :param credentials: Credential to use for authentication.
+ :type credentials: ~azure.core.credentials.TokenCredential
+ :param kwargs: Additional keyword arguments.
+ :type kwargs: Any
+ """
+
+ super(CapabilityHostsOperations, self).__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+ self._all_operations = all_operations
+ self._capability_hosts_operations = service_client_10_2024.capability_hosts
+ self._workspace_operations = service_client_10_2024.workspaces
+ self._credentials = credentials
+ self._init_kwargs = kwargs
+
+ @experimental
+ @monitor_with_activity(ops_logger, "CapabilityHost.Get", ActivityType.PUBLICAPI)
+ @distributed_trace
+ def get(self, name: str, **kwargs: Any) -> CapabilityHost:
+ """Retrieve a capability host resource.
+
+ :param name: The name of the capability host to retrieve.
+ :type name: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if project name or hub name
+ not provided while creation of MLClient object in workspacename param.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Capabilityhost name is not provided.
+ Details will be provided in the error message.
+ :return: CapabilityHost object.
+ :rtype: ~azure.ai.ml.entities._workspace._ai_workspaces.capability_host.CapabilityHost
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_capability_host.py
+ :start-after: [START capability_host_get_operation]
+ :end-before: [END capability_host_get_operation]
+ :language: python
+ :dedent: 8
+ :caption: Get example.
+ """
+
+ self._validate_workspace_name()
+
+ rest_obj = self._capability_hosts_operations.get(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ name=name,
+ **kwargs,
+ )
+
+ capability_host = CapabilityHost._from_rest_object(rest_obj)
+
+ return capability_host
+
+ @experimental
+ @monitor_with_activity(ops_logger, "CapabilityHost.BeginCreateOrUpdate", ActivityType.PUBLICAPI)
+ @distributed_trace
+ def begin_create_or_update(self, capability_host: CapabilityHost, **kwargs: Any) -> LROPoller[CapabilityHost]:
+ """Begin the creation of a capability host in a Hub or Project workspace.
+ Note that currently this method can only accept the `create` operation request
+ and not `update` operation request.
+
+ :param capability_host: The CapabilityHost object containing the details of the capability host to create.
+ :type capability_host: ~azure.ai.ml.entities.CapabilityHost
+ :return: An LROPoller object that can be used to track the long-running
+ operation that is creation of capability host.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities._workspace._ai_workspaces.capability_host.CapabilityHost] # pylint: disable=line-too-long
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_capability_host.py
+ :start-after: [START capability_host_begin_create_or_update_operation]
+ :end-before: [END capability_host_begin_create_or_update_operation]
+ :language: python
+ :dedent: 8
+ :caption: Create example.
+ """
+ try:
+ self._validate_workspace_name()
+
+ workspace = self._get_workspace()
+
+ self._validate_workspace_kind(workspace._kind)
+
+ self._validate_properties(capability_host, workspace._kind)
+
+ if workspace._kind == WorkspaceKind.PROJECT:
+ if capability_host.storage_connections is None or len(capability_host.storage_connections) == 0:
+ capability_host.storage_connections = self._get_default_storage_connections()
+
+ capability_host_resource = capability_host._to_rest_object()
+
+ poller = self._capability_hosts_operations.begin_create_or_update(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ name=capability_host.name,
+ body=capability_host_resource,
+ polling=True,
+ **kwargs,
+ cls=lambda response, deserialized, headers: CapabilityHost._from_rest_object(deserialized),
+ )
+ return poller
+
+ except Exception as ex:
+ if isinstance(ex, (ValidationException, SchemaValidationError)):
+ log_and_raise_error(ex)
+ raise ex
+
+ @experimental
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "CapabilityHost.Delete", ActivityType.PUBLICAPI)
+ def begin_delete(
+ self,
+ name: str,
+ **kwargs: Any,
+ ) -> LROPoller[None]:
+ """Delete capability host.
+
+ :param name: capability host name.
+ :type name: str
+ :return: A poller for deletion status
+ :rtype: ~azure.core.polling.LROPoller[None]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_capability_host.py
+ :start-after: [START capability_host_delete_operation]
+ :end-before: [END capability_host_delete_operation]
+ :language: python
+ :dedent: 8
+ :caption: Delete example.
+ """
+ poller = self._capability_hosts_operations.begin_delete(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ name=name,
+ polling=True,
+ **kwargs,
+ )
+ return poller
+
+ def _get_default_storage_connections(self) -> List[str]:
+ """Retrieve the default storage connections for a capability host.
+
+ :return: A list of default storage connections.
+ :rtype: List[str]
+ """
+ return [f"{self._workspace_name}/{DEFAULT_STORAGE_CONNECTION_NAME}"]
+
+ def _validate_workspace_kind(self, workspace_kind: str) -> None:
+ """Validate the workspace kind, it should be either hub or project only.
+
+ :param workspace_kind: The kind of the workspace, either hub or project only.
+ :type workspace_kind: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if workspace kind is not Hub or Project.
+ Details will be provided in the error message.
+ :return: None, or the result of cls(response)
+ :rtype: None
+ """
+
+ valid_kind = workspace_kind in {WorkspaceKind.HUB, WorkspaceKind.PROJECT}
+ if not valid_kind:
+ msg = f"Invalid workspace kind: {workspace_kind}. Workspace kind should be either 'Hub' or 'Project'."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.CAPABILITY_HOST,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ def _validate_properties(self, capability_host: CapabilityHost, workspace_kind: str) -> None:
+ """Validate the properties of the capability host for project workspace.
+
+ :param capability_host: The capability host to validate.
+ :type capability_host: CapabilityHost
+ :param workspace_kind: The kind of the workspace, Project only.
+ :type workspace_kind: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if the OpenAI service connection or
+ vector store (AISearch) connection is empty for a Project workspace kind.
+ Details will be provided in the error message.
+ :return: None, or the result of cls(response)
+ :rtype: None
+ """
+
+ if workspace_kind == WorkspaceKind.PROJECT:
+ if capability_host.ai_services_connections is None or capability_host.vector_store_connections is None:
+ msg = "For Project workspace kind, OpenAI service connections and vector store (AISearch) connections are required." # pylint: disable=line-too-long
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.CAPABILITY_HOST,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ def _get_workspace(self) -> Workspace:
+ """Retrieve the workspace object.
+
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if specified Hub or Project do not exist.
+ Details will be provided in the error message.
+ :return: Hub or Project object if it exists
+ :rtype: ~azure.ai.ml.entities._workspace.workspace.Workspace
+ """
+ rest_workspace = self._workspace_operations.get(self._resource_group_name, self._workspace_name)
+ workspace = Workspace._from_rest_object(rest_workspace)
+ if workspace is None:
+ msg = f"Workspace with name {self._workspace_name} does not exist."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.CAPABILITY_HOST,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ return workspace
+
+ def _validate_workspace_name(self) -> None:
+ """Validates that a hub name or project name is set in the MLClient's workspace name parameter.
+
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if project name or hub name
+ not provided while creation of
+ MLClient object in workspacename param. Details will be provided in the error message.
+ :return: None, or the result of cls(response)
+ :rtype: None
+ """
+ workspace_name = self._workspace_name
+ if not workspace_name:
+ msg = "Please pass either a hub name or project name to the workspace_name parameter when initializing an MLClient object." # pylint: disable=line-too-long
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.CAPABILITY_HOST,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_code_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_code_operations.py
new file mode 100644
index 00000000..7b0e64c3
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_code_operations.py
@@ -0,0 +1,307 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+import re
+from os import PathLike
+from pathlib import Path
+from typing import Dict, Optional, Union
+
+from marshmallow.exceptions import ValidationError as SchemaValidationError
+
+from azure.ai.ml._artifacts._artifact_utilities import (
+ _check_and_upload_path,
+ _get_datastore_name,
+ _get_snapshot_path_info,
+ get_datastore_info,
+)
+from azure.ai.ml._artifacts._constants import (
+ ASSET_PATH_ERROR,
+ CHANGED_ASSET_PATH_MSG,
+ CHANGED_ASSET_PATH_MSG_NO_PERSONAL_DATA,
+)
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import (
+ AzureMachineLearningWorkspaces as ServiceClient102021Dataplane,
+)
+from azure.ai.ml._restclient.v2022_10_01_preview import AzureMachineLearningWorkspaces as ServiceClient102022
+from azure.ai.ml._restclient.v2023_04_01 import AzureMachineLearningWorkspaces as ServiceClient042023
+from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._asset_utils import (
+ _get_existing_asset_name_and_version,
+ get_content_hash_version,
+ get_storage_info_for_non_registry_asset,
+)
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._utils._registry_utils import get_asset_body_for_registry_storage, get_sas_uri_for_registry_asset
+from azure.ai.ml._utils._storage_utils import get_storage_client
+from azure.ai.ml.entities._assets import Code
+from azure.ai.ml.exceptions import (
+ AssetPathException,
+ ErrorCategory,
+ ErrorTarget,
+ ValidationErrorType,
+ ValidationException,
+)
+from azure.ai.ml.operations._datastore_operations import DatastoreOperations
+from azure.core.exceptions import HttpResponseError
+
+# pylint: disable=protected-access
+
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class CodeOperations(_ScopeDependentOperations):
+ """Represents a client for performing operations on code assets.
+
+ You should not instantiate this class directly. Instead, you should create MLClient and use this client via the
+ property MLClient.code
+
+ :param operation_scope: Scope variables for the operations classes of an MLClient object.
+ :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope
+ :param operation_config: Common configuration for operations classes of an MLClient object.
+ :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig
+ :param service_client: Service client to allow end users to operate on Azure Machine Learning Workspace resources.
+ :type service_client: typing.Union[
+ ~azure.ai.ml._restclient.v2022_10_01_preview._azure_machine_learning_workspaces.AzureMachineLearningWorkspaces,
+ ~azure.ai.ml._restclient.v2021_10_01_dataplanepreview._azure_machine_learning_workspaces.
+ AzureMachineLearningWorkspaces,
+ ~azure.ai.ml._restclient.v2023_04_01._azure_machine_learning_workspaces.AzureMachineLearningWorkspaces]
+ :param datastore_operations: Represents a client for performing operations on Datastores.
+ :type datastore_operations: ~azure.ai.ml.operations._datastore_operations.DatastoreOperations
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client: Union[ServiceClient102022, ServiceClient102021Dataplane, ServiceClient042023],
+ datastore_operations: DatastoreOperations,
+ **kwargs: Dict,
+ ):
+ super(CodeOperations, self).__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+ self._service_client = service_client
+ self._version_operation = service_client.code_versions
+ self._container_operation = service_client.code_containers
+ self._datastore_operation = datastore_operations
+ self._init_kwargs = kwargs
+
+ @monitor_with_activity(ops_logger, "Code.CreateOrUpdate", ActivityType.PUBLICAPI)
+ def create_or_update(self, code: Code) -> Code:
+ """Returns created or updated code asset.
+
+ If not already in storage, asset will be uploaded to the workspace's default datastore.
+
+ :param code: Code asset object.
+ :type code: Code
+ :raises ~azure.ai.ml.exceptions.AssetPathException: Raised when the Code artifact path is
+ already linked to another asset
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Code cannot be successfully validated.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty directory.
+ :return: Code asset object.
+ :rtype: ~azure.ai.ml.entities.Code
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START code_operations_create_or_update]
+ :end-before: [END code_operations_create_or_update]
+ :language: python
+ :dedent: 8
+ :caption: Create code asset example.
+ """
+ try:
+ name = code.name
+ version = code.version
+ sas_uri = None
+ blob_uri = None
+
+ if self._registry_name:
+ sas_uri = get_sas_uri_for_registry_asset(
+ service_client=self._service_client,
+ name=name,
+ version=version,
+ resource_group=self._resource_group_name,
+ registry=self._registry_name,
+ body=get_asset_body_for_registry_storage(self._registry_name, "codes", name, version),
+ )
+ else:
+ snapshot_path_info = _get_snapshot_path_info(code)
+ if snapshot_path_info:
+ _, _, asset_hash = snapshot_path_info
+ existing_assets = list(
+ self._version_operation.list(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ name=name,
+ hash=asset_hash,
+ hash_version=str(get_content_hash_version()),
+ )
+ )
+
+ if len(existing_assets) > 0:
+ existing_asset = existing_assets[0]
+ name, version = _get_existing_asset_name_and_version(existing_asset)
+ return self.get(name=name, version=version)
+ sas_info = get_storage_info_for_non_registry_asset(
+ service_client=self._service_client,
+ workspace_name=self._workspace_name,
+ name=name,
+ version=version,
+ resource_group=self._resource_group_name,
+ )
+ sas_uri = sas_info["sas_uri"]
+ blob_uri = sas_info["blob_uri"]
+
+ code, _ = _check_and_upload_path(
+ artifact=code,
+ asset_operations=self,
+ sas_uri=sas_uri,
+ artifact_type=ErrorTarget.CODE,
+ show_progress=self._show_progress,
+ blob_uri=blob_uri,
+ )
+
+ # For anonymous code, if the code already exists in storage, we reuse the name,
+ # version stored in the storage metadata so the same anonymous code won't be created again.
+ if code._is_anonymous:
+ name = code.name
+ version = code.version
+
+ code_version_resource = code._to_rest_object()
+
+ result = (
+ self._version_operation.begin_create_or_update(
+ name=name,
+ version=version,
+ registry_name=self._registry_name,
+ resource_group_name=self._operation_scope.resource_group_name,
+ body=code_version_resource,
+ **self._init_kwargs,
+ ).result()
+ if self._registry_name
+ else self._version_operation.create_or_update(
+ name=name,
+ version=version,
+ workspace_name=self._workspace_name,
+ resource_group_name=self._operation_scope.resource_group_name,
+ body=code_version_resource,
+ **self._init_kwargs,
+ )
+ )
+
+ if not result:
+ return self.get(name=name, version=version)
+ return Code._from_rest_object(result)
+ except Exception as ex:
+ if isinstance(ex, (ValidationException, SchemaValidationError)):
+ log_and_raise_error(ex)
+ elif isinstance(ex, HttpResponseError):
+ # service side raises an exception if we attempt to update an existing asset's asset path
+ if str(ex) == ASSET_PATH_ERROR:
+ raise AssetPathException(
+ message=CHANGED_ASSET_PATH_MSG,
+ target=ErrorTarget.CODE,
+ no_personal_data_message=CHANGED_ASSET_PATH_MSG_NO_PERSONAL_DATA,
+ error_category=ErrorCategory.USER_ERROR,
+ ) from ex
+ raise ex
+
+ @monitor_with_activity(ops_logger, "Code.Get", ActivityType.PUBLICAPI)
+ def get(self, name: str, version: str) -> Code:
+ """Returns information about the specified code asset.
+
+ :param name: Name of the code asset.
+ :type name: str
+ :param version: Version of the code asset.
+ :type version: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Code cannot be successfully validated.
+ Details will be provided in the error message.
+ :return: Code asset object.
+ :rtype: ~azure.ai.ml.entities.Code
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START code_operations_get]
+ :end-before: [END code_operations_get]
+ :language: python
+ :dedent: 8
+ :caption: Get code asset example.
+ """
+ return self._get(name=name, version=version)
+
+ # this is a public API but CodeOperations is hidden, so it may only monitor internal calls
+ @monitor_with_activity(ops_logger, "Code.Download", ActivityType.PUBLICAPI)
+ def download(self, name: str, version: str, download_path: Union[PathLike, str]) -> None:
+ """Download content of a code.
+
+ :param str name: Name of the code.
+ :param str version: Version of the code.
+ :param Union[PathLike, str] download_path: Local path as download destination,
+ defaults to current working directory of the current user. Contents will be overwritten.
+ :raise: ResourceNotFoundError if can't find a code matching provided name.
+ """
+ output_dir = Path(download_path)
+ if output_dir.is_dir():
+ # an OSError will be raised if the directory is not empty
+ output_dir.rmdir()
+ output_dir.mkdir(parents=True)
+
+ code = self._get(name=name, version=version)
+
+ # TODO: how should we maintain this regex?
+ m = re.match(
+ r"https://(?P<account_name>.+)\.blob\.core\.windows\.net"
+ r"(:[0-9]+)?/(?P<container_name>.+)/(?P<blob_name>.*)",
+ str(code.path),
+ )
+ if not m:
+ raise ValueError(f"Invalid code path: {code.path}")
+
+ datastore_info = get_datastore_info(
+ self._datastore_operation,
+ # always use WORKSPACE_BLOB_STORE
+ name=_get_datastore_name(),
+ container_name=m.group("container_name"),
+ )
+ storage_client = get_storage_client(**datastore_info)
+ storage_client.download(
+ starts_with=m.group("blob_name"),
+ destination=output_dir.as_posix(),
+ )
+ if not output_dir.is_dir() or not any(output_dir.iterdir()):
+ raise RuntimeError(f"Failed to download code to {output_dir}")
+
+ def _get(self, name: str, version: Optional[str] = None) -> Code:
+ if not version:
+ msg = "Code asset version must be specified as part of name parameter, in format 'name:version'."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.CODE,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ code_version_resource = (
+ self._version_operation.get(
+ name=name,
+ version=version,
+ resource_group_name=self._operation_scope.resource_group_name,
+ registry_name=self._registry_name,
+ **self._init_kwargs,
+ )
+ if self._registry_name
+ else self._version_operation.get(
+ name=name,
+ version=version,
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ **self._init_kwargs,
+ )
+ )
+ return Code._from_rest_object(code_version_resource)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_component_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_component_operations.py
new file mode 100644
index 00000000..f9e43f1d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_component_operations.py
@@ -0,0 +1,1289 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access,too-many-lines
+import time
+import collections
+import types
+from functools import partial
+from inspect import Parameter, signature
+from os import PathLike
+from pathlib import Path
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast
+import hashlib
+
+from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import (
+ AzureMachineLearningWorkspaces as ServiceClient102021Dataplane,
+)
+from azure.ai.ml._restclient.v2024_01_01_preview import (
+ AzureMachineLearningWorkspaces as ServiceClient012024,
+)
+from azure.ai.ml._restclient.v2024_01_01_preview.models import (
+ ComponentVersion,
+ ListViewType,
+)
+from azure.ai.ml._scope_dependent_operations import (
+ OperationConfig,
+ OperationsContainer,
+ OperationScope,
+ _ScopeDependentOperations,
+)
+from azure.ai.ml._telemetry import (
+ ActivityType,
+ monitor_with_activity,
+ monitor_with_telemetry_mixin,
+)
+from azure.ai.ml._utils._asset_utils import (
+ _archive_or_restore,
+ _create_or_update_autoincrement,
+ _get_file_hash,
+ _get_latest,
+ _get_next_version_from_container,
+ _resolve_label_to_asset,
+ get_ignore_file,
+ get_upload_files_from_folder,
+ IgnoreFile,
+ delete_two_catalog_files,
+ create_catalog_files,
+)
+from azure.ai.ml._utils._azureml_polling import AzureMLPolling
+from azure.ai.ml._utils._endpoint_utils import polling_wait
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._vendor.azure_resources.operations import DeploymentsOperations
+from azure.ai.ml.constants._common import (
+ DEFAULT_COMPONENT_VERSION,
+ DEFAULT_LABEL_NAME,
+ AzureMLResourceType,
+ DefaultOpenEncoding,
+ LROConfigurations,
+)
+from azure.ai.ml.entities import Component, ValidationResult
+from azure.ai.ml.exceptions import (
+ ComponentException,
+ ErrorCategory,
+ ErrorTarget,
+ ValidationException,
+)
+from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
+
+from .._utils._cache_utils import CachedNodeResolver
+from .._utils._experimental import experimental
+from .._utils.utils import extract_name_and_version, is_data_binding_expression
+from ..entities._builders import BaseNode
+from ..entities._builders.condition_node import ConditionNode
+from ..entities._builders.control_flow_node import LoopNode
+from ..entities._component.automl_component import AutoMLComponent
+from ..entities._component.code import ComponentCodeMixin
+from ..entities._component.pipeline_component import PipelineComponent
+from ..entities._job.pipeline._attr_dict import has_attr_safe
+from ._code_operations import CodeOperations
+from ._environment_operations import EnvironmentOperations
+from ._operation_orchestrator import OperationOrchestrator, _AssetResolver
+from ._workspace_operations import WorkspaceOperations
+
+ops_logger = OpsLogger(__name__)
+logger, module_logger = ops_logger.package_logger, ops_logger.module_logger
+
+
+class ComponentOperations(_ScopeDependentOperations):
+ """ComponentOperations.
+
+ You should not instantiate this class directly. Instead, you should
+ create an MLClient instance that instantiates it for you and
+ attaches it as an attribute.
+
+ :param operation_scope: The operation scope.
+ :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope
+ :param operation_config: The operation configuration.
+ :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig
+ :param service_client: The service client for API operations.
+ :type service_client: Union[
+ ~azure.ai.ml._restclient.v2022_10_01.AzureMachineLearningWorkspaces,
+ ~azure.ai.ml._restclient.v2021_10_01_dataplanepreview.AzureMachineLearningWorkspaces]
+ :param all_operations: The container for all available operations.
+ :type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer
+ :param preflight_operation: The preflight operation for deployments.
+ :type preflight_operation: Optional[~azure.ai.ml._vendor.azure_resources.operations.DeploymentsOperations]
+ :param kwargs: Additional keyword arguments.
+ :type kwargs: Dict
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client: Union[ServiceClient012024, ServiceClient102021Dataplane],
+ all_operations: OperationsContainer,
+ preflight_operation: Optional[DeploymentsOperations] = None,
+ **kwargs: Dict,
+ ) -> None:
+ super(ComponentOperations, self).__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+ self._version_operation = service_client.component_versions
+ self._preflight_operation = preflight_operation
+ self._container_operation = service_client.component_containers
+ self._all_operations = all_operations
+ self._init_args = kwargs
+ # Maps a label to a function which given an asset name,
+ # returns the asset associated with the label
+ self._managed_label_resolver = {"latest": self._get_latest_version}
+ self._orchestrators = OperationOrchestrator(self._all_operations, self._operation_scope, self._operation_config)
+
+ self._client_key: Optional[str] = None
+
+ @property
+ def _code_operations(self) -> CodeOperations:
+ res: CodeOperations = self._all_operations.get_operation( # type: ignore[misc]
+ AzureMLResourceType.CODE, lambda x: isinstance(x, CodeOperations)
+ )
+ return res
+
+ @property
+ def _environment_operations(self) -> EnvironmentOperations:
+ return cast(
+ EnvironmentOperations,
+ self._all_operations.get_operation( # type: ignore[misc]
+ AzureMLResourceType.ENVIRONMENT,
+ lambda x: isinstance(x, EnvironmentOperations),
+ ),
+ )
+
+ @property
+ def _workspace_operations(self) -> WorkspaceOperations:
+ return cast(
+ WorkspaceOperations,
+ self._all_operations.get_operation( # type: ignore[misc]
+ AzureMLResourceType.WORKSPACE,
+ lambda x: isinstance(x, WorkspaceOperations),
+ ),
+ )
+
+ @property
+ def _job_operations(self) -> Any:
+ from ._job_operations import JobOperations
+
+ return self._all_operations.get_operation( # type: ignore[misc]
+ AzureMLResourceType.JOB, lambda x: isinstance(x, JobOperations)
+ )
+
+ @monitor_with_activity(ops_logger, "Component.List", ActivityType.PUBLICAPI)
+ def list(
+ self,
+ name: Union[str, None] = None,
+ *,
+ list_view_type: ListViewType = ListViewType.ACTIVE_ONLY,
+ ) -> Iterable[Component]:
+ """List specific component or components of the workspace.
+
+ :param name: Component name, if not set, list all components of the workspace
+ :type name: Optional[str]
+ :keyword list_view_type: View type for including/excluding (for example) archived components.
+ Default: ACTIVE_ONLY.
+ :type list_view_type: Optional[ListViewType]
+ :return: An iterator like instance of component objects
+ :rtype: ~azure.core.paging.ItemPaged[Component]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START component_operations_list]
+ :end-before: [END component_operations_list]
+ :language: python
+ :dedent: 8
+ :caption: List component example.
+ """
+
+ if name:
+ return cast(
+ Iterable[Component],
+ (
+ self._version_operation.list(
+ name=name,
+ resource_group_name=self._resource_group_name,
+ registry_name=self._registry_name,
+ **self._init_args,
+ cls=lambda objs: [Component._from_rest_object(obj) for obj in objs],
+ )
+ if self._registry_name
+ else self._version_operation.list(
+ name=name,
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ list_view_type=list_view_type,
+ **self._init_args,
+ cls=lambda objs: [Component._from_rest_object(obj) for obj in objs],
+ )
+ ),
+ )
+ return cast(
+ Iterable[Component],
+ (
+ self._container_operation.list(
+ resource_group_name=self._resource_group_name,
+ registry_name=self._registry_name,
+ **self._init_args,
+ cls=lambda objs: [Component._from_container_rest_object(obj) for obj in objs],
+ )
+ if self._registry_name
+ else self._container_operation.list(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ list_view_type=list_view_type,
+ **self._init_args,
+ cls=lambda objs: [Component._from_container_rest_object(obj) for obj in objs],
+ )
+ ),
+ )
+
+ @monitor_with_telemetry_mixin(ops_logger, "ComponentVersion.Get", ActivityType.INTERNALCALL)
+ def _get_component_version(self, name: str, version: Optional[str] = DEFAULT_COMPONENT_VERSION) -> ComponentVersion:
+ """Returns ComponentVersion information about the specified component name and version.
+
+ :param name: Name of the code component.
+ :type name: str
+ :param version: Version of the component.
+ :type version: Optional[str]
+ :return: The ComponentVersion object of the specified component name and version.
+ :rtype: ~azure.ai.ml.entities.ComponentVersion
+ """
+ result = (
+ self._version_operation.get(
+ name=name,
+ version=version,
+ resource_group_name=self._resource_group_name,
+ registry_name=self._registry_name,
+ **self._init_args,
+ )
+ if self._registry_name
+ else self._version_operation.get(
+ name=name,
+ version=version,
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ **self._init_args,
+ )
+ )
+ return result
+
+ @monitor_with_telemetry_mixin(ops_logger, "Component.Get", ActivityType.PUBLICAPI)
+ def get(self, name: str, version: Optional[str] = None, label: Optional[str] = None) -> Component:
+ """Returns information about the specified component.
+
+ :param name: Name of the code component.
+ :type name: str
+ :param version: Version of the component.
+ :type version: Optional[str]
+ :param label: Label of the component, mutually exclusive with version.
+ :type label: Optional[str]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Component cannot be successfully
+ identified and retrieved. Details will be provided in the error message.
+ :return: The specified component object.
+ :rtype: ~azure.ai.ml.entities.Component
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START component_operations_get]
+ :end-before: [END component_operations_get]
+ :language: python
+ :dedent: 8
+ :caption: Get component example.
+ """
+ return self._get(name=name, version=version, label=label)
+
+ def _localize_code(self, component: Component, base_dir: Path) -> None:
+ if not isinstance(component, ComponentCodeMixin):
+ return
+ code = component._get_origin_code_value()
+ if not isinstance(code, str):
+ return
+ # registry code will keep the "azureml:" prefix can be used directly
+ if code.startswith("azureml://registries"):
+ return
+
+ target_code_value = "./code"
+ self._code_operations.download(
+ **extract_name_and_version(code),
+ download_path=base_dir.joinpath(target_code_value),
+ )
+
+ setattr(component, component._get_code_field_name(), target_code_value)
+
+ def _localize_environment(self, component: Component, base_dir: Path) -> None:
+ from azure.ai.ml.entities import ParallelComponent
+
+ parent: Any = None
+ if hasattr(component, "environment"):
+ parent = component
+ elif isinstance(component, ParallelComponent):
+ parent = component.task
+ else:
+ return
+
+ # environment can be None
+ if not isinstance(parent.environment, str):
+ return
+ # registry environment will keep the "azureml:" prefix can be used directly
+ if parent.environment.startswith("azureml://registries"):
+ return
+
+ environment = self._environment_operations.get(**extract_name_and_version(parent.environment))
+ environment._localize(base_path=base_dir.absolute().as_posix())
+ parent.environment = environment
+
+ @experimental
+ @monitor_with_telemetry_mixin(ops_logger, "Component.Download", ActivityType.PUBLICAPI)
+ def download(
+ self,
+ name: str,
+ download_path: Union[PathLike, str] = ".",
+ *,
+ version: Optional[str] = None,
+ ) -> None:
+ """Download the specified component and its dependencies to local. Local component can be used to create
+ the component in another workspace or for offline development.
+
+ :param name: Name of the code component.
+ :type name: str
+ :param Union[PathLike, str] download_path: Local path as download destination,
+ defaults to current working directory of the current user. Will be created if not exists.
+ :type download_path: str
+ :keyword version: Version of the component.
+ :paramtype version: Optional[str]
+ :raises ~OSError: Raised if download_path is pointing to an existing directory that is not empty.
+ identified and retrieved. Details will be provided in the error message.
+ :return: The specified component object.
+ :rtype: ~azure.ai.ml.entities.Component
+ """
+ download_path = Path(download_path)
+ component = self._get(name=name, version=version)
+ self._resolve_azureml_id(component)
+
+ output_dir = Path(download_path)
+ if output_dir.is_dir():
+ # an OSError will be raised if the directory is not empty
+ output_dir.rmdir()
+ output_dir.mkdir(parents=True)
+ # download code
+ self._localize_code(component, output_dir)
+
+ # download environment
+ self._localize_environment(component, output_dir)
+
+ component._localize(output_dir.absolute().as_posix())
+ (output_dir / "component_spec.yaml").write_text(component._to_yaml(), encoding=DefaultOpenEncoding.WRITE)
+
+ def _get(self, name: str, version: Optional[str] = None, label: Optional[str] = None) -> Component:
+ if version and label:
+ msg = "Cannot specify both version and label."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.COMPONENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ if not version and not label:
+ label = DEFAULT_LABEL_NAME
+
+ if label == DEFAULT_LABEL_NAME:
+ label = None
+ version = DEFAULT_COMPONENT_VERSION
+
+ if label:
+ return _resolve_label_to_asset(self, name, label)
+
+ result = self._get_component_version(name, version)
+ component = Component._from_rest_object(result)
+ self._resolve_azureml_id(component, jobs_only=True)
+ return component
+
+ @experimental
+ @monitor_with_telemetry_mixin(ops_logger, "Component.Validate", ActivityType.PUBLICAPI)
+ def validate(
+ self,
+ component: Union[Component, types.FunctionType],
+ raise_on_failure: bool = False,
+ **kwargs: Any,
+ ) -> ValidationResult:
+ """validate a specified component. if there are inline defined
+ entities, e.g. Environment, Code, they won't be created.
+
+ :param component: The component object or a mldesigner component function that generates component object
+ :type component: Union[Component, types.FunctionType]
+ :param raise_on_failure: Whether to raise exception on validation error. Defaults to False
+ :type raise_on_failure: bool
+ :return: All validation errors
+ :rtype: ~azure.ai.ml.entities.ValidationResult
+ """
+ return self._validate(
+ component,
+ raise_on_failure=raise_on_failure,
+ # TODO 2330505: change this to True after remote validation is ready
+ skip_remote_validation=kwargs.pop("skip_remote_validation", True),
+ )
+
+ @monitor_with_telemetry_mixin(ops_logger, "Component.Validate", ActivityType.INTERNALCALL)
+ def _validate(
+ self,
+ component: Union[Component, types.FunctionType],
+ raise_on_failure: bool,
+ skip_remote_validation: bool,
+ ) -> ValidationResult:
+ """Implementation of validate. Add this function to avoid calling validate() directly in create_or_update(),
+ which will impact telemetry statistics & bring experimental warning in create_or_update().
+
+ :param component: The component
+ :type component: Union[Component, types.FunctionType]
+ :param raise_on_failure: Whether to raise on failure.
+ :type raise_on_failure: bool
+ :param skip_remote_validation: Whether to skip remote validation.
+ :type skip_remote_validation: bool
+ :return: The validation result
+ :rtype: ValidationResult
+ """
+ # Update component when the input is a component function
+ if isinstance(component, types.FunctionType):
+ component = _refine_component(component)
+
+ # local validation
+ result = component._validate(raise_error=raise_on_failure)
+ # remote validation, note that preflight_operation is not available for registry client
+ if not skip_remote_validation and self._preflight_operation:
+ workspace = self._workspace_operations.get()
+ remote_validation_result = self._preflight_operation.begin_validate(
+ resource_group_name=self._resource_group_name,
+ deployment_name=self._workspace_name,
+ parameters=component._build_rest_object_for_remote_validation(
+ location=workspace.location,
+ workspace_name=self._workspace_name,
+ ),
+ **self._init_args,
+ )
+ result.merge_with(
+ # pylint: disable=protected-access
+ component._build_validation_result_from_rest_object(remote_validation_result.result()),
+ overwrite=True,
+ )
+ # resolve location for diagnostics from remote validation
+ result.resolve_location_for_diagnostics(component._source_path) # type: ignore
+ return component._try_raise( # pylint: disable=protected-access
+ result,
+ raise_error=raise_on_failure,
+ )
+
+ def _update_flow_rest_object(self, rest_component_resource: Any) -> None:
+ import re
+
+ from azure.ai.ml._utils._arm_id_utils import AMLVersionedArmId
+
+ component_spec = rest_component_resource.properties.component_spec
+ code, flow_file_name = AMLVersionedArmId(component_spec["code"]), component_spec.pop("flow_file_name")
+ # TODO: avoid remote request here if met performance issue
+ created_code = self._code_operations.get(name=code.asset_name, version=code.asset_version)
+ # remove port number and append flow file name to get full uri for flow.dag.yaml
+ component_spec["flow_definition_uri"] = f"{re.sub(r':[0-9]+/', '/', created_code.path)}/{flow_file_name}"
+
+ def _reset_version_if_no_change(self, component: Component, current_name: str, current_version: str) -> Tuple:
+ """Reset component version to default version if there's no change in the component.
+
+ :param component: The component object
+ :type component: Component
+ :param current_name: The component name
+ :type current_name: str
+ :param current_version: The component version
+ :type current_version: str
+ :return: The new version and rest component resource
+ :rtype: Tuple[str, ComponentVersion]
+ """
+ rest_component_resource = component._to_rest_object()
+
+ try:
+ client_component_hash = rest_component_resource.properties.properties.get("client_component_hash")
+ remote_component_version = self._get_component_version(name=current_name) # will raise error if not found.
+ remote_component_hash = remote_component_version.properties.properties.get("client_component_hash")
+ if client_component_hash == remote_component_hash:
+ component.version = remote_component_version.properties.component_spec.get(
+ "version"
+ ) # only update the default version component instead of creating a new version component
+ logger.warning(
+ "The component is not modified compared to the default version "
+ "and the new version component registration is skipped."
+ )
+ return component.version, component._to_rest_object()
+ except ResourceNotFoundError as e:
+ logger.info("Failed to get component version, %s", e)
+ except Exception as e: # pylint: disable=W0718
+ logger.error("Failed to compare client_component_hash, %s", e)
+
+ return current_version, rest_component_resource
+
+ def _create_or_update_component_version(
+ self,
+ component: Component,
+ name: str,
+ version: Optional[str],
+ rest_component_resource: Any,
+ ) -> Any:
+ try:
+ if self._registry_name:
+ start_time = time.time()
+ path_format_arguments = {
+ "componentName": component.name,
+ "resourceGroupName": self._resource_group_name,
+ "registryName": self._registry_name,
+ }
+ poller = self._version_operation.begin_create_or_update(
+ name=name,
+ version=version,
+ resource_group_name=self._operation_scope.resource_group_name,
+ registry_name=self._registry_name,
+ body=rest_component_resource,
+ polling=AzureMLPolling(
+ LROConfigurations.POLL_INTERVAL,
+ path_format_arguments=path_format_arguments,
+ ),
+ )
+ message = f"Creating/updating registry component {component.name} with version {component.version} "
+ polling_wait(poller=poller, start_time=start_time, message=message, timeout=None)
+
+ else:
+ # _auto_increment_version can be True for non-registry component creation operation;
+ # and anonymous component should use hash as version
+ if not component._is_anonymous and component._auto_increment_version:
+ return _create_or_update_autoincrement(
+ name=name,
+ body=rest_component_resource,
+ version_operation=self._version_operation,
+ container_operation=self._container_operation,
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ **self._init_args,
+ )
+
+ return self._version_operation.create_or_update(
+ name=name,
+ version=version,
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ body=rest_component_resource,
+ **self._init_args,
+ )
+ except Exception as e:
+ raise e
+
+ return None
+
+ @monitor_with_telemetry_mixin(
+ logger,
+ "Component.CreateOrUpdate",
+ ActivityType.PUBLICAPI,
+ extra_keys=["is_anonymous"],
+ )
+ def create_or_update(
+ self,
+ component: Component,
+ version: Optional[str] = None,
+ *,
+ skip_validation: bool = False,
+ **kwargs: Any,
+ ) -> Component:
+ """Create or update a specified component. if there're inline defined
+ entities, e.g. Environment, Code, they'll be created together with the
+ component.
+
+ :param component: The component object or a mldesigner component function that generates component object
+ :type component: Union[Component, types.FunctionType]
+ :param version: The component version to override.
+ :type version: str
+ :keyword skip_validation: whether to skip validation before creating/updating the component, defaults to False
+ :paramtype skip_validation: bool
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Component cannot be successfully validated.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.AssetException: Raised if Component assets
+ (e.g. Data, Code, Model, Environment) cannot be successfully validated.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.ComponentException: Raised if Component type is unsupported.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.ModelException: Raised if Component model cannot be successfully validated.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty directory.
+ :return: The specified component object.
+ :rtype: ~azure.ai.ml.entities.Component
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START component_operations_create_or_update]
+ :end-before: [END component_operations_create_or_update]
+ :language: python
+ :dedent: 8
+ :caption: Create component example.
+ """
+ # Update component when the input is a component function
+ if isinstance(component, types.FunctionType):
+ component = _refine_component(component)
+ if version is not None:
+ component.version = version
+ # In non-registry scenario, if component does not have version, no need to get next version here.
+ # As Component property version has setter that updates `_auto_increment_version` in-place, then
+ # a component will get a version after its creation, and it will always use this version in its
+ # future creation operations, which breaks version auto increment mechanism.
+ if self._registry_name and not component.version and component._auto_increment_version:
+ component.version = _get_next_version_from_container(
+ name=component.name,
+ container_operation=self._container_operation,
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ registry_name=self._registry_name,
+ **self._init_args,
+ )
+
+ if not component._is_anonymous:
+ component._is_anonymous = kwargs.pop("is_anonymous", False)
+
+ if not skip_validation:
+ self._validate(component, raise_on_failure=True, skip_remote_validation=True)
+
+ # Create all dependent resources
+ # Only upload dependencies if component is NOT IPP
+ if not component._intellectual_property:
+ self._resolve_arm_id_or_upload_dependencies(component)
+
+ name, version = component._get_rest_name_version()
+ if not component._is_anonymous and kwargs.get("skip_if_no_change"):
+ version, rest_component_resource = self._reset_version_if_no_change(
+ component,
+ current_name=name,
+ current_version=str(version),
+ )
+ else:
+ rest_component_resource = component._to_rest_object()
+
+ # TODO: remove this after server side support directly using client created code
+ from azure.ai.ml.entities._component.flow import FlowComponent
+
+ if isinstance(component, FlowComponent):
+ self._update_flow_rest_object(rest_component_resource)
+
+ result = self._create_or_update_component_version(
+ component,
+ name,
+ version,
+ rest_component_resource,
+ )
+
+ if not result:
+ component = self.get(name=component.name, version=component.version)
+ else:
+ component = Component._from_rest_object(result)
+
+ self._resolve_azureml_id(
+ component=component,
+ jobs_only=True,
+ )
+ return component
+
+ @experimental
+ def prepare_for_sign(self, component: Component) -> None:
+ ignore_file = IgnoreFile()
+
+ if isinstance(component, ComponentCodeMixin):
+ with component._build_code() as code:
+ delete_two_catalog_files(code.path)
+ ignore_file = get_ignore_file(code.path) if code._ignore_file is None else ignore_file
+ file_list = get_upload_files_from_folder(code.path, ignore_file=ignore_file)
+ json_stub = {}
+ json_stub["HashAlgorithm"] = "SHA256"
+ json_stub["CatalogItems"] = {} # type: ignore
+
+ for file_path, file_name in sorted(file_list, key=lambda x: str(x[1]).lower()):
+ file_hash = _get_file_hash(file_path, hashlib.sha256()).hexdigest().upper()
+ json_stub["CatalogItems"][file_name] = file_hash # type: ignore
+
+ json_stub["CatalogItems"] = collections.OrderedDict( # type: ignore
+ sorted(json_stub["CatalogItems"].items()) # type: ignore
+ )
+ create_catalog_files(code.path, json_stub)
+
+ @monitor_with_telemetry_mixin(ops_logger, "Component.Archive", ActivityType.PUBLICAPI)
+ def archive(
+ self,
+ name: str,
+ version: Optional[str] = None,
+ label: Optional[str] = None,
+ # pylint:disable=unused-argument
+ **kwargs: Any,
+ ) -> None:
+ """Archive a component.
+
+ :param name: Name of the component.
+ :type name: str
+ :param version: Version of the component.
+ :type version: str
+ :param label: Label of the component. (mutually exclusive with version).
+ :type label: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START component_operations_archive]
+ :end-before: [END component_operations_archive]
+ :language: python
+ :dedent: 8
+ :caption: Archive component example.
+ """
+ _archive_or_restore(
+ asset_operations=self,
+ version_operation=self._version_operation,
+ container_operation=self._container_operation,
+ is_archived=True,
+ name=name,
+ version=version,
+ label=label,
+ )
+
+ @monitor_with_telemetry_mixin(ops_logger, "Component.Restore", ActivityType.PUBLICAPI)
+ def restore(
+ self,
+ name: str,
+ version: Optional[str] = None,
+ label: Optional[str] = None,
+ # pylint:disable=unused-argument
+ **kwargs: Any,
+ ) -> None:
+ """Restore an archived component.
+
+ :param name: Name of the component.
+ :type name: str
+ :param version: Version of the component.
+ :type version: str
+ :param label: Label of the component. (mutually exclusive with version).
+ :type label: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START component_operations_restore]
+ :end-before: [END component_operations_restore]
+ :language: python
+ :dedent: 8
+ :caption: Restore component example.
+ """
+ _archive_or_restore(
+ asset_operations=self,
+ version_operation=self._version_operation,
+ container_operation=self._container_operation,
+ is_archived=False,
+ name=name,
+ version=version,
+ label=label,
+ )
+
+ def _get_latest_version(self, component_name: str) -> Component:
+ """Returns the latest version of the asset with the given name.
+
+ Latest is defined as the most recently created, not the most
+ recently updated.
+
+ :param component_name: The component name
+ :type component_name: str
+ :return: A latest version of the named Component
+ :rtype: Component
+ """
+
+ result = (
+ _get_latest(
+ component_name,
+ self._version_operation,
+ self._resource_group_name,
+ workspace_name=None,
+ registry_name=self._registry_name,
+ )
+ if self._registry_name
+ else _get_latest(
+ component_name,
+ self._version_operation,
+ self._resource_group_name,
+ self._workspace_name,
+ )
+ )
+ return Component._from_rest_object(result)
+
+ @classmethod
+ def _try_resolve_environment_for_component(
+ cls, component: Union[BaseNode, str], _: str, resolver: _AssetResolver
+ ) -> None:
+ if isinstance(component, BaseNode):
+ component = component._component # pylint: disable=protected-access
+
+ if isinstance(component, str):
+ return
+ potential_parents: List[BaseNode] = [component]
+ if hasattr(component, "task"):
+ potential_parents.append(component.task)
+ for parent in potential_parents:
+ # for internal component, environment may be a dict or InternalEnvironment object
+ # in these two scenarios, we don't need to resolve the environment;
+ # Note for not directly importing InternalEnvironment and check with `isinstance`:
+ # import from azure.ai.ml._internal will enable internal component feature for all users,
+ # therefore, use type().__name__ to avoid import and execute type check
+ if not hasattr(parent, "environment"):
+ continue
+ if isinstance(parent.environment, dict):
+ continue
+ if type(parent.environment).__name__ == "InternalEnvironment":
+ continue
+ parent.environment = resolver(parent.environment, azureml_type=AzureMLResourceType.ENVIRONMENT)
+
+ def _resolve_azureml_id(self, component: Component, jobs_only: bool = False) -> None:
+ # TODO: remove the parameter `jobs_only`. Some tests are expecting an arm id after resolving for now.
+ resolver = self._orchestrators.resolve_azureml_id
+ self._resolve_dependencies_for_component(component, resolver, jobs_only=jobs_only)
+
+ def _resolve_arm_id_or_upload_dependencies(self, component: Component) -> None:
+ resolver = OperationOrchestrator(
+ self._all_operations, self._operation_scope, self._operation_config
+ ).get_asset_arm_id
+
+ self._resolve_dependencies_for_component(component, resolver)
+
+ def _resolve_dependencies_for_component(
+ self,
+ component: Component,
+ resolver: Callable,
+ *,
+ jobs_only: bool = False,
+ ) -> None:
+ # for now, many tests are expecting long arm id instead of short id for environment and code
+ if not jobs_only:
+ if isinstance(component, AutoMLComponent):
+ # no extra dependency for automl component
+ return
+
+ # type check for potential Job type, which is unexpected here.
+ if not isinstance(component, Component):
+ msg = f"Non supported component type: {type(component)}"
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.COMPONENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ # resolve component's code
+ _try_resolve_code_for_component(component=component, resolver=resolver)
+ # resolve component's environment
+ self._try_resolve_environment_for_component(
+ component=component, # type: ignore
+ resolver=resolver,
+ _="",
+ )
+
+ self._resolve_dependencies_for_pipeline_component_jobs(
+ component,
+ resolver=resolver,
+ )
+
+ def _resolve_inputs_for_pipeline_component_jobs(self, jobs: Dict[str, Any], base_path: str) -> None:
+ """Resolve inputs for jobs in a pipeline component.
+
+ :param jobs: A dict of nodes in a pipeline component.
+ :type jobs: Dict[str, Any]
+ :param base_path: The base path used to resolve inputs. Usually it's the base path of the pipeline component.
+ :type base_path: str
+ """
+ from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob
+
+ for _, job_instance in jobs.items():
+ # resolve inputs for each job's component
+ if isinstance(job_instance, BaseNode):
+ node: BaseNode = job_instance
+ self._job_operations._resolve_job_inputs(
+ # parameter group input need to be flattened first
+ self._job_operations._flatten_group_inputs(node.inputs),
+ base_path,
+ )
+ elif isinstance(job_instance, AutoMLJob):
+ self._job_operations._resolve_automl_job_inputs(job_instance)
+
+ @classmethod
+ def _resolve_binding_on_supported_fields_for_node(cls, node: BaseNode) -> None:
+ """Resolve all PipelineInput(binding from sdk) on supported fields to string.
+
+ :param node: The node
+ :type node: BaseNode
+ """
+ from azure.ai.ml.entities._job.pipeline._attr_dict import (
+ try_get_non_arbitrary_attr,
+ )
+ from azure.ai.ml.entities._job.pipeline._io import PipelineInput
+
+ # compute binding to pipeline input is supported on node.
+ supported_fields = ["compute", "compute_name"]
+ for field_name in supported_fields:
+ val = try_get_non_arbitrary_attr(node, field_name)
+ if isinstance(val, PipelineInput):
+ # Put binding string to field
+ setattr(node, field_name, val._data_binding())
+
+ @classmethod
+ def _try_resolve_node_level_task_for_parallel_node(cls, node: BaseNode, _: str, resolver: _AssetResolver) -> None:
+ """Resolve node.task.code for parallel node if it's a reference to node.component.task.code.
+
+ This is a hack operation.
+
+ parallel_node.task.code won't be resolved directly for now, but it will be resolved if
+ parallel_node.task is a reference to parallel_node.component.task. Then when filling back
+ parallel_node.component.task.code, parallel_node.task.code will be changed as well.
+
+ However, if we enable in-memory/on-disk cache for component resolution, such change
+ won't happen, so we resolve node level task code manually here.
+
+ Note that we will always use resolved node.component.code to fill back node.task.code
+ given code overwrite on parallel node won't take effect for now. This is to make behaviors
+ consistent across os and python versions.
+
+ The ideal solution should be done after PRS team decides how to handle parallel.task.code
+
+ :param node: The node
+ :type node: BaseNode
+ :param _: The component name
+ :type _: str
+ :param resolver: The resolver function
+ :type resolver: _AssetResolver
+ """
+ from azure.ai.ml.entities import Parallel, ParallelComponent
+
+ if not isinstance(node, Parallel):
+ return
+ component = node._component # pylint: disable=protected-access
+ if not isinstance(component, ParallelComponent):
+ return
+ if not node.task:
+ return
+
+ if node.task.code:
+ _try_resolve_code_for_component(
+ component,
+ resolver=resolver,
+ )
+ node.task.code = component.code
+ if node.task.environment:
+ node.task.environment = resolver(component.environment, azureml_type=AzureMLResourceType.ENVIRONMENT)
+
+ @classmethod
+ def _set_default_display_name_for_anonymous_component_in_node(cls, node: BaseNode, default_name: str) -> None:
+ """Set default display name for anonymous component in a node.
+ If node._component is an anonymous component and without display name, set the default display name.
+
+ :param node: The node
+ :type node: BaseNode
+ :param default_name: The default name to set
+ :type default_name: str
+ """
+ if not isinstance(node, BaseNode):
+ return
+ component = node._component
+ if isinstance(component, PipelineComponent):
+ return
+ # Set display name as node name
+ # TODO: the same anonymous component with different node name will have different anonymous hash
+ # as their display name will be different.
+ if (
+ isinstance(component, Component)
+ # check if component is anonymous and not created based on its id. We can't directly check
+ # node._component._is_anonymous as it will be set to True on component creation,
+ # which is later than this check
+ and not component.id
+ and not component.display_name
+ ):
+ component.display_name = default_name
+
+ @classmethod
+ def _try_resolve_compute_for_node(cls, node: BaseNode, _: str, resolver: _AssetResolver) -> None:
+ """Resolve compute for base node.
+
+ :param node: The node
+ :type node: BaseNode
+ :param _: The node name
+ :type _: str
+ :param resolver: The resolver function
+ :type resolver: _AssetResolver
+ """
+ if not isinstance(node, BaseNode):
+ return
+ if not isinstance(node._component, PipelineComponent):
+ # Resolve compute for other type
+ # Keep data binding expression as they are
+ if not is_data_binding_expression(node.compute):
+ # Get compute for each job
+ node.compute = resolver(node.compute, azureml_type=AzureMLResourceType.COMPUTE)
+ if has_attr_safe(node, "compute_name") and not is_data_binding_expression(node.compute_name):
+ node.compute_name = resolver(node.compute_name, azureml_type=AzureMLResourceType.COMPUTE)
+
+ @classmethod
+ def _divide_nodes_to_resolve_into_layers(
+ cls,
+ component: PipelineComponent,
+ extra_operations: List[Callable[[BaseNode, str], Any]],
+ ) -> List:
+ """Traverse the pipeline component and divide nodes to resolve into layers. Note that all leaf nodes will be
+ put in the last layer.
+ For example, for below pipeline component, assuming that all nodes need to be resolved:
+ A
+ /|\
+ B C D
+ | |
+ E F
+ |
+ G
+ return value will be:
+ [
+ [("B", B), ("C", C)],
+ [("E", E)],
+ [("D", D), ("F", F), ("G", G)],
+ ]
+
+ :param component: The pipeline component to resolve.
+ :type component: PipelineComponent
+ :param extra_operations: Extra operations to apply on nodes during the traversing.
+ :type extra_operations: List[Callable[Callable[[BaseNode, str], Any]]]
+ :return: A list of layers of nodes to resolve.
+ :rtype: List[List[Tuple[str, BaseNode]]]
+ """
+ nodes_to_process = list(component.jobs.items())
+ layers: List = []
+ leaf_nodes = []
+
+ while nodes_to_process:
+ layers.append([])
+ new_nodes_to_process = []
+ for key, job_instance in nodes_to_process:
+ cls._resolve_binding_on_supported_fields_for_node(job_instance)
+ if isinstance(job_instance, LoopNode):
+ job_instance = job_instance.body
+
+ for extra_operation in extra_operations:
+ extra_operation(job_instance, key)
+
+ if isinstance(job_instance, BaseNode) and isinstance(job_instance._component, PipelineComponent):
+ # candidates for next layer
+ new_nodes_to_process.extend(job_instance.component.jobs.items())
+ # use layers to store pipeline nodes in each layer for now
+ layers[-1].append((key, job_instance))
+ else:
+ # note that LoopNode has already been replaced by its body here
+ leaf_nodes.append((key, job_instance))
+ nodes_to_process = new_nodes_to_process
+
+ # if there is subgraph, the last item in layers will be empty for now as all leaf nodes are stored in leaf_nodes
+ if len(layers) != 0:
+ layers.pop()
+ layers.append(leaf_nodes)
+
+ return layers
+
+ def _get_workspace_key(self) -> str:
+ try:
+ workspace_rest = self._workspace_operations._operation.get(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ )
+ return str(workspace_rest.workspace_id)
+ except HttpResponseError:
+ return "{}/{}/{}".format(self._subscription_id, self._resource_group_name, self._workspace_name)
+
+ def _get_registry_key(self) -> str:
+ """Get key for used registry.
+
+ Note that, although registry id is in registry discovery response, it is not in RegistryDiscoveryDto; and we'll
+ lose the information after deserialization.
+ To avoid changing related rest client, we simply use registry related information from self to construct
+ registry key, which means that on-disk cache will be invalid if a registry is deleted and then created
+ again with the same name.
+
+ :return: The registry key
+ :rtype: str
+ """
+ return "{}/{}/{}".format(self._subscription_id, self._resource_group_name, self._registry_name)
+
+ def _get_client_key(self) -> str:
+ """Get key for used client.
+ Key should be able to uniquely identify used registry or workspace.
+
+ :return: The client key
+ :rtype: str
+ """
+ # check cache first
+ if self._client_key:
+ return self._client_key
+
+ # registry name has a higher priority comparing to workspace name according to current __init__ implementation
+ # of MLClient
+ if self._registry_name:
+ self._client_key = "registry/" + self._get_registry_key()
+ elif self._workspace_name:
+ self._client_key = "workspace/" + self._get_workspace_key()
+ else:
+ # This should never happen.
+ raise ValueError("Either workspace name or registry name must be provided to use component operations.")
+ return self._client_key
+
+ def _resolve_dependencies_for_pipeline_component_jobs(
+ self,
+ component: Union[Component, str],
+ resolver: _AssetResolver,
+ ) -> None:
+ """Resolve dependencies for pipeline component jobs.
+ Will directly return if component is not a pipeline component.
+
+ :param component: The pipeline component to resolve.
+ :type component: Union[Component, str]
+ :param resolver: The resolver to resolve the dependencies.
+ :type resolver: _AssetResolver
+ """
+ if not isinstance(component, PipelineComponent) or not component.jobs:
+ return
+
+ from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob
+
+ self._resolve_inputs_for_pipeline_component_jobs(component.jobs, component._base_path)
+
+ # This is a preparation for concurrent resolution. Nodes will be resolved later layer by layer
+ # from bottom to top, as hash calculation of a parent node will be impacted by resolution
+ # of its child nodes.
+ layers = self._divide_nodes_to_resolve_into_layers(
+ component,
+ extra_operations=[
+ # no need to do this as we now keep the original component name for anonymous components
+ # self._set_default_display_name_for_anonymous_component_in_node,
+ partial(
+ self._try_resolve_node_level_task_for_parallel_node,
+ resolver=resolver,
+ ),
+ partial(self._try_resolve_environment_for_component, resolver=resolver),
+ partial(self._try_resolve_compute_for_node, resolver=resolver),
+ # should we resolve code here after we do extra operations concurrently?
+ ],
+ )
+
+ # cache anonymous component only for now
+ # request level in-memory cache can be a better solution for other type of assets as they are
+ # relatively simple and of small number of distinct instances
+ component_cache = CachedNodeResolver(
+ resolver=resolver,
+ client_key=self._get_client_key(),
+ )
+
+ for layer in reversed(layers):
+ for _, job_instance in layer:
+ if isinstance(job_instance, AutoMLJob):
+ # only compute is resolved here
+ self._job_operations._resolve_arm_id_for_automl_job(job_instance, resolver, inside_pipeline=True)
+ elif isinstance(job_instance, BaseNode):
+ component_cache.register_node_for_lazy_resolution(job_instance)
+ elif isinstance(job_instance, ConditionNode):
+ pass
+ else:
+ msg = f"Non supported job type in Pipeline: {type(job_instance)}"
+ raise ComponentException(
+ message=msg,
+ target=ErrorTarget.COMPONENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ component_cache.resolve_nodes()
+
+
+def _refine_component(component_func: Any) -> Component:
+ """Return the component of function that is decorated by command
+ component decorator.
+
+ :param component_func: Function that is decorated by command component decorator
+ :type component_func: types.FunctionType
+ :return: Component entity of target function
+ :rtype: Component
+ """
+
+ def check_parameter_type(f: Any) -> None:
+ """Check all parameter is annotated or has a default value with clear type(not None).
+
+ :param f: The component function
+ :type f: types.FunctionType
+ """
+ annotations = getattr(f, "__annotations__", {})
+ func_parameters = signature(f).parameters
+ defaults_dict = {key: val.default for key, val in func_parameters.items()}
+ variable_inputs = [
+ key for key, val in func_parameters.items() if val.kind in [val.VAR_POSITIONAL, val.VAR_KEYWORD]
+ ]
+ if variable_inputs:
+ msg = "Cannot register the component {} with variable inputs {!r}."
+ raise ValidationException(
+ message=msg.format(f.__name__, variable_inputs),
+ no_personal_data_message=msg.format("[keys]", "[name]"),
+ target=ErrorTarget.COMPONENT,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ unknown_type_keys = [
+ key for key, val in defaults_dict.items() if key not in annotations and val is Parameter.empty
+ ]
+ if unknown_type_keys:
+ msg = "Unknown type of parameter {} in pipeline func {!r}, please add type annotation."
+ raise ValidationException(
+ message=msg.format(unknown_type_keys, f.__name__),
+ no_personal_data_message=msg.format("[keys]", "[name]"),
+ target=ErrorTarget.COMPONENT,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ def check_non_pipeline_inputs(f: Any) -> None:
+ """Check whether non_pipeline_inputs exist in pipeline builder.
+
+ :param f: The component function
+ :type f: types.FunctionType
+ """
+ if f._pipeline_builder.non_pipeline_parameter_names:
+ msg = "Cannot register pipeline component {!r} with non_pipeline_inputs."
+ raise ValidationException(
+ message=msg.format(f.__name__),
+ no_personal_data_message=msg.format(""),
+ target=ErrorTarget.COMPONENT,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ if hasattr(component_func, "_is_mldesigner_component") and component_func._is_mldesigner_component:
+ return component_func.component
+ if hasattr(component_func, "_is_dsl_func") and component_func._is_dsl_func:
+ check_non_pipeline_inputs(component_func)
+ check_parameter_type(component_func)
+ if component_func._job_settings:
+ module_logger.warning(
+ "Job settings %s on pipeline function '%s' are ignored when creating PipelineComponent.",
+ component_func._job_settings,
+ component_func.__name__,
+ )
+ # Normally pipeline component are created when dsl.pipeline inputs are provided
+ # so pipeline input .result() can resolve to correct value.
+ # When pipeline component created without dsl.pipeline inputs, pipeline input .result() won't work.
+ return component_func._pipeline_builder.build(user_provided_kwargs={})
+ msg = "Function must be a dsl or mldesigner component function: {!r}"
+ raise ValidationException(
+ message=msg.format(component_func),
+ no_personal_data_message=msg.format("component"),
+ error_category=ErrorCategory.USER_ERROR,
+ target=ErrorTarget.COMPONENT,
+ )
+
+
+def _try_resolve_code_for_component(component: Component, resolver: _AssetResolver) -> None:
+ if isinstance(component, ComponentCodeMixin):
+ with component._build_code() as code:
+ if code is None:
+ code = component._get_origin_code_value()
+ if code is None:
+ return
+ component._fill_back_code_value(resolver(code, azureml_type=AzureMLResourceType.CODE)) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_compute_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_compute_operations.py
new file mode 100644
index 00000000..7990a3fa
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_compute_operations.py
@@ -0,0 +1,447 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Any, Dict, Iterable, Optional, cast
+
+from azure.ai.ml._restclient.v2023_08_01_preview import AzureMachineLearningWorkspaces as ServiceClient022023Preview
+from azure.ai.ml._restclient.v2024_04_01_preview import AzureMachineLearningWorkspaces as ServiceClient042024Preview
+from azure.ai.ml._restclient.v2024_04_01_preview.models import SsoSetting
+from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml.constants._common import COMPUTE_UPDATE_ERROR
+from azure.ai.ml.constants._compute import ComputeType
+from azure.ai.ml.entities import AmlComputeNodeInfo, Compute, Usage, VmSize
+from azure.core.polling import LROPoller
+from azure.core.tracing.decorator import distributed_trace
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class ComputeOperations(_ScopeDependentOperations):
+ """ComputeOperations.
+
+ This class should not be instantiated directly. Instead, use the `compute` attribute of an MLClient object.
+
+ :param operation_scope: Scope variables for the operations classes of an MLClient object.
+ :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope
+ :param operation_config: Common configuration for operations classes of an MLClient object.
+ :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig
+ :param service_client: Service client to allow end users to operate on Azure Machine Learning
+ Workspace resources.
+ :type service_client: ~azure.ai.ml._restclient.v2023_02_01_preview.AzureMachineLearningWorkspaces
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client: ServiceClient022023Preview,
+ service_client_2024: ServiceClient042024Preview,
+ **kwargs: Dict,
+ ) -> None:
+ super(ComputeOperations, self).__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+ self._operation = service_client.compute
+ self._operation2024 = service_client_2024.compute
+ self._workspace_operations = service_client.workspaces
+ self._vmsize_operations = service_client.virtual_machine_sizes
+ self._usage_operations = service_client.usages
+ self._init_kwargs = kwargs
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Compute.List", ActivityType.PUBLICAPI)
+ def list(self, *, compute_type: Optional[str] = None) -> Iterable[Compute]:
+ """List computes of the workspace.
+
+ :keyword compute_type: The type of the compute to be listed, case-insensitive. Defaults to AMLCompute.
+ :paramtype compute_type: Optional[str]
+ :return: An iterator like instance of Compute objects.
+ :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.Compute]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START compute_operations_list]
+ :end-before: [END compute_operations_list]
+ :language: python
+ :dedent: 8
+ :caption: Retrieving a list of the AzureML Kubernetes compute resources in a workspace.
+ """
+
+ return cast(
+ Iterable[Compute],
+ self._operation.list(
+ self._operation_scope.resource_group_name,
+ self._workspace_name,
+ cls=lambda objs: [
+ Compute._from_rest_object(obj)
+ for obj in objs
+ if compute_type is None or str(Compute._from_rest_object(obj).type).lower() == compute_type.lower()
+ ],
+ ),
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Compute.Get", ActivityType.PUBLICAPI)
+ def get(self, name: str) -> Compute:
+ """Get a compute resource.
+
+ :param name: Name of the compute resource.
+ :type name: str
+ :return: A Compute object.
+ :rtype: ~azure.ai.ml.entities.Compute
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START compute_operations_get]
+ :end-before: [END compute_operations_get]
+ :language: python
+ :dedent: 8
+ :caption: Retrieving a compute resource from a workspace.
+ """
+
+ rest_obj = self._operation.get(
+ self._operation_scope.resource_group_name,
+ self._workspace_name,
+ name,
+ )
+ return Compute._from_rest_object(rest_obj)
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Compute.ListNodes", ActivityType.PUBLICAPI)
+ def list_nodes(self, name: str) -> Iterable[AmlComputeNodeInfo]:
+ """Retrieve a list of a compute resource's nodes.
+
+ :param name: Name of the compute resource.
+ :type name: str
+ :return: An iterator-like instance of AmlComputeNodeInfo objects.
+ :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.AmlComputeNodeInfo]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START compute_operations_list_nodes]
+ :end-before: [END compute_operations_list_nodes]
+ :language: python
+ :dedent: 8
+ :caption: Retrieving a list of nodes from a compute resource.
+ """
+ return cast(
+ Iterable[AmlComputeNodeInfo],
+ self._operation.list_nodes(
+ self._operation_scope.resource_group_name,
+ self._workspace_name,
+ name,
+ cls=lambda objs: [AmlComputeNodeInfo._from_rest_object(obj) for obj in objs],
+ ),
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Compute.BeginCreateOrUpdate", ActivityType.PUBLICAPI)
+ def begin_create_or_update(self, compute: Compute) -> LROPoller[Compute]:
+ """Create and register a compute resource.
+
+ :param compute: The compute resource definition.
+ :type compute: ~azure.ai.ml.entities.Compute
+ :return: An instance of LROPoller that returns a Compute object once the
+ long-running operation is complete.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.Compute]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START compute_operations_create_update]
+ :end-before: [END compute_operations_create_update]
+ :language: python
+ :dedent: 8
+ :caption: Creating and registering a compute resource.
+ """
+ if compute.type != ComputeType.AMLCOMPUTE:
+ if compute.location:
+ module_logger.warning(
+ "Warning: 'Location' is not supported for compute type %s and will not be used.",
+ compute.type,
+ )
+ compute.location = self._get_workspace_location()
+
+ if not compute.location:
+ compute.location = self._get_workspace_location()
+
+ compute._set_full_subnet_name(
+ self._operation_scope.subscription_id,
+ self._operation_scope.resource_group_name,
+ )
+
+ compute_rest_obj = compute._to_rest_object()
+
+ poller = self._operation.begin_create_or_update(
+ self._operation_scope.resource_group_name,
+ self._workspace_name,
+ compute_name=compute.name,
+ parameters=compute_rest_obj,
+ polling=True,
+ cls=lambda response, deserialized, headers: Compute._from_rest_object(deserialized),
+ )
+
+ return poller
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Compute.Attach", ActivityType.PUBLICAPI)
+ def begin_attach(self, compute: Compute, **kwargs: Any) -> LROPoller[Compute]:
+ """Attach a compute resource to the workspace.
+
+ :param compute: The compute resource definition.
+ :type compute: ~azure.ai.ml.entities.Compute
+ :return: An instance of LROPoller that returns a Compute object once the
+ long-running operation is complete.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.Compute]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START compute_operations_attach]
+ :end-before: [END compute_operations_attach]
+ :language: python
+ :dedent: 8
+ :caption: Attaching a compute resource to the workspace.
+ """
+ return self.begin_create_or_update(compute=compute, **kwargs)
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Compute.BeginUpdate", ActivityType.PUBLICAPI)
+ def begin_update(self, compute: Compute) -> LROPoller[Compute]:
+ """Update a compute resource. Currently only valid for AmlCompute resource types.
+
+ :param compute: The compute resource definition.
+ :type compute: ~azure.ai.ml.entities.Compute
+ :return: An instance of LROPoller that returns a Compute object once the
+ long-running operation is complete.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.Compute]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START compute_operations_update]
+ :end-before: [END compute_operations_update]
+ :language: python
+ :dedent: 8
+ :caption: Updating an AmlCompute resource.
+ """
+ if not compute.type == ComputeType.AMLCOMPUTE:
+ COMPUTE_UPDATE_ERROR.format(compute.name, compute.type)
+
+ compute_rest_obj = compute._to_rest_object()
+
+ poller = self._operation.begin_create_or_update(
+ self._operation_scope.resource_group_name,
+ self._workspace_name,
+ compute_name=compute.name,
+ parameters=compute_rest_obj,
+ polling=True,
+ cls=lambda response, deserialized, headers: Compute._from_rest_object(deserialized),
+ )
+
+ return poller
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Compute.BeginDelete", ActivityType.PUBLICAPI)
+ def begin_delete(self, name: str, *, action: str = "Delete") -> LROPoller[None]:
+ """Delete or detach a compute resource.
+
+ :param name: The name of the compute resource.
+ :type name: str
+ :keyword action: Action to perform. Possible values: ["Delete", "Detach"]. Defaults to "Delete".
+ :type action: str
+ :return: A poller to track the operation status.
+ :rtype: ~azure.core.polling.LROPoller[None]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START compute_operations_delete]
+ :end-before: [END compute_operations_delete]
+ :language: python
+ :dedent: 8
+ :caption: Delete compute example.
+ """
+ return self._operation.begin_delete(
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ compute_name=name,
+ underlying_resource_action=action,
+ **self._init_kwargs,
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Compute.BeginStart", ActivityType.PUBLICAPI)
+ def begin_start(self, name: str) -> LROPoller[None]:
+ """Start a compute instance.
+
+ :param name: The name of the compute instance.
+ :type name: str
+ :return: A poller to track the operation status.
+ :rtype: azure.core.polling.LROPoller[None]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START compute_operations_start]
+ :end-before: [END compute_operations_start]
+ :language: python
+ :dedent: 8
+ :caption: Starting a compute instance.
+ """
+
+ return self._operation.begin_start(
+ self._operation_scope.resource_group_name,
+ self._workspace_name,
+ name,
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Compute.BeginStop", ActivityType.PUBLICAPI)
+ def begin_stop(self, name: str) -> LROPoller[None]:
+ """Stop a compute instance.
+
+ :param name: The name of the compute instance.
+ :type name: str
+ :return: A poller to track the operation status.
+ :rtype: azure.core.polling.LROPoller[None]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START compute_operations_stop]
+ :end-before: [END compute_operations_stop]
+ :language: python
+ :dedent: 8
+ :caption: Stopping a compute instance.
+ """
+ return self._operation.begin_stop(
+ self._operation_scope.resource_group_name,
+ self._workspace_name,
+ name,
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Compute.BeginRestart", ActivityType.PUBLICAPI)
+ def begin_restart(self, name: str) -> LROPoller[None]:
+ """Restart a compute instance.
+
+ :param name: The name of the compute instance.
+ :type name: str
+ :return: A poller to track the operation status.
+ :rtype: azure.core.polling.LROPoller[None]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START compute_operations_restart]
+ :end-before: [END compute_operations_restart]
+ :language: python
+ :dedent: 8
+ :caption: Restarting a stopped compute instance.
+ """
+ return self._operation.begin_restart(
+ self._operation_scope.resource_group_name,
+ self._workspace_name,
+ name,
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Compute.ListUsage", ActivityType.PUBLICAPI)
+ def list_usage(self, *, location: Optional[str] = None) -> Iterable[Usage]:
+ """List the current usage information as well as AzureML resource limits for the
+ given subscription and location.
+
+ :keyword location: The location for which resource usage is queried.
+ Defaults to workspace location.
+ :paramtype location: Optional[str]
+ :return: An iterator over current usage info objects.
+ :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.Usage]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START compute_operations_list_usage]
+ :end-before: [END compute_operations_list_usage]
+ :language: python
+ :dedent: 8
+ :caption: Listing resource usage for the workspace location.
+ """
+ if not location:
+ location = self._get_workspace_location()
+ return cast(
+ Iterable[Usage],
+ self._usage_operations.list(
+ location=location,
+ cls=lambda objs: [Usage._from_rest_object(obj) for obj in objs],
+ ),
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Compute.ListSizes", ActivityType.PUBLICAPI)
+ def list_sizes(self, *, location: Optional[str] = None, compute_type: Optional[str] = None) -> Iterable[VmSize]:
+ """List the supported VM sizes in a location.
+
+ :keyword location: The location upon which virtual-machine-sizes is queried.
+ Defaults to workspace location.
+ :paramtype location: str
+ :keyword compute_type: The type of the compute to be listed, case-insensitive. Defaults to AMLCompute.
+ :paramtype compute_type: Optional[str]
+ :return: An iterator over virtual machine size objects.
+ :rtype: Iterable[~azure.ai.ml.entities.VmSize]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_compute.py
+ :start-after: [START compute_operations_list_sizes]
+ :end-before: [END compute_operations_list_sizes]
+ :language: python
+ :dedent: 8
+ :caption: Listing the supported VM sizes in the workspace location.
+ """
+ if not location:
+ location = self._get_workspace_location()
+ size_list = self._vmsize_operations.list(location=location)
+ if not size_list:
+ return []
+ if compute_type:
+ return [
+ VmSize._from_rest_object(item)
+ for item in size_list.value
+ if compute_type.lower() in (supported_type.lower() for supported_type in item.supported_compute_types)
+ ]
+ return [VmSize._from_rest_object(item) for item in size_list.value]
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Compute.enablesso", ActivityType.PUBLICAPI)
+ @experimental
+ def enable_sso(self, *, name: str, enable_sso: bool = True, **kwargs: Any) -> None:
+ """enable sso for a compute instance.
+
+ :keyword name: Name of the compute instance.
+ :paramtype name: str
+ :keyword enable_sso: enable sso bool flag
+ Default to True
+ :paramtype enable_sso: bool
+ """
+
+ self._operation2024.update_sso_settings(
+ self._operation_scope.resource_group_name,
+ self._workspace_name,
+ name,
+ parameters=SsoSetting(enable_sso=enable_sso),
+ **kwargs,
+ )
+
+ def _get_workspace_location(self) -> str:
+ workspace = self._workspace_operations.get(self._resource_group_name, self._workspace_name)
+ return str(workspace.location)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_data_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_data_operations.py
new file mode 100644
index 00000000..a54b1739
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_data_operations.py
@@ -0,0 +1,891 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access,no-value-for-parameter
+
+import os
+import time
+import uuid
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Any, Dict, Generator, Iterable, List, Optional, Union, cast
+
+from marshmallow.exceptions import ValidationError as SchemaValidationError
+
+from azure.ai.ml._artifacts._artifact_utilities import _check_and_upload_path
+from azure.ai.ml._artifacts._constants import (
+ ASSET_PATH_ERROR,
+ CHANGED_ASSET_PATH_MSG,
+ CHANGED_ASSET_PATH_MSG_NO_PERSONAL_DATA,
+)
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import (
+ AzureMachineLearningWorkspaces as ServiceClient102021Dataplane,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview import AzureMachineLearningWorkspaces as ServiceClient042023_preview
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ListViewType
+from azure.ai.ml._restclient.v2024_01_01_preview import AzureMachineLearningWorkspaces as ServiceClient012024_preview
+from azure.ai.ml._restclient.v2024_01_01_preview.models import ComputeInstanceDataMount
+from azure.ai.ml._scope_dependent_operations import (
+ OperationConfig,
+ OperationsContainer,
+ OperationScope,
+ _ScopeDependentOperations,
+)
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._asset_utils import (
+ _archive_or_restore,
+ _check_or_modify_auto_delete_setting,
+ _create_or_update_autoincrement,
+ _get_latest_version_from_container,
+ _resolve_label_to_asset,
+ _validate_auto_delete_setting_in_data_output,
+ _validate_workspace_managed_datastore,
+)
+from azure.ai.ml._utils._data_utils import (
+ download_mltable_metadata_schema,
+ read_local_mltable_metadata_contents,
+ read_remote_mltable_metadata_contents,
+ validate_mltable_metadata,
+)
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils._http_utils import HttpPipeline
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._utils._registry_utils import (
+ get_asset_body_for_registry_storage,
+ get_registry_client,
+ get_sas_uri_for_registry_asset,
+)
+from azure.ai.ml._utils.utils import is_url
+from azure.ai.ml.constants._common import (
+ ASSET_ID_FORMAT,
+ MLTABLE_METADATA_SCHEMA_URL_FALLBACK,
+ AssetTypes,
+ AzureMLResourceType,
+)
+from azure.ai.ml.data_transfer import import_data as import_data_func
+from azure.ai.ml.entities import PipelineJob, PipelineJobSettings
+from azure.ai.ml.entities._assets import Data, WorkspaceAssetReference
+from azure.ai.ml.entities._data.mltable_metadata import MLTableMetadata
+from azure.ai.ml.entities._data_import.data_import import DataImport
+from azure.ai.ml.entities._inputs_outputs import Output
+from azure.ai.ml.entities._inputs_outputs.external_data import Database
+from azure.ai.ml.exceptions import (
+ AssetPathException,
+ ErrorCategory,
+ ErrorTarget,
+ MlException,
+ ValidationErrorType,
+ ValidationException,
+)
+from azure.ai.ml.operations._datastore_operations import DatastoreOperations
+from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
+from azure.core.paging import ItemPaged
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class DataOperations(_ScopeDependentOperations):
+ """DataOperations.
+
+ You should not instantiate this class directly. Instead, you should
+ create an MLClient instance that instantiates it for you and
+ attaches it as an attribute.
+
+ :param operation_scope: Scope variables for the operations classes of an MLClient object.
+ :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope
+ :param operation_config: Common configuration for operations classes of an MLClient object.
+ :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig
+ :param service_client: Service client to allow end users to operate on Azure Machine Learning Workspace
+ resources (ServiceClient042023Preview or ServiceClient102021Dataplane).
+ :type service_client: typing.Union[
+ ~azure.ai.ml._restclient.v2023_04_01_preview._azure_machine_learning_workspaces.AzureMachineLearningWorkspaces,
+ ~azure.ai.ml._restclient.v2021_10_01_dataplanepreview._azure_machine_learning_workspaces.
+ AzureMachineLearningWorkspaces]
+ :param datastore_operations: Represents a client for performing operations on Datastores.
+ :type datastore_operations: ~azure.ai.ml.operations._datastore_operations.DatastoreOperations
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client: Union[ServiceClient042023_preview, ServiceClient102021Dataplane],
+ service_client_012024_preview: ServiceClient012024_preview,
+ datastore_operations: DatastoreOperations,
+ **kwargs: Any,
+ ):
+ super(DataOperations, self).__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+ self._operation = service_client.data_versions
+ self._container_operation = service_client.data_containers
+ self._datastore_operation = datastore_operations
+ self._compute_operation = service_client_012024_preview.compute
+ self._service_client = service_client
+ self._init_kwargs = kwargs
+ self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline")
+ self._all_operations: OperationsContainer = kwargs.pop("all_operations")
+ # Maps a label to a function which given an asset name,
+ # returns the asset associated with the label
+ self._managed_label_resolver = {"latest": self._get_latest_version}
+
+ @monitor_with_activity(ops_logger, "Data.List", ActivityType.PUBLICAPI)
+ def list(
+ self,
+ name: Optional[str] = None,
+ *,
+ list_view_type: ListViewType = ListViewType.ACTIVE_ONLY,
+ ) -> ItemPaged[Data]:
+ """List the data assets of the workspace.
+
+ :param name: Name of a specific data asset, optional.
+ :type name: Optional[str]
+ :keyword list_view_type: View type for including/excluding (for example) archived data assets.
+ Default: ACTIVE_ONLY.
+ :type list_view_type: Optional[ListViewType]
+ :return: An iterator like instance of Data objects
+ :rtype: ~azure.core.paging.ItemPaged[Data]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START data_operations_list]
+ :end-before: [END data_operations_list]
+ :language: python
+ :dedent: 8
+ :caption: List data assets example.
+ """
+ if name:
+ return (
+ self._operation.list(
+ name=name,
+ registry_name=self._registry_name,
+ cls=lambda objs: [Data._from_rest_object(obj) for obj in objs],
+ list_view_type=list_view_type,
+ **self._scope_kwargs,
+ )
+ if self._registry_name
+ else self._operation.list(
+ name=name,
+ workspace_name=self._workspace_name,
+ cls=lambda objs: [Data._from_rest_object(obj) for obj in objs],
+ list_view_type=list_view_type,
+ **self._scope_kwargs,
+ )
+ )
+ return (
+ self._container_operation.list(
+ registry_name=self._registry_name,
+ cls=lambda objs: [Data._from_container_rest_object(obj) for obj in objs],
+ list_view_type=list_view_type,
+ **self._scope_kwargs,
+ )
+ if self._registry_name
+ else self._container_operation.list(
+ workspace_name=self._workspace_name,
+ cls=lambda objs: [Data._from_container_rest_object(obj) for obj in objs],
+ list_view_type=list_view_type,
+ **self._scope_kwargs,
+ )
+ )
+
+ def _get(self, name: Optional[str], version: Optional[str] = None) -> Data:
+ if version:
+ return (
+ self._operation.get(
+ name=name,
+ version=version,
+ registry_name=self._registry_name,
+ **self._scope_kwargs,
+ **self._init_kwargs,
+ )
+ if self._registry_name
+ else self._operation.get(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ name=name,
+ version=version,
+ **self._init_kwargs,
+ )
+ )
+ return (
+ self._container_operation.get(
+ name=name,
+ registry_name=self._registry_name,
+ **self._scope_kwargs,
+ **self._init_kwargs,
+ )
+ if self._registry_name
+ else self._container_operation.get(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ name=name,
+ **self._init_kwargs,
+ )
+ )
+
+ @monitor_with_activity(ops_logger, "Data.Get", ActivityType.PUBLICAPI)
+ def get(self, name: str, version: Optional[str] = None, label: Optional[str] = None) -> Data: # type: ignore
+ """Get the specified data asset.
+
+ :param name: Name of data asset.
+ :type name: str
+ :param version: Version of data asset.
+ :type version: str
+ :param label: Label of the data asset. (mutually exclusive with version)
+ :type label: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Data cannot be successfully
+ identified and retrieved. Details will be provided in the error message.
+ :return: Data asset object.
+ :rtype: ~azure.ai.ml.entities.Data
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START data_operations_get]
+ :end-before: [END data_operations_get]
+ :language: python
+ :dedent: 8
+ :caption: Get data assets example.
+ """
+ try:
+ if version and label:
+ msg = "Cannot specify both version and label."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.DATA,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ if label:
+ return _resolve_label_to_asset(self, name, label)
+
+ if not version:
+ msg = "Must provide either version or label."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.DATA,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+ data_version_resource = self._get(name, version)
+ return Data._from_rest_object(data_version_resource)
+ except (ValidationException, SchemaValidationError) as ex:
+ log_and_raise_error(ex)
+
+ @monitor_with_activity(ops_logger, "Data.CreateOrUpdate", ActivityType.PUBLICAPI)
+ def create_or_update(self, data: Data) -> Data:
+ """Returns created or updated data asset.
+
+ If not already in storage, asset will be uploaded to the workspace's blob storage.
+
+ :param data: Data asset object.
+ :type data: azure.ai.ml.entities.Data
+ :raises ~azure.ai.ml.exceptions.AssetPathException: Raised when the Data artifact path is
+ already linked to another asset
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Data cannot be successfully validated.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty directory.
+ :return: Data asset object.
+ :rtype: ~azure.ai.ml.entities.Data
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START data_operations_create_or_update]
+ :end-before: [END data_operations_create_or_update]
+ :language: python
+ :dedent: 8
+ :caption: Create data assets example.
+ """
+ try:
+ name = data.name
+ if not data.version and self._registry_name:
+ msg = "Data asset version is required for registry"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.DATA,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+ version = data.version
+
+ sas_uri = None
+ if self._registry_name:
+ # If the data asset is a workspace asset, promote to registry
+ if isinstance(data, WorkspaceAssetReference):
+ try:
+ self._operation.get(
+ name=data.name,
+ version=data.version,
+ resource_group_name=self._resource_group_name,
+ registry_name=self._registry_name,
+ )
+ except Exception as err: # pylint: disable=W0718
+ if isinstance(err, ResourceNotFoundError):
+ pass
+ else:
+ raise err
+ else:
+ msg = "An data asset with this name and version already exists in registry"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.DATA,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ data_res_obj = data._to_rest_object()
+ result = self._service_client.resource_management_asset_reference.begin_import_method(
+ resource_group_name=self._resource_group_name,
+ registry_name=self._registry_name,
+ body=data_res_obj,
+ ).result()
+
+ if not result:
+ data_res_obj = self._get(name=data.name, version=data.version)
+ return Data._from_rest_object(data_res_obj)
+
+ sas_uri = get_sas_uri_for_registry_asset(
+ service_client=self._service_client,
+ name=name,
+ version=version,
+ resource_group=self._resource_group_name,
+ registry=self._registry_name,
+ body=get_asset_body_for_registry_storage(self._registry_name, "data", name, version),
+ )
+
+ referenced_uris = self._validate(data)
+ if referenced_uris:
+ data._referenced_uris = referenced_uris
+
+ data, _ = _check_and_upload_path(
+ artifact=data,
+ asset_operations=self,
+ sas_uri=sas_uri,
+ artifact_type=ErrorTarget.DATA,
+ show_progress=self._show_progress,
+ )
+
+ _check_or_modify_auto_delete_setting(data.auto_delete_setting)
+
+ data_version_resource = data._to_rest_object()
+ auto_increment_version = data._auto_increment_version
+
+ if auto_increment_version:
+ result = _create_or_update_autoincrement(
+ name=data.name,
+ body=data_version_resource,
+ version_operation=self._operation,
+ container_operation=self._container_operation,
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ **self._init_kwargs,
+ )
+ else:
+ result = (
+ self._operation.begin_create_or_update(
+ name=name,
+ version=version,
+ registry_name=self._registry_name,
+ body=data_version_resource,
+ **self._scope_kwargs,
+ ).result()
+ if self._registry_name
+ else self._operation.create_or_update(
+ name=name,
+ version=version,
+ workspace_name=self._workspace_name,
+ body=data_version_resource,
+ **self._scope_kwargs,
+ )
+ )
+
+ if not result and self._registry_name:
+ result = self._get(name=name, version=version)
+
+ return Data._from_rest_object(result)
+ except Exception as ex:
+ if isinstance(ex, (ValidationException, SchemaValidationError)):
+ log_and_raise_error(ex)
+ elif isinstance(ex, HttpResponseError):
+ # service side raises an exception if we attempt to update an existing asset's asset path
+ if str(ex) == ASSET_PATH_ERROR:
+ raise AssetPathException(
+ message=CHANGED_ASSET_PATH_MSG,
+ tartget=ErrorTarget.DATA,
+ no_personal_data_message=CHANGED_ASSET_PATH_MSG_NO_PERSONAL_DATA,
+ error_category=ErrorCategory.USER_ERROR,
+ ) from ex
+ raise ex
+
+ @monitor_with_activity(ops_logger, "Data.ImportData", ActivityType.PUBLICAPI)
+ @experimental
+ def import_data(self, data_import: DataImport, **kwargs: Any) -> PipelineJob:
+ """Returns the data import job that is creating the data asset.
+
+ :param data_import: DataImport object.
+ :type data_import: azure.ai.ml.entities.DataImport
+ :return: data import job object.
+ :rtype: ~azure.ai.ml.entities.PipelineJob
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START data_operations_import_data]
+ :end-before: [END data_operations_import_data]
+ :language: python
+ :dedent: 8
+ :caption: Import data assets example.
+ """
+
+ experiment_name = "data_import_" + str(data_import.name)
+ data_import.type = AssetTypes.MLTABLE if isinstance(data_import.source, Database) else AssetTypes.URI_FOLDER
+
+ # avoid specifying auto_delete_setting in job output now
+ _validate_auto_delete_setting_in_data_output(data_import.auto_delete_setting)
+
+ # block cumtomer specified path on managed datastore
+ data_import.path = _validate_workspace_managed_datastore(data_import.path)
+
+ if "${{name}}" not in str(data_import.path):
+ data_import.path = data_import.path.rstrip("/") + "/${{name}}" # type: ignore
+ import_job = import_data_func(
+ description=data_import.description or experiment_name,
+ display_name=experiment_name,
+ experiment_name=experiment_name,
+ compute="serverless",
+ source=data_import.source,
+ outputs={
+ "sink": Output(
+ type=data_import.type,
+ path=data_import.path, # type: ignore
+ name=data_import.name,
+ version=data_import.version,
+ )
+ },
+ )
+ import_pipeline = PipelineJob(
+ description=data_import.description or experiment_name,
+ tags=data_import.tags,
+ display_name=experiment_name,
+ experiment_name=experiment_name,
+ properties=data_import.properties or {},
+ settings=PipelineJobSettings(force_rerun=True),
+ jobs={experiment_name: import_job},
+ )
+ import_pipeline.properties["azureml.materializationAssetName"] = data_import.name
+ return self._all_operations.all_operations[AzureMLResourceType.JOB].create_or_update(
+ job=import_pipeline, skip_validation=True, **kwargs
+ )
+
+ @monitor_with_activity(ops_logger, "Data.ListMaterializationStatus", ActivityType.PUBLICAPI)
+ def list_materialization_status(
+ self,
+ name: str,
+ *,
+ list_view_type: ListViewType = ListViewType.ACTIVE_ONLY,
+ **kwargs: Any,
+ ) -> Iterable[PipelineJob]:
+ """List materialization jobs of the asset.
+
+ :param name: name of asset being created by the materialization jobs.
+ :type name: str
+ :keyword list_view_type: View type for including/excluding (for example) archived jobs. Default: ACTIVE_ONLY.
+ :paramtype list_view_type: Optional[ListViewType]
+ :return: An iterator like instance of Job objects.
+ :rtype: ~azure.core.paging.ItemPaged[PipelineJob]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START data_operations_list_materialization_status]
+ :end-before: [END data_operations_list_materialization_status]
+ :language: python
+ :dedent: 8
+ :caption: List materialization jobs example.
+ """
+
+ return cast(
+ Iterable[PipelineJob],
+ self._all_operations.all_operations[AzureMLResourceType.JOB].list(
+ job_type="Pipeline",
+ asset_name=name,
+ list_view_type=list_view_type,
+ **kwargs,
+ ),
+ )
+
+ @monitor_with_activity(ops_logger, "Data.Validate", ActivityType.INTERNALCALL)
+ def _validate(self, data: Data) -> Optional[List[str]]:
+ if not data.path:
+ msg = "Missing data path. Path is required for data."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ target=ErrorTarget.DATA,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ asset_path = str(data.path)
+ asset_type = data.type
+ base_path = data.base_path
+
+ if asset_type == AssetTypes.MLTABLE:
+ if is_url(asset_path):
+ try:
+ metadata_contents = read_remote_mltable_metadata_contents(
+ base_uri=asset_path,
+ datastore_operations=self._datastore_operation,
+ requests_pipeline=self._requests_pipeline,
+ )
+ metadata_yaml_path = None
+ except Exception: # pylint: disable=W0718
+ # skip validation for remote MLTable when the contents cannot be read
+ module_logger.info("Unable to access MLTable metadata at path %s", asset_path)
+ return None
+ else:
+ metadata_contents = read_local_mltable_metadata_contents(path=asset_path)
+ metadata_yaml_path = Path(asset_path, "MLTable")
+ metadata = MLTableMetadata._load(metadata_contents, metadata_yaml_path)
+ mltable_metadata_schema = self._try_get_mltable_metadata_jsonschema(data._mltable_schema_url)
+ if mltable_metadata_schema and not data._skip_validation:
+ validate_mltable_metadata(
+ mltable_metadata_dict=metadata_contents,
+ mltable_schema=mltable_metadata_schema,
+ )
+ return cast(Optional[List[str]], metadata.referenced_uris())
+
+ if is_url(asset_path):
+ # skip validation for remote URI_FILE or URI_FOLDER
+ pass
+ elif os.path.isabs(asset_path):
+ _assert_local_path_matches_asset_type(asset_path, asset_type)
+ else:
+ abs_path = Path(base_path, asset_path).resolve()
+ _assert_local_path_matches_asset_type(str(abs_path), asset_type)
+
+ return None
+
+ def _try_get_mltable_metadata_jsonschema(self, mltable_schema_url: Optional[str]) -> Optional[Dict]:
+ if mltable_schema_url is None:
+ mltable_schema_url = MLTABLE_METADATA_SCHEMA_URL_FALLBACK
+ try:
+ return cast(Optional[Dict], download_mltable_metadata_schema(mltable_schema_url, self._requests_pipeline))
+ except Exception: # pylint: disable=W0718
+ module_logger.info(
+ 'Failed to download MLTable metadata jsonschema from "%s", skipping validation',
+ mltable_schema_url,
+ )
+ return None
+
+ @monitor_with_activity(ops_logger, "Data.Archive", ActivityType.PUBLICAPI)
+ def archive(
+ self,
+ name: str,
+ version: Optional[str] = None,
+ label: Optional[str] = None,
+ # pylint:disable=unused-argument
+ **kwargs: Any,
+ ) -> None:
+ """Archive a data asset.
+
+ :param name: Name of data asset.
+ :type name: str
+ :param version: Version of data asset.
+ :type version: str
+ :param label: Label of the data asset. (mutually exclusive with version)
+ :type label: str
+ :return: None
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START data_operations_archive]
+ :end-before: [END data_operations_archive]
+ :language: python
+ :dedent: 8
+ :caption: Archive data asset example.
+ """
+
+ _archive_or_restore(
+ asset_operations=self,
+ version_operation=self._operation,
+ container_operation=self._container_operation,
+ is_archived=True,
+ name=name,
+ version=version,
+ label=label,
+ )
+
+ @monitor_with_activity(ops_logger, "Data.Restore", ActivityType.PUBLICAPI)
+ def restore(
+ self,
+ name: str,
+ version: Optional[str] = None,
+ label: Optional[str] = None,
+ # pylint:disable=unused-argument
+ **kwargs: Any,
+ ) -> None:
+ """Restore an archived data asset.
+
+ :param name: Name of data asset.
+ :type name: str
+ :param version: Version of data asset.
+ :type version: str
+ :param label: Label of the data asset. (mutually exclusive with version)
+ :type label: str
+ :return: None
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START data_operations_restore]
+ :end-before: [END data_operations_restore]
+ :language: python
+ :dedent: 8
+ :caption: Restore data asset example.
+ """
+
+ _archive_or_restore(
+ asset_operations=self,
+ version_operation=self._operation,
+ container_operation=self._container_operation,
+ is_archived=False,
+ name=name,
+ version=version,
+ label=label,
+ )
+
+ def _get_latest_version(self, name: str) -> Data:
+ """Returns the latest version of the asset with the given name. Latest is defined as the most recently created,
+ not the most recently updated.
+
+ :param name: The asset name
+ :type name: str
+ :return: The latest asset
+ :rtype: Data
+ """
+ latest_version = _get_latest_version_from_container(
+ name,
+ self._container_operation,
+ self._resource_group_name,
+ self._workspace_name,
+ self._registry_name,
+ )
+ return self.get(name, version=latest_version)
+
+ @monitor_with_activity(ops_logger, "data.Share", ActivityType.PUBLICAPI)
+ @experimental
+ def share(
+ self,
+ name: str,
+ version: str,
+ *,
+ share_with_name: str,
+ share_with_version: str,
+ registry_name: str,
+ **kwargs: Any,
+ ) -> Data:
+ """Share a data asset from workspace to registry.
+
+ :param name: Name of data asset.
+ :type name: str
+ :param version: Version of data asset.
+ :type version: str
+ :keyword share_with_name: Name of data asset to share with.
+ :paramtype share_with_name: str
+ :keyword share_with_version: Version of data asset to share with.
+ :paramtype share_with_version: str
+ :keyword registry_name: Name of the destination registry.
+ :paramtype registry_name: str
+ :return: Data asset object.
+ :rtype: ~azure.ai.ml.entities.Data
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START data_operations_share]
+ :end-before: [END data_operations_share]
+ :language: python
+ :dedent: 8
+ :caption: Share data asset example.
+ """
+
+ # Get workspace info to get workspace GUID
+ workspace = self._service_client.workspaces.get(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ **kwargs,
+ )
+ workspace_guid = workspace.workspace_id
+ workspace_location = workspace.location
+
+ # Get data asset ID
+ asset_id = ASSET_ID_FORMAT.format(
+ workspace_location,
+ workspace_guid,
+ AzureMLResourceType.DATA,
+ name,
+ version,
+ )
+
+ data_ref = WorkspaceAssetReference(
+ name=share_with_name if share_with_name else name,
+ version=share_with_version if share_with_version else version,
+ asset_id=asset_id,
+ )
+
+ with self._set_registry_client(registry_name):
+ return self.create_or_update(data_ref)
+
+ @monitor_with_activity(ops_logger, "data.Mount", ActivityType.PUBLICAPI)
+ @experimental
+ def mount(
+ self,
+ path: str,
+ *,
+ mount_point: Optional[str] = None,
+ mode: str = "ro_mount",
+ debug: bool = False,
+ persistent: bool = False,
+ **kwargs,
+ ) -> None:
+ """Mount a data asset to a local path, so that you can access data inside it
+ under a local path with any tools of your choice.
+
+ :param path: The data asset path to mount, in the form of `azureml:<name>` or `azureml:<name>:<version>`.
+ :type path: str
+ :keyword mount_point: A local path used as mount point.
+ :type mount_point: str
+ :keyword mode: Mount mode. Only `ro_mount` (read-only) is supported for data asset mount.
+ :type mode: str
+ :keyword debug: Whether to enable verbose logging.
+ :type debug: bool
+ :keyword persistent: Whether to persist the mount after reboot. Applies only when running on Compute Instance,
+ where the 'CI_NAME' environment variable is set."
+ :type persistent: bool
+ :return: None
+ """
+
+ assert mode in ["ro_mount", "rw_mount"], "mode should be either `ro_mount` or `rw_mount`"
+ read_only = mode == "ro_mount"
+ assert read_only, "read-write mount for data asset is not supported yet"
+
+ ci_name = os.environ.get("CI_NAME")
+ assert not persistent or (
+ persistent and ci_name is not None
+ ), "persistent mount is only supported on Compute Instance, where the 'CI_NAME' environment variable is set."
+
+ try:
+ from azureml.dataprep import rslex_fuse_subprocess_wrapper
+ except ImportError as exc:
+ raise MlException(
+ "Mount operations requires package azureml-dataprep-rslex installed. "
+ + "You can install it with Azure ML SDK with `pip install azure-ai-ml[mount]`."
+ ) from exc
+
+ uri = rslex_fuse_subprocess_wrapper.build_data_asset_uri(
+ self._operation_scope._subscription_id, self._resource_group_name, self._workspace_name, path
+ )
+ if persistent and ci_name is not None:
+ mount_name = f"unified_mount_{str(uuid.uuid4()).replace('-', '')}"
+ self._compute_operation.update_data_mounts(
+ self._resource_group_name,
+ self._workspace_name,
+ ci_name,
+ [
+ ComputeInstanceDataMount(
+ source=uri,
+ source_type="URI",
+ mount_name=mount_name,
+ mount_action="Mount",
+ mount_path=mount_point or "",
+ )
+ ],
+ api_version="2021-01-01",
+ **kwargs,
+ )
+ print(f"Mount requested [name: {mount_name}]. Waiting for completion ...")
+ while True:
+ compute = self._compute_operation.get(self._resource_group_name, self._workspace_name, ci_name)
+ mounts = compute.properties.properties.data_mounts
+ try:
+ mount = [mount for mount in mounts if mount.mount_name == mount_name][0]
+ if mount.mount_state == "Mounted":
+ print(f"Mounted [name: {mount_name}].")
+ break
+ if mount.mount_state == "MountRequested":
+ pass
+ elif mount.mount_state == "MountFailed":
+ msg = f"Mount failed [name: {mount_name}]: {mount.error}"
+ raise MlException(message=msg, no_personal_data_message=msg)
+ else:
+ msg = f"Got unexpected mount state [name: {mount_name}]: {mount.mount_state}"
+ raise MlException(message=msg, no_personal_data_message=msg)
+ except IndexError:
+ pass
+ time.sleep(5)
+
+ else:
+ rslex_fuse_subprocess_wrapper.start_fuse_mount_subprocess(
+ uri, mount_point, read_only, debug, credential=self._operation._config.credential
+ )
+
+ @contextmanager
+ # pylint: disable-next=docstring-missing-return,docstring-missing-rtype
+ def _set_registry_client(self, registry_name: str) -> Generator:
+ """Sets the registry client for the data operations.
+
+ :param registry_name: Name of the registry.
+ :type registry_name: str
+ """
+ rg_ = self._operation_scope._resource_group_name
+ sub_ = self._operation_scope._subscription_id
+ registry_ = self._operation_scope.registry_name
+ client_ = self._service_client
+ data_versions_operation_ = self._operation
+
+ try:
+ _client, _rg, _sub = get_registry_client(self._service_client._config.credential, registry_name)
+ self._operation_scope.registry_name = registry_name
+ self._operation_scope._resource_group_name = _rg
+ self._operation_scope._subscription_id = _sub
+ self._service_client = _client
+ self._operation = _client.data_versions
+ yield
+ finally:
+ self._operation_scope.registry_name = registry_
+ self._operation_scope._resource_group_name = rg_
+ self._operation_scope._subscription_id = sub_
+ self._service_client = client_
+ self._operation = data_versions_operation_
+
+
+def _assert_local_path_matches_asset_type(
+ local_path: str,
+ asset_type: str,
+) -> None:
+ # assert file system type matches asset type
+ if asset_type == AssetTypes.URI_FOLDER and not os.path.isdir(local_path):
+ raise ValidationException(
+ message="File path does not match asset type {}: {}".format(asset_type, local_path),
+ no_personal_data_message="File path does not match asset type {}".format(asset_type),
+ target=ErrorTarget.DATA,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND,
+ )
+ if asset_type == AssetTypes.URI_FILE and not os.path.isfile(local_path):
+ raise ValidationException(
+ message="File path does not match asset type {}: {}".format(asset_type, local_path),
+ no_personal_data_message="File path does not match asset type {}".format(asset_type),
+ target=ErrorTarget.DATA,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_dataset_dataplane_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_dataset_dataplane_operations.py
new file mode 100644
index 00000000..d9a95074
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_dataset_dataplane_operations.py
@@ -0,0 +1,32 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from typing import List
+
+from azure.ai.ml._restclient.dataset_dataplane import AzureMachineLearningWorkspaces as ServiceClientDatasetDataplane
+from azure.ai.ml._restclient.dataset_dataplane.models import BatchDataUriResponse, BatchGetResolvedURIs
+from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations
+
+module_logger = logging.getLogger(__name__)
+
+
+class DatasetDataplaneOperations(_ScopeDependentOperations):
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client: ServiceClientDatasetDataplane,
+ ):
+ super().__init__(operation_scope, operation_config)
+ self._operation = service_client.data_version
+
+ def get_batch_dataset_uris(self, dataset_ids: List[str]) -> BatchDataUriResponse:
+ batch_uri_request = BatchGetResolvedURIs(values=dataset_ids)
+ return self._operation.batch_get_resolved_uris(
+ self._operation_scope.subscription_id,
+ self._operation_scope.resource_group_name,
+ self._workspace_name,
+ body=batch_uri_request,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_datastore_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_datastore_operations.py
new file mode 100644
index 00000000..74a518e9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_datastore_operations.py
@@ -0,0 +1,329 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import time
+import uuid
+from typing import Dict, Iterable, Optional, cast
+
+from marshmallow.exceptions import ValidationError as SchemaValidationError
+
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._restclient.v2024_01_01_preview import AzureMachineLearningWorkspaces as ServiceClient012024Preview
+from azure.ai.ml._restclient.v2024_01_01_preview.models import ComputeInstanceDataMount
+from azure.ai.ml._restclient.v2024_07_01_preview import AzureMachineLearningWorkspaces as ServiceClient072024Preview
+from azure.ai.ml._restclient.v2024_07_01_preview.models import Datastore as DatastoreData
+from azure.ai.ml._restclient.v2024_07_01_preview.models import DatastoreSecrets, NoneDatastoreCredentials, SecretExpiry
+from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml.entities._datastore.datastore import Datastore
+from azure.ai.ml.exceptions import MlException, ValidationException
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class DatastoreOperations(_ScopeDependentOperations):
+ """Represents a client for performing operations on Datastores.
+
+ You should not instantiate this class directly. Instead, you should create MLClient and use this client via the
+ property MLClient.datastores
+
+ :param operation_scope: Scope variables for the operations classes of an MLClient object.
+ :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope
+ :param operation_config: Common configuration for operations classes of an MLClient object.
+ :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig
+ :param serviceclient_2024_01_01_preview: Service client to allow end users to operate on Azure Machine Learning
+ Workspace resources.
+ :type serviceclient_2024_01_01_preview: ~azure.ai.ml._restclient.v2023_01_01_preview.
+ _azure_machine_learning_workspaces.AzureMachineLearningWorkspaces
+ :param serviceclient_2024_07_01_preview: Service client to allow end users to operate on Azure Machine Learning
+ Workspace resources.
+ :type serviceclient_2024_07_01_preview: ~azure.ai.ml._restclient.v2024_07_01_preview.
+ _azure_machine_learning_workspaces.AzureMachineLearningWorkspaces
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ serviceclient_2024_01_01_preview: ServiceClient012024Preview,
+ serviceclient_2024_07_01_preview: ServiceClient072024Preview,
+ **kwargs: Dict,
+ ):
+ super(DatastoreOperations, self).__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+ self._operation = serviceclient_2024_07_01_preview.datastores
+ self._compute_operation = serviceclient_2024_01_01_preview.compute
+ self._credential = serviceclient_2024_07_01_preview._config.credential
+ self._init_kwargs = kwargs
+
+ @monitor_with_activity(ops_logger, "Datastore.List", ActivityType.PUBLICAPI)
+ def list(self, *, include_secrets: bool = False) -> Iterable[Datastore]:
+ """Lists all datastores and associated information within a workspace.
+
+ :keyword include_secrets: Include datastore secrets in returned datastores, defaults to False
+ :paramtype include_secrets: bool
+ :return: An iterator like instance of Datastore objects
+ :rtype: ~azure.core.paging.ItemPaged[Datastore]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START datastore_operations_list]
+ :end-before: [END datastore_operations_list]
+ :language: python
+ :dedent: 8
+ :caption: List datastore example.
+ """
+
+ def _list_helper(datastore_resource: Datastore, include_secrets: bool) -> Datastore:
+ if include_secrets:
+ self._fetch_and_populate_secret(datastore_resource)
+ return Datastore._from_rest_object(datastore_resource)
+
+ return cast(
+ Iterable[Datastore],
+ self._operation.list(
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ cls=lambda objs: [_list_helper(obj, include_secrets) for obj in objs],
+ **self._init_kwargs,
+ ),
+ )
+
+ @monitor_with_activity(ops_logger, "Datastore.ListSecrets", ActivityType.PUBLICAPI)
+ def _list_secrets(self, name: str, expirable_secret: bool = False) -> DatastoreSecrets:
+ return self._operation.list_secrets(
+ name=name,
+ body=SecretExpiry(expirable_secret=expirable_secret),
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ **self._init_kwargs,
+ )
+
+ @monitor_with_activity(ops_logger, "Datastore.Delete", ActivityType.PUBLICAPI)
+ def delete(self, name: str) -> None:
+ """Deletes a datastore reference with the given name from the workspace. This method does not delete the actual
+ datastore or underlying data in the datastore.
+
+ :param name: Name of the datastore
+ :type name: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START datastore_operations_delete]
+ :end-before: [END datastore_operations_delete]
+ :language: python
+ :dedent: 8
+ :caption: Delete datastore example.
+ """
+
+ self._operation.delete(
+ name=name,
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ **self._init_kwargs,
+ )
+
+ @monitor_with_activity(ops_logger, "Datastore.Get", ActivityType.PUBLICAPI)
+ def get(self, name: str, *, include_secrets: bool = False) -> Datastore: # type: ignore
+ """Returns information about the datastore referenced by the given name.
+
+ :param name: Datastore name
+ :type name: str
+ :keyword include_secrets: Include datastore secrets in the returned datastore, defaults to False
+ :paramtype include_secrets: bool
+ :return: Datastore with the specified name.
+ :rtype: Datastore
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START datastore_operations_get]
+ :end-before: [END datastore_operations_get]
+ :language: python
+ :dedent: 8
+ :caption: Get datastore example.
+ """
+ try:
+ datastore_resource = self._operation.get(
+ name=name,
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ **self._init_kwargs,
+ )
+ if include_secrets:
+ self._fetch_and_populate_secret(datastore_resource)
+ return Datastore._from_rest_object(datastore_resource)
+ except (ValidationException, SchemaValidationError) as ex:
+ log_and_raise_error(ex)
+
+ def _fetch_and_populate_secret(self, datastore_resource: DatastoreData) -> None:
+ if datastore_resource.name and not isinstance(
+ datastore_resource.properties.credentials, NoneDatastoreCredentials
+ ):
+ secrets = self._list_secrets(name=datastore_resource.name, expirable_secret=True)
+ datastore_resource.properties.credentials.secrets = secrets
+
+ @monitor_with_activity(ops_logger, "Datastore.GetDefault", ActivityType.PUBLICAPI)
+ def get_default(self, *, include_secrets: bool = False) -> Datastore: # type: ignore
+ """Returns the workspace's default datastore.
+
+ :keyword include_secrets: Include datastore secrets in the returned datastore, defaults to False
+ :paramtype include_secrets: bool
+ :return: The default datastore.
+ :rtype: Datastore
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START datastore_operations_get_default]
+ :end-before: [END datastore_operations_get_default]
+ :language: python
+ :dedent: 8
+ :caption: Get default datastore example.
+ """
+ try:
+ datastore_resource = self._operation.list(
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ is_default=True,
+ **self._init_kwargs,
+ ).next()
+ if include_secrets:
+ self._fetch_and_populate_secret(datastore_resource)
+ return Datastore._from_rest_object(datastore_resource)
+ except (ValidationException, SchemaValidationError) as ex:
+ log_and_raise_error(ex)
+
+ @monitor_with_activity(ops_logger, "Datastore.CreateOrUpdate", ActivityType.PUBLICAPI)
+ def create_or_update(self, datastore: Datastore) -> Datastore: # type: ignore
+ """Attaches the passed in datastore to the workspace or updates the datastore if it already exists.
+
+ :param datastore: The configuration of the datastore to attach.
+ :type datastore: Datastore
+ :return: The attached datastore.
+ :rtype: Datastore
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START datastore_operations_create_or_update]
+ :end-before: [END datastore_operations_create_or_update]
+ :language: python
+ :dedent: 8
+ :caption: Create datastore example.
+ """
+ try:
+ ds_request = datastore._to_rest_object()
+ datastore_resource = self._operation.create_or_update(
+ name=datastore.name,
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ body=ds_request,
+ skip_validation=True,
+ )
+ return Datastore._from_rest_object(datastore_resource)
+ except Exception as ex: # pylint: disable=W0718
+ if isinstance(ex, (ValidationException, SchemaValidationError)):
+ log_and_raise_error(ex)
+ else:
+ raise ex
+
+ @monitor_with_activity(ops_logger, "Datastore.Mount", ActivityType.PUBLICAPI)
+ @experimental
+ def mount(
+ self,
+ path: str,
+ *,
+ mount_point: Optional[str] = None,
+ mode: str = "ro_mount",
+ debug: bool = False,
+ persistent: bool = False,
+ **kwargs,
+ ) -> None:
+ """Mount a datastore to a local path, so that you can access data inside it
+ under a local path with any tools of your choice.
+
+ :param path: The data store path to mount, in the form of `<name>` or `azureml://datastores/<name>`.
+ :type path: str
+ :keyword mount_point: A local path used as mount point.
+ :type mount_point: str
+ :keyword mode: Mount mode, either `ro_mount` (read-only) or `rw_mount` (read-write).
+ :type mode: str
+ :keyword debug: Whether to enable verbose logging.
+ :type debug: bool
+ :keyword persistent: Whether to persist the mount after reboot. Applies only when running on Compute Instance,
+ where the 'CI_NAME' environment variable is set."
+ :type persistent: bool
+ :return: None
+ """
+
+ assert mode in ["ro_mount", "rw_mount"], "mode should be either `ro_mount` or `rw_mount`"
+ read_only = mode == "ro_mount"
+
+ import os
+
+ ci_name = os.environ.get("CI_NAME")
+ assert not persistent or (
+ persistent and ci_name is not None
+ ), "persistent mount is only supported on Compute Instance, where the 'CI_NAME' environment variable is set."
+
+ try:
+ from azureml.dataprep import rslex_fuse_subprocess_wrapper
+ except ImportError as exc:
+ msg = "Mount operations requires package azureml-dataprep-rslex installed. \
+ You can install it with Azure ML SDK with `pip install azure-ai-ml[mount]`."
+ raise MlException(message=msg, no_personal_data_message=msg) from exc
+
+ uri = rslex_fuse_subprocess_wrapper.build_datastore_uri(
+ self._operation_scope._subscription_id, self._resource_group_name, self._workspace_name, path
+ )
+ if persistent and ci_name is not None:
+ mount_name = f"unified_mount_{str(uuid.uuid4()).replace('-', '')}"
+ self._compute_operation.update_data_mounts(
+ self._resource_group_name,
+ self._workspace_name,
+ ci_name,
+ [
+ ComputeInstanceDataMount(
+ source=uri,
+ source_type="URI",
+ mount_name=mount_name,
+ mount_action="Mount",
+ mount_path=mount_point or "",
+ )
+ ],
+ api_version="2021-01-01",
+ **kwargs,
+ )
+ print(f"Mount requested [name: {mount_name}]. Waiting for completion ...")
+ while True:
+ compute = self._compute_operation.get(self._resource_group_name, self._workspace_name, ci_name)
+ mounts = compute.properties.properties.data_mounts
+ try:
+ mount = [mount for mount in mounts if mount.mount_name == mount_name][0]
+ if mount.mount_state == "Mounted":
+ print(f"Mounted [name: {mount_name}].")
+ break
+ if mount.mount_state == "MountRequested":
+ pass
+ elif mount.mount_state == "MountFailed":
+ msg = f"Mount failed [name: {mount_name}]: {mount.error}"
+ raise MlException(message=msg, no_personal_data_message=msg)
+ else:
+ msg = f"Got unexpected mount state [name: {mount_name}]: {mount.mount_state}"
+ raise MlException(message=msg, no_personal_data_message=msg)
+ except IndexError:
+ pass
+ time.sleep(5)
+ else:
+ rslex_fuse_subprocess_wrapper.start_fuse_mount_subprocess(
+ uri, mount_point, read_only, debug, credential=self._operation._config.credential
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_environment_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_environment_operations.py
new file mode 100644
index 00000000..228204a7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_environment_operations.py
@@ -0,0 +1,569 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from contextlib import contextmanager
+from typing import Any, Generator, Iterable, Optional, Union, cast
+
+from marshmallow.exceptions import ValidationError as SchemaValidationError
+
+from azure.ai.ml._artifacts._artifact_utilities import _check_and_upload_env_build_context
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import (
+ AzureMachineLearningWorkspaces as ServiceClient102021Dataplane,
+)
+from azure.ai.ml._restclient.v2023_04_01_preview import AzureMachineLearningWorkspaces as ServiceClient042023Preview
+from azure.ai.ml._restclient.v2023_04_01_preview.models import EnvironmentVersion, ListViewType
+from azure.ai.ml._scope_dependent_operations import (
+ OperationConfig,
+ OperationsContainer,
+ OperationScope,
+ _ScopeDependentOperations,
+)
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._asset_utils import (
+ _archive_or_restore,
+ _get_latest,
+ _get_next_latest_versions_from_container,
+ _resolve_label_to_asset,
+)
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._utils._registry_utils import (
+ get_asset_body_for_registry_storage,
+ get_registry_client,
+ get_sas_uri_for_registry_asset,
+)
+from azure.ai.ml.constants._common import ARM_ID_PREFIX, ASSET_ID_FORMAT, AzureMLResourceType
+from azure.ai.ml.entities._assets import Environment, WorkspaceAssetReference
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class EnvironmentOperations(_ScopeDependentOperations):
+ """EnvironmentOperations.
+
+ You should not instantiate this class directly. Instead, you should
+ create an MLClient instance that instantiates it for you and
+ attaches it as an attribute.
+
+ :param operation_scope: Scope variables for the operations classes of an MLClient object.
+ :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope
+ :param operation_config: Common configuration for operations classes of an MLClient object.
+ :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig
+ :param service_client: Service client to allow end users to operate on Azure Machine Learning Workspace
+ resources (ServiceClient042023Preview or ServiceClient102021Dataplane).
+ :type service_client: typing.Union[
+ ~azure.ai.ml._restclient.v2023_04_01_preview._azure_machine_learning_workspaces.AzureMachineLearningWorkspaces,
+ ~azure.ai.ml._restclient.v2021_10_01_dataplanepreview._azure_machine_learning_workspaces.
+ AzureMachineLearningWorkspaces]
+ :param all_operations: All operations classes of an MLClient object.
+ :type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client: Union[ServiceClient042023Preview, ServiceClient102021Dataplane],
+ all_operations: OperationsContainer,
+ **kwargs: Any,
+ ):
+ super(EnvironmentOperations, self).__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+ self._kwargs = kwargs
+ self._containers_operations = service_client.environment_containers
+ self._version_operations = service_client.environment_versions
+ self._service_client = service_client
+ self._all_operations = all_operations
+ self._datastore_operation = all_operations.all_operations[AzureMLResourceType.DATASTORE]
+
+ # Maps a label to a function which given an asset name,
+ # returns the asset associated with the label
+ self._managed_label_resolver = {"latest": self._get_latest_version}
+
+ @monitor_with_activity(ops_logger, "Environment.CreateOrUpdate", ActivityType.PUBLICAPI)
+ def create_or_update(self, environment: Environment) -> Environment: # type: ignore
+ """Returns created or updated environment asset.
+
+ :param environment: Environment object
+ :type environment: ~azure.ai.ml.entities._assets.Environment
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Environment cannot be successfully validated.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty directory.
+ :return: Created or updated Environment object
+ :rtype: ~azure.ai.ml.entities.Environment
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START env_operations_create_or_update]
+ :end-before: [END env_operations_create_or_update]
+ :language: python
+ :dedent: 8
+ :caption: Create environment.
+ """
+ try:
+ if not environment.version and environment._auto_increment_version:
+
+ next_version, latest_version = _get_next_latest_versions_from_container(
+ name=environment.name,
+ container_operation=self._containers_operations,
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ registry_name=self._registry_name,
+ **self._kwargs,
+ )
+ # If user not passing the version, SDK will try to update the latest version
+ return self._try_update_latest_version(next_version, latest_version, environment)
+
+ sas_uri = None
+ if self._registry_name:
+ if isinstance(environment, WorkspaceAssetReference):
+ # verify that environment is not already in registry
+ try:
+ self._version_operations.get(
+ name=environment.name,
+ version=environment.version,
+ resource_group_name=self._resource_group_name,
+ registry_name=self._registry_name,
+ )
+ except Exception as err: # pylint: disable=W0718
+ if isinstance(err, ResourceNotFoundError):
+ pass
+ else:
+ raise err
+ else:
+ msg = "A environment with this name and version already exists in registry"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.ENVIRONMENT,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ environment_rest = environment._to_rest_object()
+ result = self._service_client.resource_management_asset_reference.begin_import_method(
+ resource_group_name=self._resource_group_name,
+ registry_name=self._registry_name,
+ body=environment_rest,
+ **self._kwargs,
+ ).result()
+
+ if not result:
+ env_rest_obj = self._get(name=environment.name, version=environment.version)
+ return Environment._from_rest_object(env_rest_obj)
+
+ sas_uri = get_sas_uri_for_registry_asset(
+ service_client=self._service_client,
+ name=environment.name,
+ version=environment.version,
+ resource_group=self._resource_group_name,
+ registry=self._registry_name,
+ body=get_asset_body_for_registry_storage(
+ self._registry_name,
+ "environments",
+ environment.name,
+ environment.version,
+ ),
+ )
+
+ # upload only in case of when its not registry
+ # or successfully acquire sas_uri
+ if not self._registry_name or sas_uri:
+ environment = _check_and_upload_env_build_context(
+ environment=environment,
+ operations=self,
+ sas_uri=sas_uri,
+ show_progress=self._show_progress,
+ )
+ env_version_resource = environment._to_rest_object()
+ env_rest_obj = (
+ self._version_operations.begin_create_or_update(
+ name=environment.name,
+ version=environment.version,
+ registry_name=self._registry_name,
+ body=env_version_resource,
+ **self._scope_kwargs,
+ **self._kwargs,
+ ).result()
+ if self._registry_name
+ else self._version_operations.create_or_update(
+ name=environment.name,
+ version=environment.version,
+ workspace_name=self._workspace_name,
+ body=env_version_resource,
+ **self._scope_kwargs,
+ **self._kwargs,
+ )
+ )
+ if not env_rest_obj and self._registry_name:
+ env_rest_obj = self._get(name=str(environment.name), version=environment.version)
+ return Environment._from_rest_object(env_rest_obj)
+ except Exception as ex: # pylint: disable=W0718
+ if isinstance(ex, SchemaValidationError):
+ log_and_raise_error(ex)
+ else:
+ raise ex
+
+ def _try_update_latest_version(
+ self, next_version: str, latest_version: str, environment: Environment
+ ) -> Environment:
+ env = None
+ if self._registry_name:
+ environment.version = next_version
+ env = self.create_or_update(environment)
+ else:
+ environment.version = latest_version
+ try: # Try to update the latest version
+ env = self.create_or_update(environment)
+ except Exception as ex: # pylint: disable=W0718
+ if isinstance(ex, ResourceExistsError):
+ environment.version = next_version
+ env = self.create_or_update(environment)
+ else:
+ raise ex
+ return env
+
+ def _get(self, name: str, version: Optional[str] = None) -> EnvironmentVersion:
+ if version:
+ return (
+ self._version_operations.get(
+ name=name,
+ version=version,
+ registry_name=self._registry_name,
+ **self._scope_kwargs,
+ **self._kwargs,
+ )
+ if self._registry_name
+ else self._version_operations.get(
+ name=name,
+ version=version,
+ workspace_name=self._workspace_name,
+ **self._scope_kwargs,
+ **self._kwargs,
+ )
+ )
+ return (
+ self._containers_operations.get(
+ name=name,
+ registry_name=self._registry_name,
+ **self._scope_kwargs,
+ **self._kwargs,
+ )
+ if self._registry_name
+ else self._containers_operations.get(
+ name=name,
+ workspace_name=self._workspace_name,
+ **self._scope_kwargs,
+ **self._kwargs,
+ )
+ )
+
+ @monitor_with_activity(ops_logger, "Environment.Get", ActivityType.PUBLICAPI)
+ def get(self, name: str, version: Optional[str] = None, label: Optional[str] = None) -> Environment:
+ """Returns the specified environment asset.
+
+ :param name: Name of the environment.
+ :type name: str
+ :param version: Version of the environment.
+ :type version: str
+ :param label: Label of the environment. (mutually exclusive with version)
+ :type label: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Environment cannot be successfully validated.
+ Details will be provided in the error message.
+ :return: Environment object
+ :rtype: ~azure.ai.ml.entities.Environment
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START env_operations_get]
+ :end-before: [END env_operations_get]
+ :language: python
+ :dedent: 8
+ :caption: Get example.
+ """
+ if version and label:
+ msg = "Cannot specify both version and label."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.ENVIRONMENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ if label:
+ return _resolve_label_to_asset(self, name, label)
+
+ if not version:
+ msg = "Must provide either version or label."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.ENVIRONMENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+ name = _preprocess_environment_name(name)
+ env_version_resource = self._get(name, version)
+
+ return Environment._from_rest_object(env_version_resource)
+
+ @monitor_with_activity(ops_logger, "Environment.List", ActivityType.PUBLICAPI)
+ def list(
+ self,
+ name: Optional[str] = None,
+ *,
+ list_view_type: ListViewType = ListViewType.ACTIVE_ONLY,
+ ) -> Iterable[Environment]:
+ """List all environment assets in workspace.
+
+ :param name: Name of the environment.
+ :type name: Optional[str]
+ :keyword list_view_type: View type for including/excluding (for example) archived environments.
+ Default: ACTIVE_ONLY.
+ :type list_view_type: Optional[ListViewType]
+ :return: An iterator like instance of Environment objects.
+ :rtype: ~azure.core.paging.ItemPaged[Environment]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START env_operations_list]
+ :end-before: [END env_operations_list]
+ :language: python
+ :dedent: 8
+ :caption: List example.
+ """
+ if name:
+ return cast(
+ Iterable[Environment],
+ (
+ self._version_operations.list(
+ name=name,
+ registry_name=self._registry_name,
+ cls=lambda objs: [Environment._from_rest_object(obj) for obj in objs],
+ **self._scope_kwargs,
+ **self._kwargs,
+ )
+ if self._registry_name
+ else self._version_operations.list(
+ name=name,
+ workspace_name=self._workspace_name,
+ cls=lambda objs: [Environment._from_rest_object(obj) for obj in objs],
+ list_view_type=list_view_type,
+ **self._scope_kwargs,
+ **self._kwargs,
+ )
+ ),
+ )
+ return cast(
+ Iterable[Environment],
+ (
+ self._containers_operations.list(
+ registry_name=self._registry_name,
+ cls=lambda objs: [Environment._from_container_rest_object(obj) for obj in objs],
+ **self._scope_kwargs,
+ **self._kwargs,
+ )
+ if self._registry_name
+ else self._containers_operations.list(
+ workspace_name=self._workspace_name,
+ cls=lambda objs: [Environment._from_container_rest_object(obj) for obj in objs],
+ list_view_type=list_view_type,
+ **self._scope_kwargs,
+ **self._kwargs,
+ )
+ ),
+ )
+
+ @monitor_with_activity(ops_logger, "Environment.Delete", ActivityType.PUBLICAPI)
+ def archive(
+ self,
+ name: str,
+ version: Optional[str] = None,
+ label: Optional[str] = None,
+ # pylint:disable=unused-argument
+ **kwargs: Any,
+ ) -> None:
+ """Archive an environment or an environment version.
+
+ :param name: Name of the environment.
+ :type name: str
+ :param version: Version of the environment.
+ :type version: str
+ :param label: Label of the environment. (mutually exclusive with version)
+ :type label: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START env_operations_archive]
+ :end-before: [END env_operations_archive]
+ :language: python
+ :dedent: 8
+ :caption: Archive example.
+ """
+ name = _preprocess_environment_name(name)
+ _archive_or_restore(
+ asset_operations=self,
+ version_operation=self._version_operations,
+ container_operation=self._containers_operations,
+ is_archived=True,
+ name=name,
+ version=version,
+ label=label,
+ )
+
+ @monitor_with_activity(ops_logger, "Environment.Restore", ActivityType.PUBLICAPI)
+ def restore(
+ self,
+ name: str,
+ version: Optional[str] = None,
+ label: Optional[str] = None,
+ # pylint:disable=unused-argument
+ **kwargs: Any,
+ ) -> None:
+ """Restore an archived environment version.
+
+ :param name: Name of the environment.
+ :type name: str
+ :param version: Version of the environment.
+ :type version: str
+ :param label: Label of the environment. (mutually exclusive with version)
+ :type label: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START env_operations_restore]
+ :end-before: [END env_operations_restore]
+ :language: python
+ :dedent: 8
+ :caption: Restore example.
+ """
+ name = _preprocess_environment_name(name)
+ _archive_or_restore(
+ asset_operations=self,
+ version_operation=self._version_operations,
+ container_operation=self._containers_operations,
+ is_archived=False,
+ name=name,
+ version=version,
+ label=label,
+ )
+
+ def _get_latest_version(self, name: str) -> Environment:
+ """Returns the latest version of the asset with the given name.
+
+ Latest is defined as the most recently created, not the most
+ recently updated.
+
+ :param name: The asset name
+ :type name: str
+ :return: The latest version of the named environment
+ :rtype: Environment
+ """
+ result = _get_latest(
+ name,
+ self._version_operations,
+ self._resource_group_name,
+ self._workspace_name,
+ self._registry_name,
+ )
+ return Environment._from_rest_object(result)
+
+ @monitor_with_activity(ops_logger, "Environment.Share", ActivityType.PUBLICAPI)
+ @experimental
+ def share(
+ self,
+ name: str,
+ version: str,
+ *,
+ share_with_name: str,
+ share_with_version: str,
+ registry_name: str,
+ ) -> Environment:
+ """Share a environment asset from workspace to registry.
+
+ :param name: Name of environment asset.
+ :type name: str
+ :param version: Version of environment asset.
+ :type version: str
+ :keyword share_with_name: Name of environment asset to share with.
+ :paramtype share_with_name: str
+ :keyword share_with_version: Version of environment asset to share with.
+ :paramtype share_with_version: str
+ :keyword registry_name: Name of the destination registry.
+ :paramtype registry_name: str
+ :return: Environment asset object.
+ :rtype: ~azure.ai.ml.entities.Environment
+ """
+
+ # Get workspace info to get workspace GUID
+ workspace = self._service_client.workspaces.get(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ )
+ workspace_guid = workspace.workspace_id
+ workspace_location = workspace.location
+
+ # Get environment asset ID
+ asset_id = ASSET_ID_FORMAT.format(
+ workspace_location,
+ workspace_guid,
+ AzureMLResourceType.ENVIRONMENT,
+ name,
+ version,
+ )
+
+ environment_ref = WorkspaceAssetReference(
+ name=share_with_name if share_with_name else name,
+ version=share_with_version if share_with_version else version,
+ asset_id=asset_id,
+ )
+
+ with self._set_registry_client(registry_name):
+ return self.create_or_update(environment_ref)
+
+ @contextmanager
+ # pylint: disable-next=docstring-missing-return,docstring-missing-rtype
+ def _set_registry_client(self, registry_name: str) -> Generator:
+ """Sets the registry client for the environment operations.
+
+ :param registry_name: Name of the registry.
+ :type registry_name: str
+ """
+ rg_ = self._operation_scope._resource_group_name
+ sub_ = self._operation_scope._subscription_id
+ registry_ = self._operation_scope.registry_name
+ client_ = self._service_client
+ environment_versions_operation_ = self._version_operations
+
+ try:
+ _client, _rg, _sub = get_registry_client(self._service_client._config.credential, registry_name)
+ self._operation_scope.registry_name = registry_name
+ self._operation_scope._resource_group_name = _rg
+ self._operation_scope._subscription_id = _sub
+ self._service_client = _client
+ self._version_operations = _client.environment_versions
+ yield
+ finally:
+ self._operation_scope.registry_name = registry_
+ self._operation_scope._resource_group_name = rg_
+ self._operation_scope._subscription_id = sub_
+ self._service_client = client_
+ self._version_operations = environment_versions_operation_
+
+
+def _preprocess_environment_name(environment_name: str) -> str:
+ if environment_name.startswith(ARM_ID_PREFIX):
+ return environment_name[len(ARM_ID_PREFIX) :]
+ return environment_name
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_evaluator_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_evaluator_operations.py
new file mode 100644
index 00000000..47167947
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_evaluator_operations.py
@@ -0,0 +1,222 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from os import PathLike
+from typing import Any, Dict, Iterable, Optional, Union, cast
+
+from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import (
+ AzureMachineLearningWorkspaces as ServiceClient102021Dataplane,
+)
+from azure.ai.ml._restclient.v2023_08_01_preview import AzureMachineLearningWorkspaces as ServiceClient082023Preview
+from azure.ai.ml._restclient.v2023_08_01_preview.models import ListViewType
+from azure.ai.ml._scope_dependent_operations import (
+ OperationConfig,
+ OperationsContainer,
+ OperationScope,
+ _ScopeDependentOperations,
+)
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._utils.utils import _get_evaluator_properties, _is_evaluator
+from azure.ai.ml.entities._assets import Model
+from azure.ai.ml.entities._assets.workspace_asset_reference import WorkspaceAssetReference
+from azure.ai.ml.exceptions import UnsupportedOperationError
+from azure.ai.ml.operations._datastore_operations import DatastoreOperations
+from azure.ai.ml.operations._model_operations import ModelOperations
+from azure.core.exceptions import ResourceNotFoundError
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class EvaluatorOperations(_ScopeDependentOperations):
+ """EvaluatorOperations.
+
+ You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it
+ for you and attaches it as an attribute.
+
+ :param operation_scope: Scope variables for the operations classes of an MLClient object.
+ :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope
+ :param operation_config: Common configuration for operations classes of an MLClient object.
+ :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig
+ :param service_client: Service client to allow end users to operate on Azure Machine Learning Workspace
+ resources (ServiceClient082023Preview or ServiceClient102021Dataplane).
+ :type service_client: typing.Union[
+ azure.ai.ml._restclient.v2023_04_01_preview._azure_machine_learning_workspaces.AzureMachineLearningWorkspaces,
+ azure.ai.ml._restclient.v2021_10_01_dataplanepreview._azure_machine_learning_workspaces.
+ AzureMachineLearningWorkspaces]
+ :param datastore_operations: Represents a client for performing operations on Datastores.
+ :type datastore_operations: ~azure.ai.ml.operations._datastore_operations.DatastoreOperations
+ :param all_operations: All operations classes of an MLClient object.
+ :type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer
+ :param kwargs: A dictionary of additional configuration parameters.
+ :type kwargs: dict
+ """
+
+ # pylint: disable=unused-argument
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client: Union[ServiceClient082023Preview, ServiceClient102021Dataplane],
+ datastore_operations: DatastoreOperations,
+ all_operations: Optional[OperationsContainer] = None,
+ **kwargs,
+ ):
+ super(EvaluatorOperations, self).__init__(operation_scope, operation_config)
+
+ ops_logger.update_filter()
+ self._model_op = ModelOperations(
+ operation_scope=operation_scope,
+ operation_config=operation_config,
+ service_client=service_client,
+ datastore_operations=datastore_operations,
+ all_operations=all_operations,
+ **{ModelOperations._IS_EVALUATOR: True},
+ **kwargs,
+ )
+ self._operation_scope = self._model_op._operation_scope
+ self._datastore_operation = self._model_op._datastore_operation
+
+ @monitor_with_activity(ops_logger, "Evaluator.CreateOrUpdate", ActivityType.PUBLICAPI)
+ def create_or_update( # type: ignore
+ self, model: Union[Model, WorkspaceAssetReference], **kwargs: Any
+ ) -> Model: # TODO: Are we going to implement job_name?
+ """Returns created or updated model asset.
+
+ :param model: Model asset object.
+ :type model: ~azure.ai.ml.entities.Model
+ :raises ~azure.ai.ml.exceptions.AssetPathException: Raised when the Model artifact path is
+ already linked to another asset
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Model cannot be successfully validated.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty directory.
+ :return: Model asset object.
+ :rtype: ~azure.ai.ml.entities.Model
+ """
+ model.properties.update(_get_evaluator_properties())
+ return self._model_op.create_or_update(model)
+
+ def _raise_if_not_evaluator(self, properties: Optional[Dict[str, Any]], message: str) -> None:
+ """
+ :param properties: The properties of a model.
+ :type properties: dict[str, str]
+ :param message: The message to be set on exception.
+ :type message: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if model is not an
+ evaluator.
+ """
+ if properties is not None and not _is_evaluator(properties):
+ raise ResourceNotFoundError(
+ message=message,
+ response=None,
+ )
+
+ @monitor_with_activity(ops_logger, "Evaluator.Get", ActivityType.PUBLICAPI)
+ def get(self, name: str, *, version: Optional[str] = None, label: Optional[str] = None, **kwargs) -> Model:
+ """Returns information about the specified model asset.
+
+ :param name: Name of the model.
+ :type name: str
+ :keyword version: Version of the model.
+ :paramtype version: str
+ :keyword label: Label of the model. (mutually exclusive with version)
+ :paramtype label: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Model cannot be successfully validated.
+ Details will be provided in the error message.
+ :return: Model asset object.
+ :rtype: ~azure.ai.ml.entities.Model
+ """
+ model = self._model_op.get(name, version, label)
+
+ properties = None if model is None else model.properties
+ self._raise_if_not_evaluator(
+ properties,
+ f"Evaluator {name} with version {version} not found.",
+ )
+
+ return model
+
+ @monitor_with_activity(ops_logger, "Evaluator.Download", ActivityType.PUBLICAPI)
+ def download(self, name: str, version: str, download_path: Union[PathLike, str] = ".", **kwargs: Any) -> None:
+ """Download files related to a model.
+
+ :param name: Name of the model.
+ :type name: str
+ :param version: Version of the model.
+ :type version: str
+ :param download_path: Local path as download destination, defaults to current working directory of the current
+ user. Contents will be overwritten.
+ :type download_path: Union[PathLike, str]
+ :raises ResourceNotFoundError: if can't find a model matching provided name.
+ """
+ self._model_op.download(name, version, download_path)
+
+ @monitor_with_activity(ops_logger, "Evaluator.List", ActivityType.PUBLICAPI)
+ def list(
+ self,
+ name: str,
+ stage: Optional[str] = None,
+ *,
+ list_view_type: ListViewType = ListViewType.ACTIVE_ONLY,
+ **kwargs: Any,
+ ) -> Iterable[Model]:
+ """List all model assets in workspace.
+
+ :param name: Name of the model.
+ :type name: str
+ :param stage: The Model stage
+ :type stage: Optional[str]
+ :keyword list_view_type: View type for including/excluding (for example) archived models.
+ Defaults to :attr:`ListViewType.ACTIVE_ONLY`.
+ :paramtype list_view_type: ListViewType
+ :return: An iterator like instance of Model objects
+ :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.Model]
+ """
+ properties_str = "is-promptflow=true,is-evaluator=true"
+ if name:
+ return cast(
+ Iterable[Model],
+ (
+ self._model_op._model_versions_operation.list(
+ name=name,
+ registry_name=self._model_op._registry_name,
+ cls=lambda objs: [Model._from_rest_object(obj) for obj in objs],
+ properties=properties_str,
+ **self._model_op._scope_kwargs,
+ )
+ if self._registry_name
+ else self._model_op._model_versions_operation.list(
+ name=name,
+ workspace_name=self._model_op._workspace_name,
+ cls=lambda objs: [Model._from_rest_object(obj) for obj in objs],
+ list_view_type=list_view_type,
+ properties=properties_str,
+ stage=stage,
+ **self._model_op._scope_kwargs,
+ )
+ ),
+ )
+ # ModelContainer object does not carry properties.
+ raise UnsupportedOperationError("list on evaluation operations without name provided")
+ # TODO: Implement filtering of the ModelContainerOperations list output
+ # return cast(
+ # Iterable[Model], (
+ # self._model_container_operation.list(
+ # registry_name=self._registry_name,
+ # cls=lambda objs: [Model._from_container_rest_object(obj) for obj in objs],
+ # list_view_type=list_view_type,
+ # **self._scope_kwargs,
+ # )
+ # if self._registry_name
+ # else self._model_container_operation.list(
+ # workspace_name=self._workspace_name,
+ # cls=lambda objs: [Model._from_container_rest_object(obj) for obj in objs],
+ # list_view_type=list_view_type,
+ # **self._scope_kwargs,
+ # )
+ # )
+ # )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_set_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_set_operations.py
new file mode 100644
index 00000000..db51295f
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_set_operations.py
@@ -0,0 +1,456 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import json
+import os
+from copy import deepcopy
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+
+from marshmallow.exceptions import ValidationError as SchemaValidationError
+
+from azure.ai.ml._artifacts._artifact_utilities import _check_and_upload_path
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._restclient.v2023_08_01_preview import AzureMachineLearningWorkspaces as ServiceClient082023Preview
+from azure.ai.ml._restclient.v2023_10_01 import AzureMachineLearningServices as ServiceClient102023
+from azure.ai.ml._restclient.v2023_10_01.models import (
+ FeaturesetVersion,
+ FeaturesetVersionBackfillRequest,
+ FeatureWindow,
+)
+from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._feature_store_utils import (
+ _archive_or_restore,
+ _datetime_to_str,
+ read_feature_set_metadata,
+ read_remote_feature_set_spec_metadata,
+)
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._utils.utils import is_url
+from azure.ai.ml.constants import ListViewType
+from azure.ai.ml.entities._assets._artifacts.feature_set import FeatureSet
+from azure.ai.ml.entities._feature_set.data_availability_status import DataAvailabilityStatus
+from azure.ai.ml.entities._feature_set.feature import Feature
+from azure.ai.ml.entities._feature_set.feature_set_backfill_metadata import FeatureSetBackfillMetadata
+from azure.ai.ml.entities._feature_set.feature_set_materialization_metadata import FeatureSetMaterializationMetadata
+from azure.ai.ml.entities._feature_set.featureset_spec_metadata import FeaturesetSpecMetadata
+from azure.ai.ml.entities._feature_set.materialization_compute_resource import MaterializationComputeResource
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+from azure.ai.ml.operations._datastore_operations import DatastoreOperations
+from azure.core.paging import ItemPaged
+from azure.core.polling import LROPoller
+from azure.core.tracing.decorator import distributed_trace
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class FeatureSetOperations(_ScopeDependentOperations):
+ """FeatureSetOperations.
+
+ You should not instantiate this class directly. Instead, you should
+ create an MLClient instance that instantiates it for you and
+ attaches it as an attribute.
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client: ServiceClient102023,
+ service_client_for_jobs: ServiceClient082023Preview,
+ datastore_operations: DatastoreOperations,
+ **kwargs: Dict,
+ ):
+ super(FeatureSetOperations, self).__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+ self._operation = service_client.featureset_versions
+ self._container_operation = service_client.featureset_containers
+ self._jobs_operation = service_client_for_jobs.jobs
+ self._feature_operation = service_client.features
+ self._service_client = service_client
+ self._datastore_operation = datastore_operations
+ self._init_kwargs = kwargs
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureSet.List", ActivityType.PUBLICAPI)
+ def list(
+ self,
+ name: Optional[str] = None,
+ *,
+ list_view_type: ListViewType = ListViewType.ACTIVE_ONLY,
+ **kwargs: Dict,
+ ) -> ItemPaged[FeatureSet]:
+ """List the FeatureSet assets of the workspace.
+
+ :param name: Name of a specific FeatureSet asset, optional.
+ :type name: Optional[str]
+ :keyword list_view_type: View type for including/excluding (for example) archived FeatureSet assets.
+ Defaults to ACTIVE_ONLY.
+ :type list_view_type: Optional[ListViewType]
+ :return: An iterator like instance of FeatureSet objects
+ :rtype: ~azure.core.paging.ItemPaged[FeatureSet]
+ """
+ if name:
+ return self._operation.list(
+ workspace_name=self._workspace_name,
+ name=name,
+ cls=lambda objs: [FeatureSet._from_rest_object(obj) for obj in objs],
+ list_view_type=list_view_type,
+ **self._scope_kwargs,
+ **kwargs,
+ )
+ return self._container_operation.list(
+ workspace_name=self._workspace_name,
+ cls=lambda objs: [FeatureSet._from_container_rest_object(obj) for obj in objs],
+ list_view_type=list_view_type,
+ **self._scope_kwargs,
+ **kwargs,
+ )
+
+ def _get(self, name: str, version: Optional[str] = None, **kwargs: Dict) -> FeaturesetVersion:
+ return self._operation.get(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ name=name,
+ version=version,
+ **self._init_kwargs,
+ **kwargs,
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureSet.Get", ActivityType.PUBLICAPI)
+ def get(self, name: str, version: str, **kwargs: Dict) -> FeatureSet: # type: ignore
+ """Get the specified FeatureSet asset.
+
+ :param name: Name of FeatureSet asset.
+ :type name: str
+ :param version: Version of FeatureSet asset.
+ :type version: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if FeatureSet cannot be successfully
+ identified and retrieved. Details will be provided in the error message.
+ :raises ~azure.core.exceptions.HttpResponseError: Raised if the corresponding name and version cannot be
+ retrieved from the service.
+ :return: FeatureSet asset object.
+ :rtype: ~azure.ai.ml.entities.FeatureSet
+ """
+ try:
+ featureset_version_resource = self._get(name, version, **kwargs)
+ return FeatureSet._from_rest_object(featureset_version_resource) # type: ignore[return-value]
+ except (ValidationException, SchemaValidationError) as ex:
+ log_and_raise_error(ex)
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureSet.BeginCreateOrUpdate", ActivityType.PUBLICAPI)
+ def begin_create_or_update(self, featureset: FeatureSet, **kwargs: Dict) -> LROPoller[FeatureSet]:
+ """Create or update FeatureSet
+
+ :param featureset: FeatureSet definition.
+ :type featureset: FeatureSet
+ :return: An instance of LROPoller that returns a FeatureSet.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.FeatureSet]
+ """
+ featureset_copy = deepcopy(featureset)
+
+ featureset_spec = self._validate_and_get_feature_set_spec(featureset_copy)
+ featureset_copy.properties["featuresetPropertiesVersion"] = "1"
+ featureset_copy.properties["featuresetProperties"] = json.dumps(featureset_spec._to_dict())
+
+ sas_uri = None
+
+ if not is_url(featureset_copy.path):
+ with open(os.path.join(str(featureset_copy.path), ".amlignore"), mode="w", encoding="utf-8") as f:
+ f.write(".*\n*.amltmp\n*.amltemp")
+
+ featureset_copy, _ = _check_and_upload_path(
+ artifact=featureset_copy, asset_operations=self, sas_uri=sas_uri, artifact_type=ErrorTarget.FEATURE_SET
+ )
+
+ featureset_resource = FeatureSet._to_rest_object(featureset_copy)
+
+ return self._operation.begin_create_or_update(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ name=featureset_copy.name,
+ version=featureset_copy.version,
+ body=featureset_resource,
+ **kwargs,
+ cls=lambda response, deserialized, headers: FeatureSet._from_rest_object(deserialized),
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureSet.BeginBackFill", ActivityType.PUBLICAPI)
+ def begin_backfill(
+ self,
+ *,
+ name: str,
+ version: str,
+ feature_window_start_time: Optional[datetime] = None,
+ feature_window_end_time: Optional[datetime] = None,
+ display_name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[Dict[str, str]] = None,
+ compute_resource: Optional[MaterializationComputeResource] = None,
+ spark_configuration: Optional[Dict[str, str]] = None,
+ data_status: Optional[List[Union[str, DataAvailabilityStatus]]] = None,
+ job_id: Optional[str] = None,
+ **kwargs: Dict,
+ ) -> LROPoller[FeatureSetBackfillMetadata]:
+ """Backfill.
+
+ :keyword name: Feature set name. This is case-sensitive.
+ :paramtype name: str
+ :keyword version: Version identifier. This is case-sensitive.
+ :paramtype version: str
+ :keyword feature_window_start_time: Start time of the feature window to be materialized.
+ :paramtype feature_window_start_time: datetime
+ :keyword feature_window_end_time: End time of the feature window to be materialized.
+ :paramtype feature_window_end_time: datetime
+ :keyword display_name: Specifies description.
+ :paramtype display_name: str
+ :keyword description: Specifies description.
+ :paramtype description: str
+ :keyword tags: A set of tags. Specifies the tags.
+ :paramtype tags: dict[str, str]
+ :keyword compute_resource: Specifies the compute resource settings.
+ :paramtype compute_resource: ~azure.ai.ml.entities.MaterializationComputeResource
+ :keyword spark_configuration: Specifies the spark compute settings.
+ :paramtype spark_configuration: dict[str, str]
+ :keyword data_status: Specifies the data status that you want to backfill.
+ :paramtype data_status: list[str or ~azure.ai.ml.entities.DataAvailabilityStatus]
+ :keyword job_id: The job id.
+ :paramtype job_id: str
+ :return: An instance of LROPoller that returns ~azure.ai.ml.entities.FeatureSetBackfillMetadata
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.FeatureSetBackfillMetadata]
+ """
+ request_body: FeaturesetVersionBackfillRequest = FeaturesetVersionBackfillRequest(
+ description=description,
+ display_name=display_name,
+ feature_window=FeatureWindow(
+ feature_window_start=feature_window_start_time, feature_window_end=feature_window_end_time
+ ),
+ resource=compute_resource._to_rest_object() if compute_resource else None,
+ spark_configuration=spark_configuration,
+ data_availability_status=data_status,
+ job_id=job_id,
+ tags=tags,
+ )
+
+ return self._operation.begin_backfill(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ name=name,
+ version=version,
+ body=request_body,
+ **kwargs,
+ cls=lambda response, deserialized, headers: FeatureSetBackfillMetadata._from_rest_object(deserialized),
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureSet.ListMaterializationOperation", ActivityType.PUBLICAPI)
+ def list_materialization_operations(
+ self,
+ name: str,
+ version: str,
+ *,
+ feature_window_start_time: Optional[Union[str, datetime]] = None,
+ feature_window_end_time: Optional[Union[str, datetime]] = None,
+ filters: Optional[str] = None,
+ **kwargs: Dict,
+ ) -> ItemPaged[FeatureSetMaterializationMetadata]:
+ """List Materialization operation.
+
+ :param name: Feature set name.
+ :type name: str
+ :param version: Feature set version.
+ :type version: str
+ :keyword feature_window_start_time: Start time of the feature window to filter materialization jobs.
+ :paramtype feature_window_start_time: Union[str, datetime]
+ :keyword feature_window_end_time: End time of the feature window to filter materialization jobs.
+ :paramtype feature_window_end_time: Union[str, datetime]
+ :keyword filters: Comma-separated list of tag names (and optionally values). Example: tag1,tag2=value2.
+ :paramtype filters: str
+ :return: An iterator like instance of ~azure.ai.ml.entities.FeatureSetMaterializationMetadata objects
+ :rtype: ~azure.core.paging.ItemPaged[FeatureSetMaterializationMetadata]
+ """
+ feature_window_start_time = _datetime_to_str(feature_window_start_time) if feature_window_start_time else None
+ feature_window_end_time = _datetime_to_str(feature_window_end_time) if feature_window_end_time else None
+ properties = f"azureml.FeatureSetName={name},azureml.FeatureSetVersion={version}"
+ if feature_window_start_time:
+ properties = properties + f",azureml.FeatureWindowStart={feature_window_start_time}"
+ if feature_window_end_time:
+ properties = properties + f",azureml.FeatureWindowEnd={feature_window_end_time}"
+
+ materialization_jobs = self._jobs_operation.list(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ properties=properties,
+ tag=filters,
+ cls=lambda objs: [FeatureSetMaterializationMetadata._from_rest_object(obj) for obj in objs],
+ **kwargs,
+ )
+ return materialization_jobs
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureSet.ListFeatures", ActivityType.PUBLICAPI)
+ def list_features(
+ self,
+ feature_set_name: str,
+ version: str,
+ *,
+ feature_name: Optional[str] = None,
+ description: Optional[str] = None,
+ tags: Optional[str] = None,
+ **kwargs: Dict,
+ ) -> ItemPaged[Feature]:
+ """List features
+
+ :param feature_set_name: Feature set name.
+ :type feature_set_name: str
+ :param version: Feature set version.
+ :type version: str
+ :keyword feature_name: feature name.
+ :paramtype feature_name: str
+ :keyword description: Description of the featureset.
+ :paramtype description: str
+ :keyword tags: Comma-separated list of tag names (and optionally values). Example: tag1,tag2=value2.
+ :paramtype tags: str
+ :return: An iterator like instance of Feature objects
+ :rtype: ~azure.core.paging.ItemPaged[Feature]
+ """
+ features = self._feature_operation.list(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ featureset_name=feature_set_name,
+ featureset_version=version,
+ tags=tags,
+ feature_name=feature_name,
+ description=description,
+ **kwargs,
+ cls=lambda objs: [Feature._from_rest_object(obj) for obj in objs],
+ )
+ return features
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureSet.GetFeature", ActivityType.PUBLICAPI)
+ def get_feature(
+ self, feature_set_name: str, version: str, *, feature_name: str, **kwargs: Dict
+ ) -> Optional["Feature"]:
+ """Get Feature
+
+ :param feature_set_name: Feature set name.
+ :type feature_set_name: str
+ :param version: Feature set version.
+ :type version: str
+ :keyword feature_name: The feature name. This argument is case-sensitive.
+ :paramtype feature_name: str
+ :return: Feature object
+ :rtype: ~azure.ai.ml.entities.Feature
+ """
+ feature = self._feature_operation.get(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ featureset_name=feature_set_name,
+ featureset_version=version,
+ feature_name=feature_name,
+ **kwargs,
+ )
+
+ return Feature._from_rest_object(feature) # type: ignore[return-value]
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureSet.Archive", ActivityType.PUBLICAPI)
+ def archive(
+ self,
+ name: str,
+ version: str,
+ **kwargs: Dict,
+ ) -> None:
+ """Archive a FeatureSet asset.
+
+ :param name: Name of FeatureSet asset.
+ :type name: str
+ :param version: Version of FeatureSet asset.
+ :type version: str
+ :return: None
+ """
+
+ _archive_or_restore(
+ asset_operations=self,
+ version_operation=self._operation,
+ is_archived=True,
+ name=name,
+ version=version,
+ **kwargs,
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureSet.Restore", ActivityType.PUBLICAPI)
+ def restore(
+ self,
+ name: str,
+ version: str,
+ **kwargs: Dict,
+ ) -> None:
+ """Restore an archived FeatureSet asset.
+
+ :param name: Name of FeatureSet asset.
+ :type name: str
+ :param version: Version of FeatureSet asset.
+ :type version: str
+ :return: None
+ """
+
+ _archive_or_restore(
+ asset_operations=self,
+ version_operation=self._operation,
+ is_archived=False,
+ name=name,
+ version=version,
+ **kwargs,
+ )
+
+ def _validate_and_get_feature_set_spec(self, featureset: FeatureSet) -> FeaturesetSpecMetadata:
+ if not (featureset.specification and featureset.specification.path):
+ msg = "Missing FeatureSet specification path. Path is required for feature set."
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ target=ErrorTarget.FEATURE_SET,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ featureset_spec_path: Any = str(featureset.specification.path)
+ if is_url(featureset_spec_path):
+ try:
+ featureset_spec_contents = read_remote_feature_set_spec_metadata(
+ base_uri=featureset_spec_path,
+ datastore_operations=self._datastore_operation,
+ )
+ featureset_spec_path = None
+ except Exception as ex:
+ module_logger.info("Unable to access FeaturesetSpec at path %s", featureset_spec_path)
+ raise ex
+ return FeaturesetSpecMetadata._load(featureset_spec_contents, featureset_spec_path)
+
+ if not os.path.isabs(featureset_spec_path):
+ featureset_spec_path = Path(featureset.base_path, featureset_spec_path).resolve()
+
+ if not os.path.isdir(featureset_spec_path):
+ raise ValidationException(
+ message="No such directory: {}".format(featureset_spec_path),
+ no_personal_data_message="No such directory",
+ target=ErrorTarget.FEATURE_SET,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND,
+ )
+
+ featureset_spec_contents = read_feature_set_metadata(path=featureset_spec_path)
+ featureset_spec_yaml_path = Path(featureset_spec_path, "FeatureSetSpec.yaml")
+ return FeaturesetSpecMetadata._load(featureset_spec_contents, featureset_spec_yaml_path)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_store_entity_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_store_entity_operations.py
new file mode 100644
index 00000000..e0ce9587
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_store_entity_operations.py
@@ -0,0 +1,191 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Dict, Optional
+
+from marshmallow.exceptions import ValidationError as SchemaValidationError
+
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._restclient.v2023_10_01 import AzureMachineLearningServices as ServiceClient102023
+from azure.ai.ml._restclient.v2023_10_01.models import FeaturestoreEntityVersion
+from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._feature_store_utils import _archive_or_restore
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml.constants import ListViewType
+from azure.ai.ml.entities._feature_store_entity.feature_store_entity import FeatureStoreEntity
+from azure.ai.ml.exceptions import ValidationException
+from azure.core.paging import ItemPaged
+from azure.core.polling import LROPoller
+from azure.core.tracing.decorator import distributed_trace
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class FeatureStoreEntityOperations(_ScopeDependentOperations):
+ """FeatureStoreEntityOperations.
+
+ You should not instantiate this class directly. Instead, you should
+ create an MLClient instance that instantiates it for you and
+ attaches it as an attribute.
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client: ServiceClient102023,
+ **kwargs: Dict,
+ ):
+ super(FeatureStoreEntityOperations, self).__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+ self._operation = service_client.featurestore_entity_versions
+ self._container_operation = service_client.featurestore_entity_containers
+ self._service_client = service_client
+ self._init_kwargs = kwargs
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureStoreEntity.List", ActivityType.PUBLICAPI)
+ def list(
+ self,
+ name: Optional[str] = None,
+ *,
+ list_view_type: ListViewType = ListViewType.ACTIVE_ONLY,
+ **kwargs: Dict,
+ ) -> ItemPaged[FeatureStoreEntity]:
+ """List the FeatureStoreEntity assets of the workspace.
+
+ :param name: Name of a specific FeatureStoreEntity asset, optional.
+ :type name: Optional[str]
+ :keyword list_view_type: View type for including/excluding (for example) archived FeatureStoreEntity assets.
+ Default: ACTIVE_ONLY.
+ :paramtype list_view_type: Optional[ListViewType]
+ :return: An iterator like instance of FeatureStoreEntity objects
+ :rtype: ~azure.core.paging.ItemPaged[FeatureStoreEntity]
+ """
+ if name:
+ return self._operation.list(
+ workspace_name=self._workspace_name,
+ name=name,
+ cls=lambda objs: [FeatureStoreEntity._from_rest_object(obj) for obj in objs],
+ list_view_type=list_view_type,
+ **self._scope_kwargs,
+ **kwargs,
+ )
+ return self._container_operation.list(
+ workspace_name=self._workspace_name,
+ cls=lambda objs: [FeatureStoreEntity._from_container_rest_object(obj) for obj in objs],
+ list_view_type=list_view_type,
+ **self._scope_kwargs,
+ **kwargs,
+ )
+
+ def _get(self, name: str, version: Optional[str] = None, **kwargs: Dict) -> FeaturestoreEntityVersion:
+ return self._operation.get(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ name=name,
+ version=version,
+ **self._init_kwargs,
+ **kwargs,
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureStoreEntity.Get", ActivityType.PUBLICAPI)
+ def get(self, name: str, version: str, **kwargs: Dict) -> FeatureStoreEntity: # type: ignore
+ """Get the specified FeatureStoreEntity asset.
+
+ :param name: Name of FeatureStoreEntity asset.
+ :type name: str
+ :param version: Version of FeatureStoreEntity asset.
+ :type version: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if FeatureStoreEntity cannot be successfully
+ identified and retrieved. Details will be provided in the error message.
+ :return: FeatureStoreEntity asset object.
+ :rtype: ~azure.ai.ml.entities.FeatureStoreEntity
+ """
+ try:
+ feature_store_entity_version_resource = self._get(name, version, **kwargs)
+ return FeatureStoreEntity._from_rest_object(feature_store_entity_version_resource)
+ except (ValidationException, SchemaValidationError) as ex:
+ log_and_raise_error(ex)
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureStoreEntity.BeginCreateOrUpdate", ActivityType.PUBLICAPI)
+ def begin_create_or_update(
+ self, feature_store_entity: FeatureStoreEntity, **kwargs: Dict
+ ) -> LROPoller[FeatureStoreEntity]:
+ """Create or update FeatureStoreEntity
+
+ :param feature_store_entity: FeatureStoreEntity definition.
+ :type feature_store_entity: FeatureStoreEntity
+ :return: An instance of LROPoller that returns a FeatureStoreEntity.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.FeatureStoreEntity]
+ """
+ feature_store_entity_resource = FeatureStoreEntity._to_rest_object(feature_store_entity)
+
+ return self._operation.begin_create_or_update(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ name=feature_store_entity.name,
+ version=feature_store_entity.version,
+ body=feature_store_entity_resource,
+ cls=lambda response, deserialized, headers: FeatureStoreEntity._from_rest_object(deserialized),
+ **kwargs,
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureStoreEntity.Archive", ActivityType.PUBLICAPI)
+ def archive(
+ self,
+ name: str,
+ version: str,
+ **kwargs: Dict,
+ ) -> None:
+ """Archive a FeatureStoreEntity asset.
+
+ :param name: Name of FeatureStoreEntity asset.
+ :type name: str
+ :param version: Version of FeatureStoreEntity asset.
+ :type version: str
+ :return: None
+ """
+
+ _archive_or_restore(
+ asset_operations=self,
+ version_operation=self._operation,
+ is_archived=True,
+ name=name,
+ version=version,
+ **kwargs,
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureStoreEntity.Restore", ActivityType.PUBLICAPI)
+ def restore(
+ self,
+ name: str,
+ version: str,
+ **kwargs: Dict,
+ ) -> None:
+ """Restore an archived FeatureStoreEntity asset.
+
+ :param name: Name of FeatureStoreEntity asset.
+ :type name: str
+ :param version: Version of FeatureStoreEntity asset.
+ :type version: str
+ :return: None
+ """
+
+ _archive_or_restore(
+ asset_operations=self,
+ version_operation=self._operation,
+ is_archived=False,
+ name=name,
+ version=version,
+ **kwargs,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_store_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_store_operations.py
new file mode 100644
index 00000000..a7f8e93d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_feature_store_operations.py
@@ -0,0 +1,566 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import re
+import uuid
+from typing import Any, Dict, Iterable, Optional, cast
+
+from marshmallow import ValidationError
+
+from azure.ai.ml._restclient.v2024_10_01_preview import AzureMachineLearningWorkspaces as ServiceClient102024Preview
+from azure.ai.ml._restclient.v2024_10_01_preview.models import ManagedNetworkProvisionOptions
+from azure.ai.ml._scope_dependent_operations import OperationsContainer, OperationScope
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._utils.utils import camel_to_snake
+from azure.ai.ml.constants import ManagedServiceIdentityType
+from azure.ai.ml.constants._common import Scope, WorkspaceKind
+from azure.ai.ml.entities import (
+ IdentityConfiguration,
+ ManagedIdentityConfiguration,
+ ManagedNetworkProvisionStatus,
+ WorkspaceConnection,
+)
+from azure.ai.ml.entities._feature_store._constants import (
+ OFFLINE_MATERIALIZATION_STORE_TYPE,
+ OFFLINE_STORE_CONNECTION_CATEGORY,
+ OFFLINE_STORE_CONNECTION_NAME,
+ ONLINE_MATERIALIZATION_STORE_TYPE,
+ ONLINE_STORE_CONNECTION_CATEGORY,
+ ONLINE_STORE_CONNECTION_NAME,
+ STORE_REGEX_PATTERN,
+)
+from azure.ai.ml.entities._feature_store.feature_store import FeatureStore
+from azure.ai.ml.entities._feature_store.materialization_store import MaterializationStore
+from azure.ai.ml.entities._workspace.feature_store_settings import FeatureStoreSettings
+from azure.core.credentials import TokenCredential
+from azure.core.exceptions import ResourceNotFoundError
+from azure.core.polling import LROPoller
+from azure.core.tracing.decorator import distributed_trace
+
+from ._workspace_operations_base import WorkspaceOperationsBase
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class FeatureStoreOperations(WorkspaceOperationsBase):
+ """FeatureStoreOperations.
+
+ You should not instantiate this class directly. Instead, you should
+ create an MLClient instance that instantiates it for you and
+ attaches it as an attribute.
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ service_client: ServiceClient102024Preview,
+ all_operations: OperationsContainer,
+ credentials: Optional[TokenCredential] = None,
+ **kwargs: Dict,
+ ):
+ ops_logger.update_filter()
+ self._provision_network_operation = service_client.managed_network_provisions
+ super().__init__(
+ operation_scope=operation_scope,
+ service_client=service_client,
+ all_operations=all_operations,
+ credentials=credentials,
+ **kwargs,
+ )
+ self._workspace_connection_operation = service_client.workspace_connections
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureStore.List", ActivityType.PUBLICAPI)
+ # pylint: disable=unused-argument
+ def list(self, *, scope: str = Scope.RESOURCE_GROUP, **kwargs: Dict) -> Iterable[FeatureStore]:
+ """List all feature stores that the user has access to in the current
+ resource group or subscription.
+
+ :keyword scope: scope of the listing, "resource_group" or "subscription", defaults to "resource_group"
+ :paramtype scope: str
+ :return: An iterator like instance of FeatureStore objects
+ :rtype: ~azure.core.paging.ItemPaged[FeatureStore]
+ """
+
+ if scope == Scope.SUBSCRIPTION:
+ return cast(
+ Iterable[FeatureStore],
+ self._operation.list_by_subscription(
+ cls=lambda objs: [
+ FeatureStore._from_rest_object(filterObj)
+ for filterObj in filter(lambda ws: ws.kind.lower() == WorkspaceKind.FEATURE_STORE, objs)
+ ],
+ ),
+ )
+ return cast(
+ Iterable[FeatureStore],
+ self._operation.list_by_resource_group(
+ self._resource_group_name,
+ cls=lambda objs: [
+ FeatureStore._from_rest_object(filterObj)
+ for filterObj in filter(lambda ws: ws.kind.lower() == WorkspaceKind.FEATURE_STORE, objs)
+ ],
+ ),
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureStore.Get", ActivityType.PUBLICAPI)
+ # pylint: disable=arguments-renamed
+ def get(self, name: str, **kwargs: Any) -> FeatureStore:
+ """Get a feature store by name.
+
+ :param name: Name of the feature store.
+ :type name: str
+ :raises ~azure.core.exceptions.HttpResponseError: Raised if the corresponding name and version cannot be
+ retrieved from the service.
+ :return: The feature store with the provided name.
+ :rtype: FeatureStore
+ """
+
+ feature_store: Any = None
+ resource_group = kwargs.get("resource_group") or self._resource_group_name
+ rest_workspace_obj = kwargs.get("rest_workspace_obj", None) or self._operation.get(resource_group, name)
+ if (
+ rest_workspace_obj
+ and rest_workspace_obj.kind
+ and rest_workspace_obj.kind.lower() == WorkspaceKind.FEATURE_STORE
+ ):
+ feature_store = FeatureStore._from_rest_object(rest_workspace_obj)
+
+ if feature_store:
+ offline_store_connection = None
+ if (
+ rest_workspace_obj.feature_store_settings
+ and rest_workspace_obj.feature_store_settings.offline_store_connection_name
+ ):
+ try:
+ offline_store_connection = self._workspace_connection_operation.get(
+ resource_group, name, rest_workspace_obj.feature_store_settings.offline_store_connection_name
+ )
+ except ResourceNotFoundError:
+ pass
+
+ if offline_store_connection:
+ if (
+ offline_store_connection.properties
+ and offline_store_connection.properties.category == OFFLINE_STORE_CONNECTION_CATEGORY
+ ):
+ feature_store.offline_store = MaterializationStore(
+ type=OFFLINE_MATERIALIZATION_STORE_TYPE, target=offline_store_connection.properties.target
+ )
+
+ online_store_connection = None
+ if (
+ rest_workspace_obj.feature_store_settings
+ and rest_workspace_obj.feature_store_settings.online_store_connection_name
+ ):
+ try:
+ online_store_connection = self._workspace_connection_operation.get(
+ resource_group, name, rest_workspace_obj.feature_store_settings.online_store_connection_name
+ )
+ except ResourceNotFoundError:
+ pass
+
+ if online_store_connection:
+ if (
+ online_store_connection.properties
+ and online_store_connection.properties.category == ONLINE_STORE_CONNECTION_CATEGORY
+ ):
+ feature_store.online_store = MaterializationStore(
+ type=ONLINE_MATERIALIZATION_STORE_TYPE, target=online_store_connection.properties.target
+ )
+
+ # materialization identity = identity when created through feature store operations
+ if (
+ offline_store_connection and offline_store_connection.name.startswith(OFFLINE_STORE_CONNECTION_NAME)
+ ) or (online_store_connection and online_store_connection.name.startswith(ONLINE_STORE_CONNECTION_NAME)):
+ if (
+ feature_store.identity
+ and feature_store.identity.user_assigned_identities
+ and isinstance(feature_store.identity.user_assigned_identities[0], ManagedIdentityConfiguration)
+ ):
+ feature_store.materialization_identity = feature_store.identity.user_assigned_identities[0]
+
+ return feature_store
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureStore.BeginCreate", ActivityType.PUBLICAPI)
+ # pylint: disable=arguments-differ
+ def begin_create(
+ self,
+ feature_store: FeatureStore,
+ *,
+ grant_materialization_permissions: bool = True,
+ update_dependent_resources: bool = False,
+ **kwargs: Dict,
+ ) -> LROPoller[FeatureStore]:
+ """Create a new FeatureStore.
+
+ Returns the feature store if already exists.
+
+ :param feature_store: FeatureStore definition.
+ :type feature_store: FeatureStore
+ :keyword grant_materialization_permissions: Whether or not to grant materialization permissions.
+ Defaults to True.
+ :paramtype grant_materialization_permissions: bool
+ :keyword update_dependent_resources: Whether or not to update dependent resources. Defaults to False.
+ :type update_dependent_resources: bool
+ :return: An instance of LROPoller that returns a FeatureStore.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.FeatureStore]
+ """
+ resource_group = kwargs.get("resource_group", self._resource_group_name)
+ try:
+ rest_workspace_obj = self._operation.get(resource_group, feature_store.name)
+ if rest_workspace_obj:
+ return self.begin_update(
+ feature_store=feature_store, update_dependent_resources=update_dependent_resources, kwargs=kwargs
+ )
+ except Exception: # pylint: disable=W0718
+ pass
+
+ if feature_store.offline_store and feature_store.offline_store.type != OFFLINE_MATERIALIZATION_STORE_TYPE:
+ raise ValidationError("offline store type should be azure_data_lake_gen2")
+
+ if feature_store.online_store and feature_store.online_store.type != ONLINE_MATERIALIZATION_STORE_TYPE:
+ raise ValidationError("online store type should be redis")
+
+ # generate a random suffix for online/offline store connection name,
+ # please don't refer to OFFLINE_STORE_CONNECTION_NAME and
+ # ONLINE_STORE_CONNECTION_NAME directly from FeatureStore
+ random_string = uuid.uuid4().hex[:8]
+ if feature_store._feature_store_settings is not None:
+ feature_store._feature_store_settings.offline_store_connection_name = (
+ f"{OFFLINE_STORE_CONNECTION_NAME}-{random_string}"
+ )
+ feature_store._feature_store_settings.online_store_connection_name = (
+ f"{ONLINE_STORE_CONNECTION_NAME}-{random_string}"
+ if feature_store.online_store and feature_store.online_store.target
+ else None
+ )
+
+ def get_callback() -> FeatureStore:
+ return self.get(feature_store.name)
+
+ return super().begin_create(
+ workspace=feature_store,
+ update_dependent_resources=update_dependent_resources,
+ get_callback=get_callback,
+ offline_store_target=feature_store.offline_store.target if feature_store.offline_store else None,
+ online_store_target=feature_store.online_store.target if feature_store.online_store else None,
+ materialization_identity=feature_store.materialization_identity,
+ grant_materialization_permissions=grant_materialization_permissions,
+ **kwargs,
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureStore.BeginUpdate", ActivityType.PUBLICAPI)
+ # pylint: disable=arguments-renamed
+ # pylint: disable=too-many-locals, too-many-branches, too-many-statements
+ def begin_update( # pylint: disable=C4758
+ self,
+ feature_store: FeatureStore,
+ *,
+ grant_materialization_permissions: bool = True,
+ update_dependent_resources: bool = False,
+ **kwargs: Any,
+ ) -> LROPoller[FeatureStore]:
+ """Update friendly name, description, online store connection, offline store connection, materialization
+ identities or tags of a feature store.
+
+ :param feature_store: FeatureStore resource.
+ :type feature_store: FeatureStore
+ :keyword grant_materialization_permissions: Whether or not to grant materialization permissions.
+ Defaults to True.
+ :paramtype grant_materialization_permissions: bool
+ :keyword update_dependent_resources: gives your consent to update the feature store dependent resources.
+ Note that updating the feature store attached Azure Container Registry resource may break lineage
+ of previous jobs or your ability to rerun earlier jobs in this feature store.
+ Also, updating the feature store attached Azure Application Insights resource may break lineage of
+ deployed inference endpoints this feature store. Only set this argument if you are sure that you want
+ to perform this operation. If this argument is not set, the command to update
+ Azure Container Registry and Azure Application Insights will fail.
+ :keyword application_insights: Application insights resource for feature store. Defaults to None.
+ :paramtype application_insights: Optional[str]
+ :keyword container_registry: Container registry resource for feature store. Defaults to None.
+ :paramtype container_registry: Optional[str]
+ :return: An instance of LROPoller that returns a FeatureStore.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.FeatureStore]
+ """
+ resource_group = kwargs.get("resource_group", self._resource_group_name)
+ rest_workspace_obj = self._operation.get(resource_group, feature_store.name)
+ if not (
+ rest_workspace_obj
+ and rest_workspace_obj.kind
+ and rest_workspace_obj.kind.lower() == WorkspaceKind.FEATURE_STORE
+ ):
+ raise ValidationError("{0} is not a feature store".format(feature_store.name))
+
+ resource_group = kwargs.get("resource_group") or self._resource_group_name
+ offline_store = kwargs.get("offline_store", feature_store.offline_store)
+ online_store = kwargs.get("online_store", feature_store.online_store)
+ offline_store_target_to_update = offline_store.target if offline_store else None
+ online_store_target_to_update = online_store.target if online_store else None
+ update_workspace_role_assignment = False
+ update_offline_store_role_assignment = False
+ update_online_store_role_assignment = False
+
+ update_offline_store_connection = False
+ update_online_store_connection = False
+
+ existing_materialization_identity = None
+ if rest_workspace_obj.identity:
+ identity = IdentityConfiguration._from_workspace_rest_object(rest_workspace_obj.identity)
+ if (
+ identity
+ and identity.user_assigned_identities
+ and isinstance(identity.user_assigned_identities[0], ManagedIdentityConfiguration)
+ ):
+ existing_materialization_identity = identity.user_assigned_identities[0]
+
+ materialization_identity = kwargs.get(
+ "materialization_identity", feature_store.materialization_identity or existing_materialization_identity
+ )
+
+ if (
+ feature_store.materialization_identity
+ and feature_store.materialization_identity.resource_id
+ and (
+ not existing_materialization_identity
+ or feature_store.materialization_identity.resource_id != existing_materialization_identity.resource_id
+ )
+ ):
+ update_workspace_role_assignment = True
+ update_offline_store_role_assignment = True
+ update_online_store_role_assignment = True
+
+ self._validate_offline_store(offline_store=offline_store)
+
+ if (
+ rest_workspace_obj.feature_store_settings
+ and rest_workspace_obj.feature_store_settings.offline_store_connection_name
+ ):
+ existing_offline_store_connection = self._workspace_connection_operation.get(
+ resource_group,
+ feature_store.name,
+ rest_workspace_obj.feature_store_settings.offline_store_connection_name,
+ )
+
+ offline_store_target_to_update = (
+ offline_store_target_to_update or existing_offline_store_connection.properties.target
+ )
+ if offline_store and (
+ not existing_offline_store_connection.properties
+ or existing_offline_store_connection.properties.target != offline_store.target
+ ):
+ update_offline_store_connection = True
+ update_offline_store_role_assignment = True
+ module_logger.info(
+ "Warning: You have changed the offline store connection, "
+ "any data that was materialized "
+ "earlier will not be available. You have to run backfill again."
+ )
+ elif offline_store_target_to_update:
+ update_offline_store_connection = True
+ update_offline_store_role_assignment = True
+
+ if online_store and online_store.type != ONLINE_MATERIALIZATION_STORE_TYPE:
+ raise ValidationError("online store type should be redis")
+
+ if (
+ rest_workspace_obj.feature_store_settings
+ and rest_workspace_obj.feature_store_settings.online_store_connection_name
+ ):
+ existing_online_store_connection = self._workspace_connection_operation.get(
+ resource_group,
+ feature_store.name,
+ rest_workspace_obj.feature_store_settings.online_store_connection_name,
+ )
+
+ online_store_target_to_update = (
+ online_store_target_to_update or existing_online_store_connection.properties.target
+ )
+ if online_store and (
+ not existing_online_store_connection.properties
+ or existing_online_store_connection.properties.target != online_store.target
+ ):
+ update_online_store_connection = True
+ update_online_store_role_assignment = True
+ module_logger.info(
+ "Warning: You have changed the online store connection, "
+ "any data that was materialized earlier "
+ "will not be available. You have to run backfill again."
+ )
+ elif online_store_target_to_update:
+ update_online_store_connection = True
+ update_online_store_role_assignment = True
+
+ feature_store_settings: Any = FeatureStoreSettings._from_rest_object(rest_workspace_obj.feature_store_settings)
+
+ # generate a random suffix for online/offline store connection name
+ random_string = uuid.uuid4().hex[:8]
+ if offline_store:
+ if materialization_identity:
+ if update_offline_store_connection:
+ offline_store_connection_name_new = f"{OFFLINE_STORE_CONNECTION_NAME}-{random_string}"
+ offline_store_connection = WorkspaceConnection(
+ name=offline_store_connection_name_new,
+ type=offline_store.type,
+ target=offline_store.target,
+ credentials=materialization_identity,
+ )
+ rest_offline_store_connection = offline_store_connection._to_rest_object()
+ self._workspace_connection_operation.create(
+ resource_group_name=resource_group,
+ workspace_name=feature_store.name,
+ connection_name=offline_store_connection_name_new,
+ body=rest_offline_store_connection,
+ )
+ feature_store_settings.offline_store_connection_name = offline_store_connection_name_new
+ else:
+ module_logger.info(
+ "No need to update Offline store connection, name: %s.\n",
+ feature_store_settings.offline_store_connection_name,
+ )
+ else:
+ raise ValidationError("Materialization identity is required to setup offline store connection")
+
+ if online_store:
+ if materialization_identity:
+ if update_online_store_connection:
+ online_store_connection_name_new = f"{ONLINE_STORE_CONNECTION_NAME}-{random_string}"
+ online_store_connection = WorkspaceConnection(
+ name=online_store_connection_name_new,
+ type=online_store.type,
+ target=online_store.target,
+ credentials=materialization_identity,
+ )
+ rest_online_store_connection = online_store_connection._to_rest_object()
+ self._workspace_connection_operation.create(
+ resource_group_name=resource_group,
+ workspace_name=feature_store.name,
+ connection_name=online_store_connection_name_new,
+ body=rest_online_store_connection,
+ )
+ feature_store_settings.online_store_connection_name = online_store_connection_name_new
+ else:
+ module_logger.info(
+ "No need to update Online store connection, name: %s.\n",
+ feature_store_settings.online_store_connection_name,
+ )
+ else:
+ raise ValidationError("Materialization identity is required to setup online store connection")
+
+ if not offline_store_target_to_update:
+ update_offline_store_role_assignment = False
+ if not online_store_target_to_update:
+ update_online_store_role_assignment = False
+
+ user_defined_cr = feature_store.compute_runtime
+ if (
+ user_defined_cr
+ and user_defined_cr.spark_runtime_version != feature_store_settings.compute_runtime.spark_runtime_version
+ ):
+ # update user defined compute runtime
+ feature_store_settings.compute_runtime = feature_store.compute_runtime
+
+ identity = kwargs.pop("identity", feature_store.identity)
+ if materialization_identity:
+ identity = IdentityConfiguration(
+ type=camel_to_snake(ManagedServiceIdentityType.SYSTEM_ASSIGNED_USER_ASSIGNED),
+ # At most 1 UAI can be attached to workspace when materialization is enabled
+ user_assigned_identities=[materialization_identity],
+ )
+
+ def deserialize_callback(rest_obj: Any) -> FeatureStore:
+ return self.get(rest_obj.name, rest_workspace_obj=rest_obj)
+
+ return super().begin_update(
+ workspace=feature_store,
+ update_dependent_resources=update_dependent_resources,
+ deserialize_callback=deserialize_callback,
+ feature_store_settings=feature_store_settings,
+ identity=identity,
+ grant_materialization_permissions=grant_materialization_permissions,
+ update_workspace_role_assignment=update_workspace_role_assignment,
+ update_offline_store_role_assignment=update_offline_store_role_assignment,
+ update_online_store_role_assignment=update_online_store_role_assignment,
+ materialization_identity_id=(
+ materialization_identity.resource_id
+ if update_workspace_role_assignment
+ or update_offline_store_role_assignment
+ or update_online_store_role_assignment
+ else None
+ ),
+ offline_store_target=offline_store_target_to_update if update_offline_store_role_assignment else None,
+ online_store_target=online_store_target_to_update if update_online_store_role_assignment else None,
+ **kwargs,
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureStore.BeginDelete", ActivityType.PUBLICAPI)
+ def begin_delete(self, name: str, *, delete_dependent_resources: bool = False, **kwargs: Any) -> LROPoller[None]:
+ """Delete a FeatureStore.
+
+ :param name: Name of the FeatureStore
+ :type name: str
+ :keyword delete_dependent_resources: Whether to delete resources associated with the feature store,
+ i.e., container registry, storage account, key vault, and application insights.
+ The default is False. Set to True to delete these resources.
+ :paramtype delete_dependent_resources: bool
+ :return: A poller to track the operation status.
+ :rtype: ~azure.core.polling.LROPoller[None]
+ """
+ resource_group = kwargs.get("resource_group") or self._resource_group_name
+ rest_workspace_obj = self._operation.get(resource_group, name)
+ if not (
+ rest_workspace_obj
+ and rest_workspace_obj.kind
+ and rest_workspace_obj.kind.lower() == WorkspaceKind.FEATURE_STORE
+ ):
+ raise ValidationError("{0} is not a feature store".format(name))
+
+ return super().begin_delete(name=name, delete_dependent_resources=delete_dependent_resources, **kwargs)
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "FeatureStore.BeginProvisionNetwork", ActivityType.PUBLICAPI)
+ def begin_provision_network(
+ self,
+ *,
+ feature_store_name: Optional[str] = None,
+ include_spark: bool = False,
+ **kwargs: Any,
+ ) -> LROPoller[ManagedNetworkProvisionStatus]:
+ """Triggers the feature store to provision the managed network. Specifying spark enabled
+ as true prepares the feature store managed network for supporting Spark.
+
+ :keyword feature_store_name: Name of the feature store.
+ :paramtype feature_store_name: str
+ :keyword include_spark: Whether to include spark in the network provisioning. Defaults to False.
+ :paramtype include_spark: bool
+ :return: An instance of LROPoller.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.ManagedNetworkProvisionStatus]
+ """
+ workspace_name = self._check_workspace_name(feature_store_name)
+
+ poller = self._provision_network_operation.begin_provision_managed_network(
+ self._resource_group_name,
+ workspace_name,
+ ManagedNetworkProvisionOptions(include_spark=include_spark),
+ polling=True,
+ cls=lambda response, deserialized, headers: ManagedNetworkProvisionStatus._from_rest_object(deserialized),
+ **kwargs,
+ )
+ module_logger.info("Provision network request initiated for feature store: %s\n", workspace_name)
+ return poller
+
+ def _validate_offline_store(self, offline_store: MaterializationStore) -> None:
+ store_regex = re.compile(STORE_REGEX_PATTERN)
+ if offline_store and store_regex.match(offline_store.target) is None:
+ raise ValidationError(f"Invalid AzureML offlinestore target ARM Id {offline_store.target}")
+ if offline_store and offline_store.type != OFFLINE_MATERIALIZATION_STORE_TYPE:
+ raise ValidationError("offline store type should be azure_data_lake_gen2")
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py
new file mode 100644
index 00000000..28e409c7
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_index_operations.py
@@ -0,0 +1,483 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+import json
+from typing import Any, Callable, Dict, Iterable, List, Optional, Union
+
+from azure.ai.ml._artifacts._artifact_utilities import _check_and_upload_path
+
+# cspell:disable-next-line
+from azure.ai.ml._restclient.azure_ai_assets_v2024_04_01.azureaiassetsv20240401 import (
+ MachineLearningServicesClient as AzureAiAssetsClient042024,
+)
+
+# cspell:disable-next-line
+from azure.ai.ml._restclient.azure_ai_assets_v2024_04_01.azureaiassetsv20240401.models import Index as RestIndex
+from azure.ai.ml._restclient.v2023_04_01_preview.models import ListViewType
+from azure.ai.ml._scope_dependent_operations import (
+ OperationConfig,
+ OperationsContainer,
+ OperationScope,
+ _ScopeDependentOperations,
+)
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._asset_utils import (
+ _resolve_label_to_asset,
+ _validate_auto_delete_setting_in_data_output,
+ _validate_workspace_managed_datastore,
+)
+from azure.ai.ml._utils._http_utils import HttpPipeline
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._utils.utils import _get_base_urls_from_discovery_service
+from azure.ai.ml.constants._common import AssetTypes, AzureMLResourceType, WorkspaceDiscoveryUrlKey
+from azure.ai.ml.dsl import pipeline
+from azure.ai.ml.entities import PipelineJob, PipelineJobSettings
+from azure.ai.ml.entities._assets import Index
+from azure.ai.ml.entities._credentials import ManagedIdentityConfiguration, UserIdentityConfiguration
+from azure.ai.ml.entities._indexes import (
+ AzureAISearchConfig,
+ GitSource,
+ IndexDataSource,
+ LocalSource,
+ ModelConfiguration,
+)
+from azure.ai.ml.entities._indexes.data_index_func import index_data as index_data_func
+from azure.ai.ml.entities._indexes.entities.data_index import (
+ CitationRegex,
+ Data,
+ DataIndex,
+ Embedding,
+ IndexSource,
+ IndexStore,
+)
+from azure.ai.ml.entities._indexes.utils._open_ai_utils import build_connection_id, build_open_ai_protocol
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
+from azure.ai.ml.operations._datastore_operations import DatastoreOperations
+from azure.core.credentials import TokenCredential
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class IndexOperations(_ScopeDependentOperations):
+ """Represents a client for performing operations on index assets.
+
+ You should not instantiate this class directly. Instead, you should create MLClient and use this client via the
+ property MLClient.index
+ """
+
+ def __init__(
+ self,
+ *,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ credential: TokenCredential,
+ datastore_operations: DatastoreOperations,
+ all_operations: OperationsContainer,
+ **kwargs: Any,
+ ):
+ super().__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+ self._credential = credential
+ # Dataplane service clients are lazily created as they are needed
+ self.__azure_ai_assets_client: Optional[AzureAiAssetsClient042024] = None
+ # Kwargs to propagate to dataplane service clients
+ self._service_client_kwargs: Dict[str, Any] = kwargs.pop("_service_client_kwargs", {})
+ self._all_operations = all_operations
+
+ self._datastore_operation: DatastoreOperations = datastore_operations
+ self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline")
+
+ # Maps a label to a function which given an asset name,
+ # returns the asset associated with the label
+ self._managed_label_resolver: Dict[str, Callable[[str], Index]] = {"latest": self._get_latest_version}
+
+ @property
+ def _azure_ai_assets(self) -> AzureAiAssetsClient042024:
+ """Lazily instantiated client for azure ai assets api.
+
+ .. note::
+
+ Property is lazily instantiated since the api's base url depends on the user's workspace, and
+ must be fetched remotely.
+ """
+ if self.__azure_ai_assets_client is None:
+ workspace_operations = self._all_operations.all_operations[AzureMLResourceType.WORKSPACE]
+
+ endpoint = _get_base_urls_from_discovery_service(
+ workspace_operations, self._operation_scope.workspace_name, self._requests_pipeline
+ )[WorkspaceDiscoveryUrlKey.API]
+
+ self.__azure_ai_assets_client = AzureAiAssetsClient042024(
+ endpoint=endpoint,
+ subscription_id=self._operation_scope.subscription_id,
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._operation_scope.workspace_name,
+ credential=self._credential,
+ **self._service_client_kwargs,
+ )
+
+ return self.__azure_ai_assets_client
+
+ @monitor_with_activity(ops_logger, "Index.CreateOrUpdate", ActivityType.PUBLICAPI)
+ def create_or_update(self, index: Index, **kwargs) -> Index:
+ """Returns created or updated index asset.
+
+ If not already in storage, asset will be uploaded to the workspace's default datastore.
+
+ :param index: Index asset object.
+ :type index: Index
+ :return: Index asset object.
+ :rtype: ~azure.ai.ml.entities.Index
+ """
+
+ if not index.name:
+ msg = "Must specify a name."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.INDEX,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+
+ if not index.version:
+ if not index._auto_increment_version:
+ msg = "Must specify a version."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.INDEX,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+
+ next_version = self._azure_ai_assets.indexes.get_next_version(index.name).next_version
+
+ if next_version is None:
+ msg = "Version not specified, could not automatically increment version. Set a version to resolve."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.INDEX,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+
+ index.version = str(next_version)
+
+ _ = _check_and_upload_path(
+ artifact=index,
+ asset_operations=self,
+ datastore_name=index.datastore,
+ artifact_type=ErrorTarget.INDEX,
+ show_progress=self._show_progress,
+ )
+
+ return Index._from_rest_object(
+ self._azure_ai_assets.indexes.create_or_update(
+ name=index.name, version=index.version, body=index._to_rest_object(), **kwargs
+ )
+ )
+
+ @monitor_with_activity(ops_logger, "Index.Get", ActivityType.PUBLICAPI)
+ def get(self, name: str, *, version: Optional[str] = None, label: Optional[str] = None, **kwargs) -> Index:
+ """Returns information about the specified index asset.
+
+ :param str name: Name of the index asset.
+ :keyword Optional[str] version: Version of the index asset. Mutually exclusive with `label`.
+ :keyword Optional[str] label: The label of the index asset. Mutually exclusive with `version`.
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Index cannot be successfully validated.
+ Details will be provided in the error message.
+ :return: Index asset object.
+ :rtype: ~azure.ai.ml.entities.Index
+ """
+ if version and label:
+ msg = "Cannot specify both version and label."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.INDEX,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ if label:
+ return _resolve_label_to_asset(self, name, label)
+
+ if not version:
+ msg = "Must provide either version or label."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.INDEX,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+
+ index_version_resource = self._azure_ai_assets.indexes.get(name=name, version=version, **kwargs)
+
+ return Index._from_rest_object(index_version_resource)
+
+ def _get_latest_version(self, name: str) -> Index:
+ return Index._from_rest_object(self._azure_ai_assets.indexes.get_latest(name))
+
+ @monitor_with_activity(ops_logger, "Index.List", ActivityType.PUBLICAPI)
+ def list(
+ self, name: Optional[str] = None, *, list_view_type: ListViewType = ListViewType.ACTIVE_ONLY, **kwargs
+ ) -> Iterable[Index]:
+ """List all Index assets in workspace.
+
+ If an Index is specified by name, list all version of that Index.
+
+ :param name: Name of the model.
+ :type name: Optional[str]
+ :keyword list_view_type: View type for including/excluding (for example) archived models.
+ Defaults to :attr:`ListViewType.ACTIVE_ONLY`.
+ :paramtype list_view_type: ListViewType
+ :return: An iterator like instance of Index objects
+ :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.Index]
+ """
+
+ def cls(rest_indexes: Iterable[RestIndex]) -> List[Index]:
+ return [Index._from_rest_object(i) for i in rest_indexes]
+
+ if name is None:
+ return self._azure_ai_assets.indexes.list_latest(cls=cls, **kwargs)
+
+ return self._azure_ai_assets.indexes.list(name, list_view_type=list_view_type, cls=cls, **kwargs)
+
+ def build_index(
+ self,
+ *,
+ ######## required args ##########
+ name: str,
+ embeddings_model_config: ModelConfiguration,
+ ######## chunking information ##########
+ data_source_citation_url: Optional[str] = None,
+ tokens_per_chunk: Optional[int] = None,
+ token_overlap_across_chunks: Optional[int] = None,
+ input_glob: Optional[str] = None,
+ ######## other generic args ########
+ document_path_replacement_regex: Optional[str] = None,
+ ######## Azure AI Search index info ########
+ index_config: Optional[AzureAISearchConfig] = None, # todo better name?
+ ######## data source info ########
+ input_source: Union[IndexDataSource, str],
+ input_source_credential: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None,
+ ) -> Union["Index", "Job"]: # type: ignore[name-defined]
+ """Builds an index on the cloud using the Azure AI Resources service.
+
+ :keyword name: The name of the index to be created.
+ :paramtype name: str
+ :keyword embeddings_model_config: Model config for the embedding model.
+ :paramtype embeddings_model_config: ~azure.ai.ml.entities._indexes.ModelConfiguration
+ :keyword data_source_citation_url: The URL of the data source.
+ :paramtype data_source_citation_url: Optional[str]
+ :keyword tokens_per_chunk: The size of chunks to be used for indexing.
+ :paramtype tokens_per_chunk: Optional[int]
+ :keyword token_overlap_across_chunks: The amount of overlap between chunks.
+ :paramtype token_overlap_across_chunks: Optional[int]
+ :keyword input_glob: The glob pattern to be used for indexing.
+ :paramtype input_glob: Optional[str]
+ :keyword document_path_replacement_regex: The regex pattern for replacing document paths.
+ :paramtype document_path_replacement_regex: Optional[str]
+ :keyword index_config: The configuration for the ACS output.
+ :paramtype index_config: Optional[~azure.ai.ml.entities._indexes.AzureAISearchConfig]
+ :keyword input_source: The input source for the index.
+ :paramtype input_source: Union[~azure.ai.ml.entities._indexes.IndexDataSource, str]
+ :keyword input_source_credential: The identity to be used for the index.
+ :paramtype input_source_credential: Optional[Union[~azure.ai.ml.entities.ManagedIdentityConfiguration,
+ ~azure.ai.ml.entities.UserIdentityConfiguration]]
+ :return: If the `source_input` is a GitSource, returns a created DataIndex Job object.
+ :rtype: Union[~azure.ai.ml.entities.Index, ~azure.ai.ml.entities.Job]
+ :raises ValueError: If the `source_input` is not type ~typing.Str or
+ ~azure.ai.ml.entities._indexes.LocalSource.
+ """
+ if document_path_replacement_regex:
+ document_path_replacement_regex = json.loads(document_path_replacement_regex)
+
+ data_index = DataIndex(
+ name=name,
+ source=IndexSource(
+ input_data=Data(
+ type="uri_folder",
+ path=".",
+ ),
+ input_glob=input_glob,
+ chunk_size=tokens_per_chunk,
+ chunk_overlap=token_overlap_across_chunks,
+ citation_url=data_source_citation_url,
+ citation_url_replacement_regex=(
+ CitationRegex(
+ match_pattern=document_path_replacement_regex["match_pattern"], # type: ignore[index]
+ replacement_pattern=document_path_replacement_regex[
+ "replacement_pattern" # type: ignore[index]
+ ],
+ )
+ if document_path_replacement_regex
+ else None
+ ),
+ ),
+ embedding=Embedding(
+ model=build_open_ai_protocol(
+ model=embeddings_model_config.model_name, deployment=embeddings_model_config.deployment_name
+ ),
+ connection=build_connection_id(embeddings_model_config.connection_name, self._operation_scope),
+ ),
+ index=(
+ IndexStore(
+ type="acs",
+ connection=build_connection_id(index_config.connection_id, self._operation_scope),
+ name=index_config.index_name,
+ )
+ if index_config is not None
+ else IndexStore(type="faiss")
+ ),
+ # name is replaced with a unique value each time the job is run
+ path=f"azureml://datastores/workspaceblobstore/paths/indexes/{name}/{{name}}",
+ )
+
+ if isinstance(input_source, LocalSource):
+ data_index.source.input_data = Data(
+ type="uri_folder",
+ path=input_source.input_data.path,
+ )
+
+ return self._create_data_indexing_job(data_index=data_index, identity=input_source_credential)
+
+ if isinstance(input_source, str):
+ data_index.source.input_data = Data(
+ type="uri_folder",
+ path=input_source,
+ )
+
+ return self._create_data_indexing_job(data_index=data_index, identity=input_source_credential)
+
+ if isinstance(input_source, GitSource):
+ from azure.ai.ml import MLClient
+
+ ml_registry = MLClient(credential=self._credential, registry_name="azureml")
+ git_clone_component = ml_registry.components.get("llm_rag_git_clone", label="latest")
+
+ # Clone Git Repo and use as input to index_job
+ @pipeline(default_compute="serverless") # type: ignore[call-overload]
+ def git_to_index(
+ git_url,
+ branch_name="",
+ git_connection_id="",
+ ):
+ git_clone = git_clone_component(git_repository=git_url, branch_name=branch_name)
+ git_clone.environment_variables["AZUREML_WORKSPACE_CONNECTION_ID_GIT"] = git_connection_id
+
+ index_job = index_data_func(
+ description=data_index.description,
+ data_index=data_index,
+ input_data_override=git_clone.outputs.output_data,
+ ml_client=MLClient(
+ subscription_id=self._subscription_id,
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ credential=self._credential,
+ ),
+ )
+ # pylint: disable=no-member
+ return index_job.outputs
+
+ git_index_job = git_to_index(
+ git_url=input_source.url,
+ branch_name=input_source.branch_name,
+ git_connection_id=input_source.connection_id,
+ )
+ # Ensure repo cloned each run to get latest, comment out to have first clone reused.
+ git_index_job.settings.force_rerun = True
+
+ # Submit the DataIndex Job
+ return self._all_operations.all_operations[AzureMLResourceType.JOB].create_or_update(git_index_job)
+ raise ValueError(f"Unsupported input source type {type(input_source)}")
+
+ def _create_data_indexing_job(
+ self,
+ data_index: DataIndex,
+ identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]] = None,
+ compute: str = "serverless",
+ serverless_instance_type: Optional[str] = None,
+ input_data_override: Optional[Input] = None,
+ submit_job: bool = True,
+ **kwargs,
+ ) -> PipelineJob:
+ """
+ Returns the data import job that is creating the data asset.
+
+ :param data_index: DataIndex object.
+ :type data_index: azure.ai.ml.entities._dataindex
+ :param identity: Identity configuration for the job.
+ :type identity: Optional[Union[ManagedIdentityConfiguration, UserIdentityConfiguration]]
+ :param compute: The compute target to use for the job. Default: "serverless".
+ :type compute: str
+ :param serverless_instance_type: The instance type to use for serverless compute.
+ :type serverless_instance_type: Optional[str]
+ :param input_data_override: Input data override for the job.
+ Used to pipe output of step into DataIndex Job in a pipeline.
+ :type input_data_override: Optional[Input]
+ :param submit_job: Whether to submit the job to the service. Default: True.
+ :type submit_job: bool
+ :return: data import job object.
+ :rtype: ~azure.ai.ml.entities.PipelineJob.
+ """
+ # pylint: disable=no-member
+ from azure.ai.ml import MLClient
+
+ default_name = "data_index_" + data_index.name if data_index.name is not None else ""
+ experiment_name = kwargs.pop("experiment_name", None) or default_name
+ data_index.type = AssetTypes.URI_FOLDER
+
+ # avoid specifying auto_delete_setting in job output now
+ _validate_auto_delete_setting_in_data_output(data_index.auto_delete_setting)
+
+ # block customer specified path on managed datastore
+ data_index.path = _validate_workspace_managed_datastore(data_index.path)
+
+ if "${{name}}" not in str(data_index.path) and "{name}" not in str(data_index.path):
+ data_index.path = str(data_index.path).rstrip("/") + "/${{name}}"
+
+ index_job = index_data_func(
+ description=data_index.description or kwargs.pop("description", None) or default_name,
+ name=data_index.name or kwargs.pop("name", None),
+ display_name=kwargs.pop("display_name", None) or default_name,
+ experiment_name=experiment_name,
+ compute=compute,
+ serverless_instance_type=serverless_instance_type,
+ data_index=data_index,
+ ml_client=MLClient(
+ subscription_id=self._subscription_id,
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ credential=self._credential,
+ ),
+ identity=identity,
+ input_data_override=input_data_override,
+ **kwargs,
+ )
+ index_pipeline = PipelineJob(
+ description=index_job.description,
+ tags=index_job.tags,
+ name=index_job.name,
+ display_name=index_job.display_name,
+ experiment_name=experiment_name,
+ properties=index_job.properties or {},
+ settings=PipelineJobSettings(force_rerun=True, default_compute=compute),
+ jobs={default_name: index_job},
+ )
+ index_pipeline.properties["azureml.mlIndexAssetName"] = data_index.name
+ index_pipeline.properties["azureml.mlIndexAssetKind"] = data_index.index.type
+ index_pipeline.properties["azureml.mlIndexAssetSource"] = kwargs.pop("mlindex_asset_source", "Data Asset")
+
+ if submit_job:
+ return self._all_operations.all_operations[AzureMLResourceType.JOB].create_or_update(
+ job=index_pipeline, skip_validation=True, **kwargs
+ )
+ return index_pipeline
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_job_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_job_operations.py
new file mode 100644
index 00000000..0003c8cd
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_job_operations.py
@@ -0,0 +1,1677 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access, too-many-instance-attributes, too-many-statements, too-many-lines
+import json
+import os.path
+from os import PathLike
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Set, Union, cast
+
+import jwt
+from marshmallow import ValidationError
+
+from azure.ai.ml._artifacts._artifact_utilities import (
+ _upload_and_generate_remote_uri,
+ aml_datastore_path_exists,
+ download_artifact_from_aml_uri,
+)
+from azure.ai.ml._azure_environments import (
+ _get_aml_resource_id_from_metadata,
+ _get_base_url_from_metadata,
+ _resource_to_scopes,
+)
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._restclient.dataset_dataplane import AzureMachineLearningWorkspaces as ServiceClientDatasetDataplane
+from azure.ai.ml._restclient.model_dataplane import AzureMachineLearningWorkspaces as ServiceClientModelDataplane
+from azure.ai.ml._restclient.runhistory import AzureMachineLearningWorkspaces as ServiceClientRunHistory
+from azure.ai.ml._restclient.runhistory.models import Run
+from azure.ai.ml._restclient.v2023_04_01_preview import AzureMachineLearningWorkspaces as ServiceClient022023Preview
+from azure.ai.ml._restclient.v2023_04_01_preview.models import JobBase, ListViewType, UserIdentity
+from azure.ai.ml._restclient.v2023_08_01_preview.models import JobType as RestJobType
+from azure.ai.ml._restclient.v2024_01_01_preview.models import JobBase as JobBase_2401
+from azure.ai.ml._restclient.v2024_10_01_preview.models import JobType as RestJobType_20241001Preview
+from azure.ai.ml._scope_dependent_operations import (
+ OperationConfig,
+ OperationsContainer,
+ OperationScope,
+ _ScopeDependentOperations,
+)
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity, monitor_with_telemetry_mixin
+from azure.ai.ml._utils._http_utils import HttpPipeline
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._utils.utils import (
+ create_requests_pipeline_with_retry,
+ download_text_from_url,
+ is_data_binding_expression,
+ is_private_preview_enabled,
+ is_url,
+)
+from azure.ai.ml.constants._common import (
+ AZUREML_RESOURCE_PROVIDER,
+ BATCH_JOB_CHILD_RUN_OUTPUT_NAME,
+ COMMON_RUNTIME_ENV_VAR,
+ DEFAULT_ARTIFACT_STORE_OUTPUT_NAME,
+ GIT_PATH_PREFIX,
+ LEVEL_ONE_NAMED_RESOURCE_ID_FORMAT,
+ LOCAL_COMPUTE_TARGET,
+ SERVERLESS_COMPUTE,
+ SHORT_URI_FORMAT,
+ SWEEP_JOB_BEST_CHILD_RUN_ID_PROPERTY_NAME,
+ TID_FMT,
+ AssetTypes,
+ AzureMLResourceType,
+ WorkspaceDiscoveryUrlKey,
+)
+from azure.ai.ml.constants._compute import ComputeType
+from azure.ai.ml.constants._job.pipeline import PipelineConstants
+from azure.ai.ml.entities import Compute, Job, PipelineJob, ServiceInstance, ValidationResult
+from azure.ai.ml.entities._assets._artifacts.code import Code
+from azure.ai.ml.entities._builders import BaseNode, Command, Spark
+from azure.ai.ml.entities._datastore._constants import WORKSPACE_BLOB_STORE
+from azure.ai.ml.entities._inputs_outputs import Input
+from azure.ai.ml.entities._job.automl.automl_job import AutoMLJob
+from azure.ai.ml.entities._job.base_job import _BaseJob
+from azure.ai.ml.entities._job.distillation.distillation_job import DistillationJob
+from azure.ai.ml.entities._job.finetuning.finetuning_job import FineTuningJob
+from azure.ai.ml.entities._job.import_job import ImportJob
+from azure.ai.ml.entities._job.job import _is_pipeline_child_job
+from azure.ai.ml.entities._job.parallel.parallel_job import ParallelJob
+from azure.ai.ml.entities._job.to_rest_functions import to_rest_job_object
+from azure.ai.ml.entities._validation import PathAwareSchemaValidatableMixin
+from azure.ai.ml.exceptions import (
+ ComponentException,
+ ErrorCategory,
+ ErrorTarget,
+ JobException,
+ JobParsingError,
+ MlException,
+ PipelineChildJobError,
+ UserErrorException,
+ ValidationErrorType,
+ ValidationException,
+)
+from azure.ai.ml.operations._run_history_constants import RunHistoryConstants
+from azure.ai.ml.sweep import SweepJob
+from azure.core.credentials import TokenCredential
+from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
+from azure.core.polling import LROPoller
+from azure.core.tracing.decorator import distributed_trace
+
+from ..constants._component import ComponentSource
+from ..entities._builders.control_flow_node import ControlFlowNode
+from ..entities._job.pipeline._io import InputOutputBase, PipelineInput, _GroupAttrDict
+from ._component_operations import ComponentOperations
+from ._compute_operations import ComputeOperations
+from ._dataset_dataplane_operations import DatasetDataplaneOperations
+from ._job_ops_helper import get_git_properties, get_job_output_uris_from_dataplane, stream_logs_until_completion
+from ._local_job_invoker import is_local_run, start_run_if_local
+from ._model_dataplane_operations import ModelDataplaneOperations
+from ._operation_orchestrator import (
+ OperationOrchestrator,
+ _AssetResolver,
+ is_ARM_id_for_resource,
+ is_registry_id_for_resource,
+ is_singularity_full_name_for_resource,
+ is_singularity_id_for_resource,
+ is_singularity_short_name_for_resource,
+)
+from ._run_operations import RunOperations
+from ._virtual_cluster_operations import VirtualClusterOperations
+
+try:
+ pass
+except ImportError:
+ pass
+
+if TYPE_CHECKING:
+ from azure.ai.ml.operations import DatastoreOperations
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class JobOperations(_ScopeDependentOperations):
+ """Initiates an instance of JobOperations
+
+ This class should not be instantiated directly. Instead, use the `jobs` attribute of an MLClient object.
+
+ :param operation_scope: Scope variables for the operations classes of an MLClient object.
+ :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope
+ :param operation_config: Common configuration for operations classes of an MLClient object.
+ :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig
+ :param service_client_02_2023_preview: Service client to allow end users to operate on Azure Machine Learning
+ Workspace resources.
+ :type service_client_02_2023_preview: ~azure.ai.ml._restclient.v2023_02_01_preview.AzureMachineLearningWorkspaces
+ :param all_operations: All operations classes of an MLClient object.
+ :type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer
+ :param credential: Credential to use for authentication.
+ :type credential: ~azure.core.credentials.TokenCredential
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client_02_2023_preview: ServiceClient022023Preview,
+ all_operations: OperationsContainer,
+ credential: TokenCredential,
+ **kwargs: Any,
+ ) -> None:
+ super(JobOperations, self).__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+
+ self._operation_2023_02_preview = service_client_02_2023_preview.jobs
+ self._service_client = service_client_02_2023_preview
+ self._all_operations = all_operations
+ self._stream_logs_until_completion = stream_logs_until_completion
+ # Dataplane service clients are lazily created as they are needed
+ self._runs_operations_client: Optional[RunOperations] = None
+ self._dataset_dataplane_operations_client: Optional[DatasetDataplaneOperations] = None
+ self._model_dataplane_operations_client: Optional[ModelDataplaneOperations] = None
+ # Kwargs to propagate to dataplane service clients
+ self._service_client_kwargs = kwargs.pop("_service_client_kwargs", {})
+ self._api_base_url: Optional[str] = None
+ self._container = "azureml"
+ self._credential = credential
+ self._orchestrators = OperationOrchestrator(self._all_operations, self._operation_scope, self._operation_config)
+
+ self.service_client_01_2024_preview = kwargs.pop("service_client_01_2024_preview", None)
+ self.service_client_10_2024_preview = kwargs.pop("service_client_10_2024_preview", None)
+ self.service_client_01_2025_preview = kwargs.pop("service_client_01_2025_preview", None)
+ self._kwargs = kwargs
+
+ self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline")
+
+ @property
+ def _component_operations(self) -> ComponentOperations:
+ return cast(
+ ComponentOperations,
+ self._all_operations.get_operation( # type: ignore[misc]
+ AzureMLResourceType.COMPONENT, lambda x: isinstance(x, ComponentOperations)
+ ),
+ )
+
+ @property
+ def _compute_operations(self) -> ComputeOperations:
+ return cast(
+ ComputeOperations,
+ self._all_operations.get_operation( # type: ignore[misc]
+ AzureMLResourceType.COMPUTE, lambda x: isinstance(x, ComputeOperations)
+ ),
+ )
+
+ @property
+ def _virtual_cluster_operations(self) -> VirtualClusterOperations:
+ return cast(
+ VirtualClusterOperations,
+ self._all_operations.get_operation( # type: ignore[misc]
+ AzureMLResourceType.VIRTUALCLUSTER,
+ lambda x: isinstance(x, VirtualClusterOperations),
+ ),
+ )
+
+ @property
+ def _datastore_operations(self) -> "DatastoreOperations":
+ from azure.ai.ml.operations import DatastoreOperations
+
+ return cast(DatastoreOperations, self._all_operations.all_operations[AzureMLResourceType.DATASTORE])
+
+ @property
+ def _runs_operations(self) -> RunOperations:
+ if not self._runs_operations_client:
+ service_client_run_history = ServiceClientRunHistory(
+ self._credential, base_url=self._api_url, **self._service_client_kwargs
+ )
+ self._runs_operations_client = RunOperations(
+ self._operation_scope, self._operation_config, service_client_run_history
+ )
+ return self._runs_operations_client
+
+ @property
+ def _dataset_dataplane_operations(self) -> DatasetDataplaneOperations:
+ if not self._dataset_dataplane_operations_client:
+ service_client_dataset_dataplane = ServiceClientDatasetDataplane(
+ self._credential, base_url=self._api_url, **self._service_client_kwargs
+ )
+ self._dataset_dataplane_operations_client = DatasetDataplaneOperations(
+ self._operation_scope,
+ self._operation_config,
+ service_client_dataset_dataplane,
+ )
+ return self._dataset_dataplane_operations_client
+
+ @property
+ def _model_dataplane_operations(self) -> ModelDataplaneOperations:
+ if not self._model_dataplane_operations_client:
+ service_client_model_dataplane = ServiceClientModelDataplane(
+ self._credential, base_url=self._api_url, **self._service_client_kwargs
+ )
+ self._model_dataplane_operations_client = ModelDataplaneOperations(
+ self._operation_scope,
+ self._operation_config,
+ service_client_model_dataplane,
+ )
+ return self._model_dataplane_operations_client
+
+ @property
+ def _api_url(self) -> str:
+ if not self._api_base_url:
+ self._api_base_url = self._get_workspace_url(url_key=WorkspaceDiscoveryUrlKey.API)
+ return self._api_base_url
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Job.List", ActivityType.PUBLICAPI)
+ def list(
+ self,
+ *,
+ parent_job_name: Optional[str] = None,
+ list_view_type: ListViewType = ListViewType.ACTIVE_ONLY,
+ **kwargs: Any,
+ ) -> Iterable[Job]:
+ """Lists jobs in the workspace.
+
+ :keyword parent_job_name: When provided, only returns jobs that are children of the named job. Defaults to None,
+ listing all jobs in the workspace.
+ :paramtype parent_job_name: Optional[str]
+ :keyword list_view_type: The view type for including/excluding archived jobs. Defaults to
+ ~azure.mgt.machinelearningservices.models.ListViewType.ACTIVE_ONLY, excluding archived jobs.
+ :paramtype list_view_type: ~azure.mgmt.machinelearningservices.models.ListViewType
+ :return: An iterator-like instance of Job objects.
+ :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.Job]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START job_operations_list]
+ :end-before: [END job_operations_list]
+ :language: python
+ :dedent: 8
+ :caption: Retrieving a list of the archived jobs in a workspace with parent job named
+ "iris-dataset-jobs".
+ """
+
+ schedule_defined = kwargs.pop("schedule_defined", None)
+ scheduled_job_name = kwargs.pop("scheduled_job_name", None)
+ max_results = kwargs.pop("max_results", None)
+
+ if parent_job_name:
+ parent_job = self.get(parent_job_name)
+ return self._runs_operations.get_run_children(parent_job.name, max_results=max_results)
+
+ return cast(
+ Iterable[Job],
+ self.service_client_01_2024_preview.jobs.list(
+ self._operation_scope.resource_group_name,
+ self._workspace_name,
+ cls=lambda objs: [self._handle_rest_errors(obj) for obj in objs],
+ list_view_type=list_view_type,
+ scheduled=schedule_defined,
+ schedule_id=scheduled_job_name,
+ **self._kwargs,
+ **kwargs,
+ ),
+ )
+
+ def _handle_rest_errors(self, job_object: Union[JobBase, Run]) -> Optional[Job]:
+ """Handle errors while resolving azureml_id's during list operation.
+
+ :param job_object: The REST object to turn into a Job
+ :type job_object: Union[JobBase, Run]
+ :return: The resolved job
+ :rtype: Optional[Job]
+ """
+ try:
+ return self._resolve_azureml_id(Job._from_rest_object(job_object))
+ except JobParsingError:
+ return None
+
+ @distributed_trace
+ @monitor_with_telemetry_mixin(ops_logger, "Job.Get", ActivityType.PUBLICAPI)
+ def get(self, name: str) -> Job:
+ """Gets a job resource.
+
+ :param name: The name of the job.
+ :type name: str
+ :raises azure.core.exceptions.ResourceNotFoundError: Raised if no job with the given name can be found.
+ :raises ~azure.ai.ml.exceptions.UserErrorException: Raised if the name parameter is not a string.
+ :return: Job object retrieved from the service.
+ :rtype: ~azure.ai.ml.entities.Job
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START job_operations_get]
+ :end-before: [END job_operations_get]
+ :language: python
+ :dedent: 8
+ :caption: Retrieving a job named "iris-dataset-job-1".
+ """
+ if not isinstance(name, str):
+ raise UserErrorException(f"{name} is a invalid input for client.jobs.get().")
+ job_object = self._get_job(name)
+
+ job: Any = None
+ if not _is_pipeline_child_job(job_object):
+ job = Job._from_rest_object(job_object)
+ if job_object.properties.job_type != RestJobType.AUTO_ML:
+ # resolvers do not work with the old contract, leave the ids as is
+ job = self._resolve_azureml_id(job)
+ else:
+ # Child jobs are no longer available through MFE, fetch
+ # through run history instead
+ job = self._runs_operations._translate_from_rest_object(self._runs_operations.get_run(name))
+
+ return job
+
+ @distributed_trace
+ @monitor_with_telemetry_mixin(ops_logger, "Job.ShowServices", ActivityType.PUBLICAPI)
+ def show_services(self, name: str, node_index: int = 0) -> Optional[Dict[str, ServiceInstance]]:
+ """Gets services associated with a job's node.
+
+ :param name: The name of the job.
+ :type name: str
+ :param node_index: The node's index (zero-based). Defaults to 0.
+ :type node_index: int
+ :return: The services associated with the job for the given node.
+ :rtype: dict[str, ~azure.ai.ml.entities.ServiceInstance]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START job_operations_show_services]
+ :end-before: [END job_operations_show_services]
+ :language: python
+ :dedent: 8
+ :caption: Retrieving the services associated with a job's 1st node.
+ """
+
+ service_instances_dict = self._runs_operations._operation.get_run_service_instances(
+ self._subscription_id,
+ self._operation_scope.resource_group_name,
+ self._workspace_name,
+ name,
+ node_index,
+ )
+ if not service_instances_dict.instances:
+ return None
+
+ return {
+ k: ServiceInstance._from_rest_object(v, node_index) for k, v in service_instances_dict.instances.items()
+ }
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Job.Cancel", ActivityType.PUBLICAPI)
+ def begin_cancel(self, name: str, **kwargs: Any) -> LROPoller[None]:
+ """Cancels a job.
+
+ :param name: The name of the job.
+ :type name: str
+ :raises azure.core.exceptions.ResourceNotFoundError: Raised if no job with the given name can be found.
+ :return: A poller to track the operation status.
+ :rtype: ~azure.core.polling.LROPoller[None]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START job_operations_begin_cancel]
+ :end-before: [END job_operations_begin_cancel]
+ :language: python
+ :dedent: 8
+ :caption: Canceling the job named "iris-dataset-job-1" and checking the poller for status.
+ """
+ tag = kwargs.pop("tag", None)
+
+ if not tag:
+ return self._operation_2023_02_preview.begin_cancel(
+ id=name,
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ **self._kwargs,
+ **kwargs,
+ )
+
+ # Note: Below batch cancel is experimental and for private usage
+ results = []
+ jobs = self.list(tag=tag)
+ # TODO: Do we need to show error message when no jobs is returned for the given tag?
+ for job in jobs:
+ result = self._operation_2023_02_preview.begin_cancel(
+ id=job.name,
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ **self._kwargs,
+ )
+ results.append(result)
+ return results
+
+ def _try_get_compute_arm_id(self, compute: Union[Compute, str]) -> Optional[Union[Compute, str]]:
+ # pylint: disable=too-many-return-statements
+ # TODO: Remove in PuP with native import job/component type support in MFE/Designer
+ # DataFactory 'clusterless' job
+ if str(compute) == ComputeType.ADF:
+ return compute
+
+ if compute is not None:
+ # Singularity
+ if isinstance(compute, str) and is_singularity_id_for_resource(compute):
+ return self._virtual_cluster_operations.get(compute)["id"]
+ if isinstance(compute, str) and is_singularity_full_name_for_resource(compute):
+ return self._orchestrators._get_singularity_arm_id_from_full_name(compute)
+ if isinstance(compute, str) and is_singularity_short_name_for_resource(compute):
+ return self._orchestrators._get_singularity_arm_id_from_short_name(compute)
+ # other compute
+ if is_ARM_id_for_resource(compute, resource_type=AzureMLResourceType.COMPUTE):
+ # compute is not a sub-workspace resource
+ compute_name = compute.split("/")[-1] # type: ignore
+ elif isinstance(compute, Compute):
+ compute_name = compute.name
+ elif isinstance(compute, str):
+ compute_name = compute
+ elif isinstance(compute, PipelineInput):
+ compute_name = str(compute)
+ else:
+ raise ValueError(
+ "compute must be either an arm id of Compute, a Compute object or a compute name but"
+ f" got {type(compute)}"
+ )
+
+ if is_data_binding_expression(compute_name):
+ return compute_name
+ if compute_name == SERVERLESS_COMPUTE:
+ return compute_name
+ try:
+ return self._compute_operations.get(compute_name).id
+ except ResourceNotFoundError as e:
+ # the original error is not helpful (Operation returned an invalid status 'Not Found'),
+ # so we raise a more helpful one
+ response = e.response
+ response.reason = "Not found compute with name {}".format(compute_name)
+ raise ResourceNotFoundError(response=response) from e
+ return None
+
+ @distributed_trace
+ @monitor_with_telemetry_mixin(ops_logger, "Job.Validate", ActivityType.PUBLICAPI)
+ def validate(self, job: Job, *, raise_on_failure: bool = False, **kwargs: Any) -> ValidationResult:
+ """Validates a Job object before submitting to the service. Anonymous assets may be created if there are inline
+ defined entities such as Component, Environment, and Code. Only pipeline jobs are supported for validation
+ currently.
+
+ :param job: The job object to be validated.
+ :type job: ~azure.ai.ml.entities.Job
+ :keyword raise_on_failure: Specifies if an error should be raised if validation fails. Defaults to False.
+ :paramtype raise_on_failure: bool
+ :return: A ValidationResult object containing all found errors.
+ :rtype: ~azure.ai.ml.entities.ValidationResult
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START job_operations_validate]
+ :end-before: [END job_operations_validate]
+ :language: python
+ :dedent: 8
+ :caption: Validating a PipelineJob object and printing out the found errors.
+ """
+ return self._validate(job, raise_on_failure=raise_on_failure, **kwargs)
+
+ @monitor_with_telemetry_mixin(ops_logger, "Job.Validate", ActivityType.INTERNALCALL)
+ def _validate(
+ self,
+ job: Job,
+ *,
+ raise_on_failure: bool = False,
+ # pylint:disable=unused-argument
+ **kwargs: Any,
+ ) -> ValidationResult:
+ """Implementation of validate.
+
+ Add this function to avoid calling validate() directly in
+ create_or_update(), which will impact telemetry statistics &
+ bring experimental warning in create_or_update().
+
+ :param job: The job to validate
+ :type job: Job
+ :keyword raise_on_failure: Whether to raise on validation failure
+ :paramtype raise_on_failure: bool
+ :return: The validation result
+ :rtype: ValidationResult
+ """
+ git_code_validation_result = PathAwareSchemaValidatableMixin._create_empty_validation_result()
+ # TODO: move this check to Job._validate after validation is supported for all job types
+ # If private features are enable and job has code value of type str we need to check
+ # that it is a valid git path case. Otherwise we should throw a ValidationException
+ # saying that the code value is not a valid code value
+ if (
+ hasattr(job, "code")
+ and job.code is not None
+ and isinstance(job.code, str)
+ and job.code.startswith(GIT_PATH_PREFIX)
+ and not is_private_preview_enabled()
+ ):
+ git_code_validation_result.append_error(
+ message=f"Invalid code value: {job.code}. Git paths are not supported.",
+ yaml_path="code",
+ )
+
+ if not isinstance(job, PathAwareSchemaValidatableMixin):
+
+ def error_func(msg: str, no_personal_data_msg: str) -> ValidationException:
+ return ValidationException(
+ message=msg,
+ no_personal_data_message=no_personal_data_msg,
+ error_target=ErrorTarget.JOB,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ return git_code_validation_result.try_raise(
+ raise_error=raise_on_failure,
+ error_func=error_func,
+ )
+
+ validation_result = job._validate(raise_error=raise_on_failure)
+ validation_result.merge_with(git_code_validation_result)
+ # fast return to avoid remote call if local validation not passed
+ # TODO: use remote call to validate the entire job after MFE API is ready
+ if validation_result.passed and isinstance(job, PipelineJob):
+ try:
+ job.compute = self._try_get_compute_arm_id(job.compute)
+ except Exception as e: # pylint: disable=W0718
+ validation_result.append_error(yaml_path="compute", message=str(e))
+
+ for node_name, node in job.jobs.items():
+ try:
+ # TODO(1979547): refactor, not all nodes have compute
+ if not isinstance(node, ControlFlowNode):
+ node.compute = self._try_get_compute_arm_id(node.compute)
+ except Exception as e: # pylint: disable=W0718
+ validation_result.append_error(yaml_path=f"jobs.{node_name}.compute", message=str(e))
+
+ validation_result.resolve_location_for_diagnostics(str(job._source_path))
+ return job._try_raise(validation_result, raise_error=raise_on_failure) # pylint: disable=protected-access
+
+ @distributed_trace
+ @monitor_with_telemetry_mixin(ops_logger, "Job.CreateOrUpdate", ActivityType.PUBLICAPI)
+ def create_or_update(
+ self,
+ job: Job,
+ *,
+ description: Optional[str] = None,
+ compute: Optional[str] = None,
+ tags: Optional[dict] = None,
+ experiment_name: Optional[str] = None,
+ skip_validation: bool = False,
+ **kwargs: Any,
+ ) -> Job:
+ """Creates or updates a job. If entities such as Environment or Code are defined inline, they'll be created
+ together with the job.
+
+ :param job: The job object.
+ :type job: ~azure.ai.ml.entities.Job
+ :keyword description: The job description.
+ :paramtype description: Optional[str]
+ :keyword compute: The compute target for the job.
+ :paramtype compute: Optional[str]
+ :keyword tags: The tags for the job.
+ :paramtype tags: Optional[dict]
+ :keyword experiment_name: The name of the experiment the job will be created under. If None is provided,
+ job will be created under experiment 'Default'.
+ :paramtype experiment_name: Optional[str]
+ :keyword skip_validation: Specifies whether or not to skip validation before creating or updating the job. Note
+ that validation for dependent resources such as an anonymous component will not be skipped. Defaults to
+ False.
+ :paramtype skip_validation: bool
+ :raises Union[~azure.ai.ml.exceptions.UserErrorException, ~azure.ai.ml.exceptions.ValidationException]: Raised
+ if Job cannot be successfully validated. Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.AssetException: Raised if Job assets
+ (e.g. Data, Code, Model, Environment) cannot be successfully validated.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.ModelException: Raised if Job model cannot be successfully validated.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.JobException: Raised if Job object or attributes correctly formatted.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty
+ directory.
+ :raises ~azure.ai.ml.exceptions.DockerEngineNotAvailableError: Raised if Docker Engine is not available for
+ local job.
+ :return: Created or updated job.
+ :rtype: ~azure.ai.ml.entities.Job
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START job_operations_create_and_update]
+ :end-before: [END job_operations_create_and_update]
+ :language: python
+ :dedent: 8
+ :caption: Creating a new job and then updating its compute.
+ """
+ if isinstance(job, BaseNode) and not (
+ isinstance(job, (Command, Spark))
+ ): # Command/Spark objects can be used directly
+ job = job._to_job()
+
+ # Set job properties before submission
+ if description is not None:
+ job.description = description
+ if compute is not None:
+ job.compute = compute
+ if tags is not None:
+ job.tags = tags
+ if experiment_name is not None:
+ job.experiment_name = experiment_name
+
+ if job.compute == LOCAL_COMPUTE_TARGET:
+ job.environment_variables[COMMON_RUNTIME_ENV_VAR] = "true" # type: ignore
+
+ # TODO: why we put log logic here instead of inside self._validate()?
+ try:
+ if not skip_validation:
+ self._validate(job, raise_on_failure=True)
+
+ # Create all dependent resources
+ self._resolve_arm_id_or_upload_dependencies(job)
+ except (ValidationException, ValidationError) as ex:
+ log_and_raise_error(ex)
+
+ git_props = get_git_properties()
+ # Do not add git props if they already exist in job properties.
+ # This is for update specifically-- if the user switches branches and tries to update
+ # their job, the request will fail since the git props will be repopulated.
+ # MFE does not allow existing properties to be updated, only for new props to be added
+ if not any(prop_name in job.properties for prop_name in git_props):
+ job.properties = {**job.properties, **git_props}
+ rest_job_resource = to_rest_job_object(job)
+
+ # Make a copy of self._kwargs instead of contaminate the original one
+ kwargs = {**self._kwargs}
+ # set headers with user aml token if job is a pipeline or has a user identity setting
+ if (rest_job_resource.properties.job_type == RestJobType.PIPELINE) or (
+ hasattr(rest_job_resource.properties, "identity")
+ and (isinstance(rest_job_resource.properties.identity, UserIdentity))
+ ):
+ self._set_headers_with_user_aml_token(kwargs)
+
+ result = self._create_or_update_with_different_version_api(rest_job_resource=rest_job_resource, **kwargs)
+
+ if is_local_run(result):
+ ws_base_url = self._all_operations.all_operations[
+ AzureMLResourceType.WORKSPACE
+ ]._operation._client._base_url
+ snapshot_id = start_run_if_local(
+ result,
+ self._credential,
+ ws_base_url,
+ self._requests_pipeline,
+ )
+ # in case of local run, the first create/update call to MFE returns the
+ # request for submitting to ES. Once we request to ES and start the run, we
+ # need to put the same body to MFE to append user tags etc.
+ if rest_job_resource.properties.job_type == RestJobType.PIPELINE:
+ job_object = self._get_job_2401(rest_job_resource.name)
+ else:
+ job_object = self._get_job(rest_job_resource.name)
+ if result.properties.tags is not None:
+ for tag_name, tag_value in rest_job_resource.properties.tags.items():
+ job_object.properties.tags[tag_name] = tag_value
+ if result.properties.properties is not None:
+ for (
+ prop_name,
+ prop_value,
+ ) in rest_job_resource.properties.properties.items():
+ job_object.properties.properties[prop_name] = prop_value
+ if snapshot_id is not None:
+ job_object.properties.properties["ContentSnapshotId"] = snapshot_id
+
+ result = self._create_or_update_with_different_version_api(rest_job_resource=job_object, **kwargs)
+
+ return self._resolve_azureml_id(Job._from_rest_object(result))
+
+ def _create_or_update_with_different_version_api(self, rest_job_resource: JobBase, **kwargs: Any) -> JobBase:
+ service_client_operation = self._operation_2023_02_preview
+ if rest_job_resource.properties.job_type == RestJobType_20241001Preview.FINE_TUNING:
+ service_client_operation = self.service_client_10_2024_preview.jobs
+ if rest_job_resource.properties.job_type == RestJobType.PIPELINE:
+ service_client_operation = self.service_client_01_2024_preview.jobs
+ if rest_job_resource.properties.job_type == RestJobType.AUTO_ML:
+ service_client_operation = self.service_client_01_2024_preview.jobs
+ if rest_job_resource.properties.job_type == RestJobType.SWEEP:
+ service_client_operation = self.service_client_01_2024_preview.jobs
+ if rest_job_resource.properties.job_type == RestJobType.COMMAND:
+ service_client_operation = self.service_client_01_2025_preview.jobs
+
+ result = service_client_operation.create_or_update(
+ id=rest_job_resource.name,
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ body=rest_job_resource,
+ **kwargs,
+ )
+
+ return result
+
+ def _create_or_update_with_latest_version_api(self, rest_job_resource: JobBase, **kwargs: Any) -> JobBase:
+ service_client_operation = self.service_client_01_2024_preview.jobs
+ result = service_client_operation.create_or_update(
+ id=rest_job_resource.name,
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ body=rest_job_resource,
+ **kwargs,
+ )
+
+ return result
+
+ def _archive_or_restore(self, name: str, is_archived: bool) -> None:
+ job_object = self._get_job(name)
+ if job_object.properties.job_type == RestJobType.PIPELINE:
+ job_object = self._get_job_2401(name)
+ if _is_pipeline_child_job(job_object):
+ raise PipelineChildJobError(job_id=job_object.id)
+ job_object.properties.is_archived = is_archived
+
+ self._create_or_update_with_different_version_api(rest_job_resource=job_object)
+
+ @distributed_trace
+ @monitor_with_telemetry_mixin(ops_logger, "Job.Archive", ActivityType.PUBLICAPI)
+ def archive(self, name: str) -> None:
+ """Archives a job.
+
+ :param name: The name of the job.
+ :type name: str
+ :raises azure.core.exceptions.ResourceNotFoundError: Raised if no job with the given name can be found.
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START job_operations_archive]
+ :end-before: [END job_operations_archive]
+ :language: python
+ :dedent: 8
+ :caption: Archiving a job.
+ """
+
+ self._archive_or_restore(name=name, is_archived=True)
+
+ @distributed_trace
+ @monitor_with_telemetry_mixin(ops_logger, "Job.Restore", ActivityType.PUBLICAPI)
+ def restore(self, name: str) -> None:
+ """Restores an archived job.
+
+ :param name: The name of the job.
+ :type name: str
+ :raises azure.core.exceptions.ResourceNotFoundError: Raised if no job with the given name can be found.
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START job_operations_restore]
+ :end-before: [END job_operations_restore]
+ :language: python
+ :dedent: 8
+ :caption: Restoring an archived job.
+ """
+
+ self._archive_or_restore(name=name, is_archived=False)
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Job.Stream", ActivityType.PUBLICAPI)
+ def stream(self, name: str) -> None:
+ """Streams the logs of a running job.
+
+ :param name: The name of the job.
+ :type name: str
+ :raises azure.core.exceptions.ResourceNotFoundError: Raised if no job with the given name can be found.
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START job_operations_stream_logs]
+ :end-before: [END job_operations_stream_logs]
+ :language: python
+ :dedent: 8
+ :caption: Streaming a running job.
+ """
+ job_object = self._get_job(name)
+
+ if _is_pipeline_child_job(job_object):
+ raise PipelineChildJobError(job_id=job_object.id)
+
+ self._stream_logs_until_completion(
+ self._runs_operations,
+ job_object,
+ self._datastore_operations,
+ requests_pipeline=self._requests_pipeline,
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Job.Download", ActivityType.PUBLICAPI)
+ def download(
+ self,
+ name: str,
+ *,
+ download_path: Union[PathLike, str] = ".",
+ output_name: Optional[str] = None,
+ all: bool = False, # pylint: disable=redefined-builtin
+ ) -> None:
+ """Downloads the logs and output of a job.
+
+ :param name: The name of a job.
+ :type name: str
+ :keyword download_path: The local path to be used as the download destination. Defaults to ".".
+ :paramtype download_path: Union[PathLike, str]
+ :keyword output_name: The name of the output to download. Defaults to None.
+ :paramtype output_name: Optional[str]
+ :keyword all: Specifies if all logs and named outputs should be downloaded. Defaults to False.
+ :paramtype all: bool
+ :raises ~azure.ai.ml.exceptions.JobException: Raised if Job is not yet in a terminal state.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.MlException: Raised if logs and outputs cannot be successfully downloaded.
+ Details will be provided in the error message.
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START job_operations_download]
+ :end-before: [END job_operations_download]
+ :language: python
+ :dedent: 8
+ :caption: Downloading all logs and named outputs of the job "job-1" into local directory "job-1-logs".
+ """
+ job_details = self.get(name)
+ # job is reused, get reused job to download
+ if (
+ job_details.properties.get(PipelineConstants.REUSED_FLAG_FIELD) == PipelineConstants.REUSED_FLAG_TRUE
+ and PipelineConstants.REUSED_JOB_ID in job_details.properties
+ ):
+ reused_job_name = job_details.properties[PipelineConstants.REUSED_JOB_ID]
+ reused_job_detail = self.get(reused_job_name)
+ module_logger.info(
+ "job %s reuses previous job %s, download from the reused job.",
+ name,
+ reused_job_name,
+ )
+ name, job_details = reused_job_name, reused_job_detail
+ job_status = job_details.status
+ if job_status not in RunHistoryConstants.TERMINAL_STATUSES:
+ msg = "This job is in state {}. Download is allowed only in states {}".format(
+ job_status, RunHistoryConstants.TERMINAL_STATUSES
+ )
+ raise JobException(
+ message=msg,
+ target=ErrorTarget.JOB,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ is_batch_job = (
+ job_details.tags.get("azureml.batchrun", None) == "true"
+ and job_details.tags.get("azureml.jobtype", None) != PipelineConstants.PIPELINE_JOB_TYPE
+ )
+ outputs = {}
+ download_path = Path(download_path)
+ artifact_directory_name = "artifacts"
+ output_directory_name = "named-outputs"
+
+ def log_missing_uri(what: str) -> None:
+ module_logger.debug(
+ 'Could not download %s for job "%s" (job status: %s)',
+ what,
+ job_details.name,
+ job_details.status,
+ )
+
+ if isinstance(job_details, SweepJob):
+ best_child_run_id = job_details.properties.get(SWEEP_JOB_BEST_CHILD_RUN_ID_PROPERTY_NAME, None)
+ if best_child_run_id:
+ self.download(
+ best_child_run_id,
+ download_path=download_path,
+ output_name=output_name,
+ all=all,
+ )
+ else:
+ log_missing_uri(what="from best child run")
+
+ if output_name:
+ # Don't need to download anything from the parent
+ return
+ # only download default artifacts (logs + default outputs) from parent
+ artifact_directory_name = "hd-artifacts"
+ output_name = None
+ all = False
+
+ if is_batch_job:
+ scoring_uri = self._get_batch_job_scoring_output_uri(job_details.name)
+ if scoring_uri:
+ outputs = {BATCH_JOB_CHILD_RUN_OUTPUT_NAME: scoring_uri}
+ else:
+ log_missing_uri("batch job scoring file")
+ elif output_name:
+ outputs = self._get_named_output_uri(name, output_name)
+
+ if output_name not in outputs:
+ log_missing_uri(what=f'output "{output_name}"')
+ elif all:
+ outputs = self._get_named_output_uri(name)
+
+ if DEFAULT_ARTIFACT_STORE_OUTPUT_NAME not in outputs:
+ log_missing_uri(what="logs")
+ else:
+ outputs = self._get_named_output_uri(name, DEFAULT_ARTIFACT_STORE_OUTPUT_NAME)
+
+ if DEFAULT_ARTIFACT_STORE_OUTPUT_NAME not in outputs:
+ log_missing_uri(what="logs")
+
+ # Download all requested artifacts
+ for item_name, uri in outputs.items():
+ if is_batch_job:
+ destination = download_path
+ elif item_name == DEFAULT_ARTIFACT_STORE_OUTPUT_NAME:
+ destination = download_path / artifact_directory_name
+ else:
+ destination = download_path / output_directory_name / item_name
+
+ module_logger.info("Downloading artifact %s to %s", uri, destination)
+ download_artifact_from_aml_uri(
+ uri=uri,
+ destination=destination, # type: ignore[arg-type]
+ datastore_operation=self._datastore_operations,
+ )
+
+ def _get_named_output_uri(
+ self, job_name: Optional[str], output_names: Optional[Union[Iterable[str], str]] = None
+ ) -> Dict[str, str]:
+ """Gets the URIs to the specified named outputs of job.
+
+ :param job_name: Run ID of the job
+ :type job_name: str
+ :param output_names: Either an output name, or an iterable of output names. If omitted, all outputs are
+ returned.
+ :type output_names: Optional[Union[Iterable[str], str]]
+ :return: Map of output_names to URIs. Note that URIs that could not be found will not be present in the map.
+ :rtype: Dict[str, str]
+ """
+
+ if isinstance(output_names, str):
+ output_names = {output_names}
+ elif output_names:
+ output_names = set(output_names)
+
+ outputs = get_job_output_uris_from_dataplane(
+ job_name,
+ self._runs_operations,
+ self._dataset_dataplane_operations,
+ self._model_dataplane_operations,
+ output_names=output_names,
+ )
+
+ missing_outputs: Set = set()
+ if output_names is not None:
+ missing_outputs = set(output_names).difference(outputs.keys())
+ else:
+ missing_outputs = set().difference(outputs.keys())
+
+ # Include default artifact store in outputs
+ if (not output_names) or DEFAULT_ARTIFACT_STORE_OUTPUT_NAME in missing_outputs:
+ try:
+ job = self.get(job_name)
+ artifact_store_uri = job.outputs[DEFAULT_ARTIFACT_STORE_OUTPUT_NAME]
+ if artifact_store_uri is not None and artifact_store_uri.path:
+ outputs[DEFAULT_ARTIFACT_STORE_OUTPUT_NAME] = artifact_store_uri.path
+ except (AttributeError, KeyError):
+ outputs[DEFAULT_ARTIFACT_STORE_OUTPUT_NAME] = SHORT_URI_FORMAT.format(
+ "workspaceartifactstore", f"ExperimentRun/dcid.{job_name}/"
+ )
+ missing_outputs.discard(DEFAULT_ARTIFACT_STORE_OUTPUT_NAME)
+
+ # A job's output is not always reported in the outputs dict, but
+ # doesn't currently have a user configurable location.
+ # Perform a search of known paths to find output
+ # TODO: Remove once job output locations are reliably returned from the service
+
+ default_datastore = self._datastore_operations.get_default().name
+
+ for name in missing_outputs:
+ potential_uris = [
+ SHORT_URI_FORMAT.format(default_datastore, f"azureml/{job_name}/{name}/"),
+ SHORT_URI_FORMAT.format(default_datastore, f"dataset/{job_name}/{name}/"),
+ ]
+
+ for potential_uri in potential_uris:
+ if aml_datastore_path_exists(potential_uri, self._datastore_operations):
+ outputs[name] = potential_uri
+ break
+
+ return outputs
+
+ def _get_batch_job_scoring_output_uri(self, job_name: str) -> Optional[str]:
+ uri = None
+ # Download scoring output, which is the "score" output of the child job named "batchscoring"
+ # Batch Jobs are pipeline jobs with only one child, so this should terminate after an iteration
+ for child in self._runs_operations.get_run_children(job_name):
+ uri = self._get_named_output_uri(child.name, BATCH_JOB_CHILD_RUN_OUTPUT_NAME).get(
+ BATCH_JOB_CHILD_RUN_OUTPUT_NAME, None
+ )
+ # After the correct child is found, break to prevent unnecessary looping
+ if uri is not None:
+ break
+ return uri
+
+ def _get_job(self, name: str) -> JobBase:
+ job = self.service_client_01_2024_preview.jobs.get(
+ id=name,
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ **self._kwargs,
+ )
+
+ if (
+ hasattr(job, "properties")
+ and job.properties
+ and hasattr(job.properties, "job_type")
+ and job.properties.job_type == RestJobType_20241001Preview.FINE_TUNING
+ ):
+ return self.service_client_10_2024_preview.jobs.get(
+ id=name,
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ **self._kwargs,
+ )
+
+ return job
+
+ # Upgrade api from 2023-04-01-preview to 2024-01-01-preview for pipeline job
+ # We can remove this function once `_get_job` function has also been upgraded to the same version with pipeline
+ def _get_job_2401(self, name: str) -> JobBase_2401:
+ service_client_operation = self.service_client_01_2024_preview.jobs
+ return service_client_operation.get(
+ id=name,
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ **self._kwargs,
+ )
+
+ def _get_workspace_url(self, url_key: WorkspaceDiscoveryUrlKey) -> str:
+ discovery_url = (
+ self._all_operations.all_operations[AzureMLResourceType.WORKSPACE]
+ .get(self._operation_scope.workspace_name)
+ .discovery_url
+ )
+ all_urls = json.loads(
+ download_text_from_url(
+ discovery_url,
+ create_requests_pipeline_with_retry(requests_pipeline=self._requests_pipeline),
+ )
+ )
+ return all_urls[url_key]
+
+ def _resolve_arm_id_or_upload_dependencies(self, job: Job) -> None:
+ """This method converts name or name:version to ARM id. Or it
+ registers/uploads nested dependencies.
+
+ :param job: the job resource entity
+ :type job: Job
+ :return: the job resource entity that nested dependencies are resolved
+ :rtype: Job
+ """
+
+ self._resolve_arm_id_or_azureml_id(job, self._orchestrators.get_asset_arm_id)
+
+ if isinstance(job, PipelineJob):
+ # Resolve top-level inputs
+ self._resolve_job_inputs(self._flatten_group_inputs(job.inputs), job._base_path)
+ # inputs in sub-pipelines has been resolved in
+ # self._resolve_arm_id_or_azureml_id(job, self._orchestrators.get_asset_arm_id)
+ # as they are part of the pipeline component
+ elif isinstance(job, AutoMLJob):
+ self._resolve_automl_job_inputs(job)
+ elif isinstance(job, FineTuningJob):
+ self._resolve_finetuning_job_inputs(job)
+ elif isinstance(job, DistillationJob):
+ self._resolve_distillation_job_inputs(job)
+ elif isinstance(job, Spark):
+ self._resolve_job_inputs(job._job_inputs.values(), job._base_path)
+ elif isinstance(job, Command):
+ # TODO: switch to use inputs of Command objects, once the inputs/outputs building
+ # logic is removed from the BaseNode constructor.
+ try:
+ self._resolve_job_inputs(job._job_inputs.values(), job._base_path)
+ except AttributeError:
+ # If the job object doesn't have "inputs" attribute, we don't need to resolve. E.g. AutoML jobs
+ pass
+ else:
+ try:
+ self._resolve_job_inputs(job.inputs.values(), job._base_path) # type: ignore
+ except AttributeError:
+ # If the job object doesn't have "inputs" attribute, we don't need to resolve. E.g. AutoML jobs
+ pass
+
+ def _resolve_automl_job_inputs(self, job: AutoMLJob) -> None:
+ """This method resolves the inputs for AutoML jobs.
+
+ :param job: the job resource entity
+ :type job: AutoMLJob
+ """
+ if isinstance(job, AutoMLJob):
+ self._resolve_job_input(job.training_data, job._base_path)
+ if job.validation_data is not None:
+ self._resolve_job_input(job.validation_data, job._base_path)
+ if hasattr(job, "test_data") and job.test_data is not None:
+ self._resolve_job_input(job.test_data, job._base_path)
+
+ def _resolve_finetuning_job_inputs(self, job: FineTuningJob) -> None:
+ """This method resolves the inputs for FineTuning jobs.
+
+ :param job: the job resource entity
+ :type job: FineTuningJob
+ """
+ from azure.ai.ml.entities._job.finetuning.finetuning_vertical import FineTuningVertical
+
+ if isinstance(job, FineTuningVertical):
+ # self._resolve_job_input(job.model, job._base_path)
+ self._resolve_job_input(job.training_data, job._base_path)
+ if job.validation_data is not None:
+ self._resolve_job_input(job.validation_data, job._base_path)
+
+ def _resolve_distillation_job_inputs(self, job: DistillationJob) -> None:
+ """This method resolves the inputs for Distillation jobs.
+
+ :param job: the job resource entity
+ :type job: DistillationJob
+ """
+ if isinstance(job, DistillationJob):
+ if job.training_data:
+ self._resolve_job_input(job.training_data, job._base_path)
+ if job.validation_data is not None:
+ self._resolve_job_input(job.validation_data, job._base_path)
+
+ def _resolve_azureml_id(self, job: Job) -> Job:
+ """This method converts ARM id to name or name:version for nested
+ entities.
+
+ :param job: the job resource entity
+ :type job: Job
+ :return: the job resource entity that nested dependencies are resolved
+ :rtype: Job
+ """
+ self._append_tid_to_studio_url(job)
+ self._resolve_job_inputs_arm_id(job)
+ return self._resolve_arm_id_or_azureml_id(job, self._orchestrators.resolve_azureml_id)
+
+ def _resolve_compute_id(self, resolver: _AssetResolver, target: Any) -> Any:
+ # special case for local runs
+ if target is not None and target.lower() == LOCAL_COMPUTE_TARGET:
+ return LOCAL_COMPUTE_TARGET
+ try:
+ modified_target_name = target
+ if target.lower().startswith(AzureMLResourceType.VIRTUALCLUSTER + "/"):
+ # Compute target can be either workspace-scoped compute type,
+ # or AML scoped VC. In the case of VC, resource name will be of form
+ # azureml:virtualClusters/<name> to disambiguate from azureml:name (which is always compute)
+ modified_target_name = modified_target_name[len(AzureMLResourceType.VIRTUALCLUSTER) + 1 :]
+ modified_target_name = LEVEL_ONE_NAMED_RESOURCE_ID_FORMAT.format(
+ self._operation_scope.subscription_id,
+ self._operation_scope.resource_group_name,
+ AZUREML_RESOURCE_PROVIDER,
+ AzureMLResourceType.VIRTUALCLUSTER,
+ modified_target_name,
+ )
+ return resolver(
+ modified_target_name,
+ azureml_type=AzureMLResourceType.VIRTUALCLUSTER,
+ sub_workspace_resource=False,
+ )
+ except Exception: # pylint: disable=W0718
+ return resolver(target, azureml_type=AzureMLResourceType.COMPUTE)
+
+ def _resolve_job_inputs(self, entries: Iterable[Union[Input, str, bool, int, float]], base_path: str) -> None:
+ """resolve job inputs as ARM id or remote url.
+
+ :param entries: An iterable of job inputs
+ :type entries: Iterable[Union[Input, str, bool, int, float]]
+ :param base_path: The base path
+ :type base_path: str
+ """
+ for entry in entries:
+ self._resolve_job_input(entry, base_path)
+
+ # TODO: move it to somewhere else?
+ @classmethod
+ def _flatten_group_inputs(
+ cls, inputs: Dict[str, Union[Input, str, bool, int, float]]
+ ) -> List[Union[Input, str, bool, int, float]]:
+ """Get flatten values from an InputDict.
+
+ :param inputs: The input dict
+ :type inputs: Dict[str, Union[Input, str, bool, int, float]]
+ :return: A list of values from the Input Dict
+ :rtype: List[Union[Input, str, bool, int, float]]
+ """
+ input_values: List = []
+ # Flatten inputs for pipeline job
+ for key, item in inputs.items():
+ if isinstance(item, _GroupAttrDict):
+ input_values.extend(item.flatten(group_parameter_name=key))
+ else:
+ if not isinstance(item, (str, bool, int, float)):
+ # skip resolving inferred optional input without path (in do-while + dynamic input case)
+ if isinstance(item._data, Input): # type: ignore
+ if not item._data.path and item._meta._is_inferred_optional: # type: ignore
+ continue
+ input_values.append(item._data) # type: ignore
+ return input_values
+
+ def _resolve_job_input(self, entry: Union[Input, str, bool, int, float], base_path: str) -> None:
+ """resolve job input as ARM id or remote url.
+
+ :param entry: The job input
+ :type entry: Union[Input, str, bool, int, float]
+ :param base_path: The base path
+ :type base_path: str
+ """
+
+ # path can be empty if the job was created from builder functions
+ if isinstance(entry, Input) and not entry.path:
+ msg = "Input path can't be empty for jobs."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.JOB,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+
+ if (
+ not isinstance(entry, Input)
+ or is_ARM_id_for_resource(entry.path)
+ or is_url(entry.path)
+ or is_data_binding_expression(entry.path) # literal value but set mode in pipeline yaml
+ ): # Literal value, ARM id or remote url. Pass through
+ return
+ try:
+ datastore_name = (
+ entry.datastore if hasattr(entry, "datastore") and entry.datastore else WORKSPACE_BLOB_STORE
+ )
+
+ # absolute local path, upload, transform to remote url
+ if os.path.isabs(entry.path): # type: ignore
+ if entry.type == AssetTypes.URI_FOLDER and not os.path.isdir(entry.path): # type: ignore
+ raise ValidationException(
+ message="There is no dir on target path: {}".format(entry.path),
+ target=ErrorTarget.JOB,
+ no_personal_data_message="There is no dir on target path",
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND,
+ )
+ if entry.type == AssetTypes.URI_FILE and not os.path.isfile(entry.path): # type: ignore
+ raise ValidationException(
+ message="There is no file on target path: {}".format(entry.path),
+ target=ErrorTarget.JOB,
+ no_personal_data_message="There is no file on target path",
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.FILE_OR_FOLDER_NOT_FOUND,
+ )
+ # absolute local path
+ entry.path = _upload_and_generate_remote_uri(
+ self._operation_scope,
+ self._datastore_operations,
+ entry.path,
+ datastore_name=datastore_name,
+ show_progress=self._show_progress,
+ )
+ # TODO : Move this part to a common place
+ if entry.type == AssetTypes.URI_FOLDER and entry.path and not entry.path.endswith("/"):
+ entry.path = entry.path + "/"
+ # Check for AzureML id, is there a better way?
+ elif ":" in entry.path or "@" in entry.path: # type: ignore
+ asset_type = AzureMLResourceType.DATA
+ if entry.type in [AssetTypes.MLFLOW_MODEL, AssetTypes.CUSTOM_MODEL]:
+ asset_type = AzureMLResourceType.MODEL
+
+ entry.path = self._orchestrators.get_asset_arm_id(entry.path, asset_type) # type: ignore
+ else: # relative local path, upload, transform to remote url
+ # Base path will be None for dsl pipeline component for now. We have 2 choices if the dsl pipeline
+ # function is imported from another file:
+ # 1) Use cwd as default base path;
+ # 2) Use the file path of the dsl pipeline function as default base path.
+ # Pick solution 1 for now as defining input path in the script to submit is a more common scenario.
+ local_path = Path(base_path or Path.cwd(), entry.path).resolve() # type: ignore
+ entry.path = _upload_and_generate_remote_uri(
+ self._operation_scope,
+ self._datastore_operations,
+ local_path,
+ datastore_name=datastore_name,
+ show_progress=self._show_progress,
+ )
+ # TODO : Move this part to a common place
+ if entry.type == AssetTypes.URI_FOLDER and entry.path and not entry.path.endswith("/"):
+ entry.path = entry.path + "/"
+ except (MlException, HttpResponseError) as e:
+ raise e
+ except Exception as e:
+ raise ValidationException(
+ message=f"Supported input path value are ARM id, AzureML id, remote uri or local path.\n"
+ f"Met {type(e)}:\n{e}",
+ target=ErrorTarget.JOB,
+ no_personal_data_message=(
+ "Supported input path value are ARM id, AzureML id, " "remote uri or local path."
+ ),
+ error=e,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ ) from e
+
+ def _resolve_job_inputs_arm_id(self, job: Job) -> None:
+ try:
+ inputs: Dict[str, Union[Input, InputOutputBase, str, bool, int, float]] = job.inputs # type: ignore
+ for _, entry in inputs.items():
+ if isinstance(entry, InputOutputBase):
+ # extract original input form input builder.
+ entry = entry._data
+ if not isinstance(entry, Input) or is_url(entry.path): # Literal value or remote url
+ continue
+ # ARM id
+ entry.path = self._orchestrators.resolve_azureml_id(entry.path)
+
+ except AttributeError:
+ # If the job object doesn't have "inputs" attribute, we don't need to resolve. E.g. AutoML jobs
+ pass
+
+ def _resolve_arm_id_or_azureml_id(self, job: Job, resolver: Union[Callable, _AssetResolver]) -> Job:
+ """Resolve arm_id for a given job.
+
+
+ :param job: The job
+ :type job: Job
+ :param resolver: The asset resolver function
+ :type resolver: _AssetResolver
+ :return: The provided job, with fields resolved to full ARM IDs
+ :rtype: Job
+ """
+ # TODO: this will need to be parallelized when multiple tasks
+ # are required. Also consider the implications for dependencies.
+
+ if isinstance(job, _BaseJob):
+ job.compute = self._resolve_compute_id(resolver, job.compute)
+ elif isinstance(job, Command):
+ job = self._resolve_arm_id_for_command_job(job, resolver)
+ elif isinstance(job, ImportJob):
+ job = self._resolve_arm_id_for_import_job(job, resolver)
+ elif isinstance(job, Spark):
+ job = self._resolve_arm_id_for_spark_job(job, resolver)
+ elif isinstance(job, ParallelJob):
+ job = self._resolve_arm_id_for_parallel_job(job, resolver)
+ elif isinstance(job, SweepJob):
+ job = self._resolve_arm_id_for_sweep_job(job, resolver)
+ elif isinstance(job, AutoMLJob):
+ job = self._resolve_arm_id_for_automl_job(job, resolver, inside_pipeline=False)
+ elif isinstance(job, PipelineJob):
+ job = self._resolve_arm_id_for_pipeline_job(job, resolver)
+ elif isinstance(job, FineTuningJob):
+ pass
+ elif isinstance(job, DistillationJob):
+ pass
+ else:
+ msg = f"Non supported job type: {type(job)}"
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.JOB,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ return job
+
+ def _resolve_arm_id_for_command_job(self, job: Command, resolver: _AssetResolver) -> Command:
+ """Resolve arm_id for CommandJob.
+
+
+ :param job: The Command job
+ :type job: Command
+ :param resolver: The asset resolver function
+ :type resolver: _AssetResolver
+ :return: The provided Command job, with resolved fields
+ :rtype: Command
+ """
+ if job.code is not None and is_registry_id_for_resource(job.code):
+ msg = "Format not supported for code asset: {}"
+ raise ValidationException(
+ message=msg.format(job.code),
+ target=ErrorTarget.JOB,
+ no_personal_data_message=msg.format("[job.code]"),
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ if job.code is not None and not is_ARM_id_for_resource(job.code, AzureMLResourceType.CODE):
+ job.code = resolver( # type: ignore
+ Code(base_path=job._base_path, path=job.code),
+ azureml_type=AzureMLResourceType.CODE,
+ )
+ job.environment = resolver(job.environment, azureml_type=AzureMLResourceType.ENVIRONMENT)
+ job.compute = self._resolve_compute_id(resolver, job.compute)
+ return job
+
+ def _resolve_arm_id_for_spark_job(self, job: Spark, resolver: _AssetResolver) -> Spark:
+ """Resolve arm_id for SparkJob.
+
+ :param job: The Spark job
+ :type job: Spark
+ :param resolver: The asset resolver function
+ :type resolver: _AssetResolver
+ :return: The provided SparkJob, with resolved fields
+ :rtype: Spark
+ """
+ if job.code is not None and is_registry_id_for_resource(job.code):
+ msg = "Format not supported for code asset: {}"
+ raise JobException(
+ message=msg.format(job.code),
+ target=ErrorTarget.JOB,
+ no_personal_data_message=msg.format("[job.code]"),
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ if job.code is not None and not is_ARM_id_for_resource(job.code, AzureMLResourceType.CODE):
+ job.code = resolver( # type: ignore
+ Code(base_path=job._base_path, path=job.code),
+ azureml_type=AzureMLResourceType.CODE,
+ )
+ job.environment = resolver(job.environment, azureml_type=AzureMLResourceType.ENVIRONMENT)
+ job.compute = self._resolve_compute_id(resolver, job.compute)
+ return job
+
+ def _resolve_arm_id_for_import_job(self, job: ImportJob, resolver: _AssetResolver) -> ImportJob:
+ """Resolve arm_id for ImportJob.
+
+ :param job: The Import job
+ :type job: ImportJob
+ :param resolver: The asset resolver function
+ :type resolver: _AssetResolver
+ :return: The provided ImportJob, with resolved fields
+ :rtype: ImportJob
+ """
+ # compute property will be no longer applicable once import job type is ready on MFE in PuP
+ # for PrP, we use command job type instead for import job where compute property is required
+ # However, MFE only validates compute resource url format. Execution service owns the real
+ # validation today but supports reserved compute names like AmlCompute, ContainerInstance and
+ # DataFactory here for 'clusterless' jobs
+ job.compute = self._resolve_compute_id(resolver, ComputeType.ADF)
+ return job
+
+ def _resolve_arm_id_for_parallel_job(self, job: ParallelJob, resolver: _AssetResolver) -> ParallelJob:
+ """Resolve arm_id for ParallelJob.
+
+ :param job: The Parallel job
+ :type job: ParallelJob
+ :param resolver: The asset resolver function
+ :type resolver: _AssetResolver
+ :return: The provided ParallelJob, with resolved fields
+ :rtype: ParallelJob
+ """
+ if job.code is not None and not is_ARM_id_for_resource(job.code, AzureMLResourceType.CODE): # type: ignore
+ job.code = resolver( # type: ignore
+ Code(base_path=job._base_path, path=job.code), # type: ignore
+ azureml_type=AzureMLResourceType.CODE,
+ )
+ job.environment = resolver(job.environment, azureml_type=AzureMLResourceType.ENVIRONMENT) # type: ignore
+ job.compute = self._resolve_compute_id(resolver, job.compute)
+ return job
+
+ def _resolve_arm_id_for_sweep_job(self, job: SweepJob, resolver: _AssetResolver) -> SweepJob:
+ """Resolve arm_id for SweepJob.
+
+ :param job: The Sweep job
+ :type job: SweepJob
+ :param resolver: The asset resolver function
+ :type resolver: _AssetResolver
+ :return: The provided SweepJob, with resolved fields
+ :rtype: SweepJob
+ """
+ if (
+ job.trial is not None
+ and job.trial.code is not None
+ and not is_ARM_id_for_resource(job.trial.code, AzureMLResourceType.CODE)
+ ):
+ job.trial.code = resolver( # type: ignore[assignment]
+ Code(base_path=job._base_path, path=job.trial.code),
+ azureml_type=AzureMLResourceType.CODE,
+ )
+ if (
+ job.trial is not None
+ and job.trial.environment is not None
+ and not is_ARM_id_for_resource(job.trial.environment, AzureMLResourceType.ENVIRONMENT)
+ ):
+ job.trial.environment = resolver( # type: ignore[assignment]
+ job.trial.environment, azureml_type=AzureMLResourceType.ENVIRONMENT
+ )
+ job.compute = self._resolve_compute_id(resolver, job.compute)
+ return job
+
+ def _resolve_arm_id_for_automl_job(
+ self, job: AutoMLJob, resolver: _AssetResolver, inside_pipeline: bool
+ ) -> AutoMLJob:
+ """Resolve arm_id for AutoMLJob.
+
+ :param job: The AutoML job
+ :type job: AutoMLJob
+ :param resolver: The asset resolver function
+ :type resolver: _AssetResolver
+ :param inside_pipeline: Whether the job is within a pipeline
+ :type inside_pipeline: bool
+ :return: The provided AutoMLJob, with resolved fields
+ :rtype: AutoMLJob
+ """
+ # AutoML does not have dependency uploads. Only need to resolve reference to arm id.
+
+ # automl node in pipeline has optional compute
+ if inside_pipeline and job.compute is None:
+ return job
+ job.compute = resolver(job.compute, azureml_type=AzureMLResourceType.COMPUTE)
+ return job
+
+ def _resolve_arm_id_for_pipeline_job(self, pipeline_job: PipelineJob, resolver: _AssetResolver) -> PipelineJob:
+ """Resolve arm_id for pipeline_job.
+
+ :param pipeline_job: The pipeline job
+ :type pipeline_job: PipelineJob
+ :param resolver: The asset resolver function
+ :type resolver: _AssetResolver
+ :return: The provided PipelineJob, with resolved fields
+ :rtype: PipelineJob
+ """
+ # Get top-level job compute
+ _get_job_compute_id(pipeline_job, resolver)
+
+ # Process job defaults:
+ if pipeline_job.settings:
+ pipeline_job.settings.default_datastore = resolver(
+ pipeline_job.settings.default_datastore,
+ azureml_type=AzureMLResourceType.DATASTORE,
+ )
+ pipeline_job.settings.default_compute = resolver(
+ pipeline_job.settings.default_compute,
+ azureml_type=AzureMLResourceType.COMPUTE,
+ )
+
+ # Process each component job
+ try:
+ self._component_operations._resolve_dependencies_for_pipeline_component_jobs(
+ pipeline_job.component, resolver
+ )
+ except ComponentException as e:
+ raise JobException(
+ message=e.message,
+ target=ErrorTarget.JOB,
+ no_personal_data_message=e.no_personal_data_message,
+ error_category=e.error_category,
+ ) from e
+
+ # Create a pipeline component for pipeline job if user specified component in job yaml.
+ if (
+ not isinstance(pipeline_job.component, str)
+ and getattr(pipeline_job.component, "_source", None) == ComponentSource.YAML_COMPONENT
+ ):
+ pipeline_job.component = resolver( # type: ignore
+ pipeline_job.component,
+ azureml_type=AzureMLResourceType.COMPONENT,
+ )
+
+ return pipeline_job
+
+ def _append_tid_to_studio_url(self, job: Job) -> None:
+ """Appends the user's tenant ID to the end of the studio URL.
+
+ Allows the UI to authenticate against the correct tenant.
+
+ :param job: The job
+ :type job: Job
+ """
+ try:
+ if job.services is not None:
+ studio_endpoint = job.services.get("Studio", None)
+ studio_url = studio_endpoint.endpoint
+ default_scopes = _resource_to_scopes(_get_base_url_from_metadata())
+ module_logger.debug("default_scopes used: `%s`\n", default_scopes)
+ # Extract the tenant id from the credential using PyJWT
+ decode = jwt.decode(
+ self._credential.get_token(*default_scopes).token,
+ options={"verify_signature": False, "verify_aud": False},
+ )
+ tid = decode["tid"]
+ formatted_tid = TID_FMT.format(tid)
+ studio_endpoint.endpoint = studio_url + formatted_tid
+ except Exception: # pylint: disable=W0718
+ module_logger.info("Proceeding with no tenant id appended to studio URL\n")
+
+ def _set_headers_with_user_aml_token(self, kwargs: Any) -> None:
+ aml_resource_id = _get_aml_resource_id_from_metadata()
+ azure_ml_scopes = _resource_to_scopes(aml_resource_id)
+ module_logger.debug("azure_ml_scopes used: `%s`\n", azure_ml_scopes)
+ aml_token = self._credential.get_token(*azure_ml_scopes).token
+ # validate token has aml audience
+ decoded_token = jwt.decode(
+ aml_token,
+ options={"verify_signature": False, "verify_aud": False},
+ )
+ if decoded_token.get("aud") != aml_resource_id:
+ msg = """AAD token with aml scope could not be fetched using the credentials being used.
+ Please validate if token with {0} scope can be fetched using credentials provided to MLClient.
+ Token with {0} scope can be fetched using credentials.get_token({0})
+ """
+ raise ValidationException(
+ message=msg.format(*azure_ml_scopes),
+ target=ErrorTarget.JOB,
+ error_type=ValidationErrorType.RESOURCE_NOT_FOUND,
+ no_personal_data_message=msg.format("[job.code]"),
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ headers = kwargs.pop("headers", {})
+ headers["x-azureml-token"] = aml_token
+ kwargs["headers"] = headers
+
+
+def _get_job_compute_id(job: Union[Job, Command], resolver: _AssetResolver) -> None:
+ job.compute = resolver(job.compute, azureml_type=AzureMLResourceType.COMPUTE)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_job_ops_helper.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_job_ops_helper.py
new file mode 100644
index 00000000..4c0802d9
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_job_ops_helper.py
@@ -0,0 +1,513 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import json
+import logging
+import os
+import re
+import subprocess
+import sys
+import time
+from typing import Any, Dict, Iterable, List, Optional, TextIO, Union
+
+from azure.ai.ml._artifacts._artifact_utilities import get_datastore_info, list_logs_in_datastore
+from azure.ai.ml._restclient.runhistory.models import Run, RunDetails, TypedAssetReference
+from azure.ai.ml._restclient.v2022_02_01_preview.models import DataType
+from azure.ai.ml._restclient.v2022_02_01_preview.models import JobType as RestJobType
+from azure.ai.ml._restclient.v2022_02_01_preview.models import ModelType
+from azure.ai.ml._restclient.v2022_10_01.models import JobBase
+from azure.ai.ml._utils._http_utils import HttpPipeline
+from azure.ai.ml._utils.utils import create_requests_pipeline_with_retry, download_text_from_url
+from azure.ai.ml.constants._common import GitProperties
+from azure.ai.ml.constants._job.job import JobLogPattern, JobType
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, JobException
+from azure.ai.ml.operations._dataset_dataplane_operations import DatasetDataplaneOperations
+from azure.ai.ml.operations._datastore_operations import DatastoreOperations
+from azure.ai.ml.operations._model_dataplane_operations import ModelDataplaneOperations
+from azure.ai.ml.operations._run_history_constants import JobStatus, RunHistoryConstants
+from azure.ai.ml.operations._run_operations import RunOperations
+
+STATUS_KEY = "status"
+
+module_logger = logging.getLogger(__name__)
+
+
+def _get_sorted_filtered_logs(
+ logs_iterable: Iterable[str],
+ job_type: str,
+ processed_logs: Optional[Dict[str, int]] = None,
+ only_streamable: bool = True,
+) -> List[str]:
+ """Filters log file names, sorts, and returns list starting with where we left off last iteration.
+
+ :param logs_iterable: An iterable of log paths.
+ :type logs_iterable: Iterable[str]
+ :param job_type: the job type to filter log files
+ :type job_type: str
+ :param processed_logs: dictionary tracking the state of how many lines of each file have been written out
+ :type processed_logs: dict[str, int]
+ :param only_streamable: Whether to only get streamable logs
+ :type only_streamable: bool
+ :return: List of logs to continue from
+ :rtype: list[str]
+ """
+ processed_logs = processed_logs if processed_logs else {}
+ # First, attempt to read logs in new Common Runtime form
+ output_logs_pattern = (
+ JobLogPattern.COMMON_RUNTIME_STREAM_LOG_PATTERN
+ if only_streamable
+ else JobLogPattern.COMMON_RUNTIME_ALL_USER_LOG_PATTERN
+ )
+ logs = list(logs_iterable)
+ filtered_logs = [x for x in logs if re.match(output_logs_pattern, x)]
+
+ # fall back to legacy log format
+ if filtered_logs is None or len(filtered_logs) == 0:
+ job_type = job_type.lower()
+ if job_type in JobType.COMMAND:
+ output_logs_pattern = JobLogPattern.COMMAND_JOB_LOG_PATTERN
+ elif job_type in JobType.PIPELINE:
+ output_logs_pattern = JobLogPattern.PIPELINE_JOB_LOG_PATTERN
+ elif job_type in JobType.SWEEP:
+ output_logs_pattern = JobLogPattern.SWEEP_JOB_LOG_PATTERN
+
+ filtered_logs = [x for x in logs if re.match(output_logs_pattern, x)]
+ filtered_logs.sort()
+ previously_printed_index = 0
+ for i, v in enumerate(filtered_logs):
+ if processed_logs.get(v):
+ previously_printed_index = i
+ else:
+ break
+ # Slice inclusive from the last printed log (can be updated before printing new files)
+ return filtered_logs[previously_printed_index:]
+
+
+def _incremental_print(log: str, processed_logs: Dict[str, int], current_log_name: str, fileout: TextIO) -> None:
+ """Incremental print.
+
+ :param log:
+ :type log: str
+ :param processed_logs: The record of how many lines have been written for each log file
+ :type processed_logs: dict[str, int]
+ :param current_log_name: the file name being read out, used in header writing and accessing processed_logs
+ :type current_log_name: str
+ :param fileout:
+ :type fileout: TestIOWrapper
+ """
+ lines = log.splitlines()
+ doc_length = len(lines)
+ if doc_length == 0:
+ # If a file is empty, skip writing out.
+ # This addresses issue where batch endpoint jobs can create log files before they are needed.
+ return
+ previous_printed_lines = processed_logs.get(current_log_name, 0)
+ # when a new log is first being written to console, print spacing and label
+ if previous_printed_lines == 0:
+ fileout.write("\n")
+ fileout.write("Streaming " + current_log_name + "\n")
+ fileout.write("=" * (len(current_log_name) + 10) + "\n")
+ fileout.write("\n")
+ # otherwise, continue to log the file where we left off
+ for line in lines[previous_printed_lines:]:
+ fileout.write(line + "\n")
+ # update state to record number of lines written for this log file
+ processed_logs[current_log_name] = doc_length
+
+
+def _get_last_log_primary_instance(logs: List) -> Any:
+ """Return last log for primary instance.
+
+ :param logs:
+ :type logs: builtin.list
+ :return: Returns the last log primary instance.
+ :rtype:
+ """
+ primary_ranks = ["rank_0", "worker_0"]
+ rank_match_re = re.compile(r"(.*)_(.*?_.*?)\.txt")
+ last_log_name = logs[-1]
+
+ last_log_match = rank_match_re.match(last_log_name)
+ if not last_log_match:
+ return last_log_name
+
+ last_log_prefix = last_log_match.group(1)
+ matching_logs = sorted(filter(lambda x: x.startswith(last_log_prefix), logs))
+
+ # we have some specific ranks that denote the primary, use those if found
+ for log_name in matching_logs:
+ match = rank_match_re.match(log_name)
+ if not match:
+ continue
+ if match.group(2) in primary_ranks:
+ return log_name
+
+ # no definitively primary instance, just return the highest sorted
+ return matching_logs[0]
+
+
+def _wait_before_polling(current_seconds: float) -> int:
+ if current_seconds < 0:
+ msg = "current_seconds must be positive"
+ raise JobException(
+ message=msg,
+ target=ErrorTarget.JOB,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ import math
+
+ # Sigmoid that tapers off near the_get_logs max at ~ 3 min
+ duration = int(
+ int(RunHistoryConstants._WAIT_COMPLETION_POLLING_INTERVAL_MAX) / (1.0 + 100 * math.exp(-current_seconds / 20.0))
+ )
+ return max(int(RunHistoryConstants._WAIT_COMPLETION_POLLING_INTERVAL_MIN), duration)
+
+
+def list_logs(run_operations: RunOperations, job_resource: JobBase) -> Dict:
+ details: RunDetails = run_operations.get_run_details(job_resource.name)
+ logs_dict = details.log_files
+ keys = _get_sorted_filtered_logs(logs_dict, job_resource.properties.job_type)
+ return {key: logs_dict[key] for key in keys}
+
+
+# pylint: disable=too-many-statements,too-many-locals
+def stream_logs_until_completion(
+ run_operations: RunOperations,
+ job_resource: JobBase,
+ datastore_operations: Optional[DatastoreOperations] = None,
+ raise_exception_on_failed_job: bool = True,
+ *,
+ requests_pipeline: HttpPipeline
+) -> None:
+ """Stream the experiment run output to the specified file handle. By default the the file handle points to stdout.
+
+ :param run_operations: The run history operations class.
+ :type run_operations: RunOperations
+ :param job_resource: The job to stream
+ :type job_resource: JobBase
+ :param datastore_operations: Optional, the datastore operations class, used to get logs from datastore
+ :type datastore_operations: Optional[DatastoreOperations]
+ :param raise_exception_on_failed_job: Should this method fail if job fails
+ :type raise_exception_on_failed_job: Boolean
+ :keyword requests_pipeline: The HTTP pipeline to use for requests.
+ :type requests_pipeline: ~azure.ai.ml._utils._http_utils.HttpPipeline
+ :return:
+ :rtype: None
+ """
+ job_type = job_resource.properties.job_type
+ job_name = job_resource.name
+ studio_endpoint = job_resource.properties.services.get("Studio", None)
+ studio_endpoint = studio_endpoint.endpoint if studio_endpoint else None
+ # Feature store jobs should be linked to the Feature Store Workspace UI.
+ # Todo: Consolidate this logic to service side
+ if "azureml.FeatureStoreJobType" in job_resource.properties.properties:
+ url_format = (
+ "https://ml.azure.com/featureStore/{fs_name}/featureSets/{fset_name}/{fset_version}/matJobs/"
+ "jobs/{run_id}?wsid=/subscriptions/{fs_sub_id}/resourceGroups/{fs_rg_name}/providers/"
+ "Microsoft.MachineLearningServices/workspaces/{fs_name}"
+ )
+ studio_endpoint = url_format.format(
+ fs_name=job_resource.properties.properties["azureml.FeatureStoreName"],
+ fs_sub_id=run_operations._subscription_id,
+ fs_rg_name=run_operations._resource_group_name,
+ fset_name=job_resource.properties.properties["azureml.FeatureSetName"],
+ fset_version=job_resource.properties.properties["azureml.FeatureSetVersion"],
+ run_id=job_name,
+ )
+ elif "FeatureStoreJobType" in job_resource.properties.properties:
+ url_format = (
+ "https://ml.azure.com/featureStore/{fs_name}/featureSets/{fset_name}/{fset_version}/matJobs/"
+ "jobs/{run_id}?wsid=/subscriptions/{fs_sub_id}/resourceGroups/{fs_rg_name}/providers/"
+ "Microsoft.MachineLearningServices/workspaces/{fs_name}"
+ )
+ studio_endpoint = url_format.format(
+ fs_name=job_resource.properties.properties["FeatureStoreName"],
+ fs_sub_id=run_operations._subscription_id,
+ fs_rg_name=run_operations._resource_group_name,
+ fset_name=job_resource.properties.properties["FeatureSetName"],
+ fset_version=job_resource.properties.properties["FeatureSetVersion"],
+ run_id=job_name,
+ )
+
+ file_handle = sys.stdout
+ ds_properties = None
+ prefix = None
+ if (
+ hasattr(job_resource.properties, "outputs")
+ and job_resource.properties.job_type != RestJobType.AUTO_ML
+ and datastore_operations
+ ):
+ # Get default output location
+
+ default_output = (
+ job_resource.properties.outputs.get("default", None) if job_resource.properties.outputs else None
+ )
+ is_uri_folder = default_output and default_output.job_output_type == DataType.URI_FOLDER
+ if is_uri_folder:
+ output_uri = default_output.uri # type: ignore
+ # Parse the uri format
+ output_uri = output_uri.split("datastores/")[1]
+ datastore_name, prefix = output_uri.split("/", 1)
+ ds_properties = get_datastore_info(datastore_operations, datastore_name)
+
+ try:
+ file_handle.write("RunId: {}\n".format(job_name))
+ file_handle.write("Web View: {}\n".format(studio_endpoint))
+
+ _current_details: RunDetails = run_operations.get_run_details(job_name)
+
+ processed_logs: Dict = {}
+
+ poll_start_time = time.time()
+ pipeline_with_retries = create_requests_pipeline_with_retry(requests_pipeline=requests_pipeline)
+ while (
+ _current_details.status in RunHistoryConstants.IN_PROGRESS_STATUSES
+ or _current_details.status == JobStatus.FINALIZING
+ ):
+ file_handle.flush()
+ time.sleep(_wait_before_polling(time.time() - poll_start_time))
+ _current_details = run_operations.get_run_details(job_name) # TODO use FileWatcher
+ if job_type.lower() in JobType.PIPELINE:
+ legacy_folder_name = "/logs/azureml/"
+ else:
+ legacy_folder_name = "/azureml-logs/"
+ _current_logs_dict = (
+ list_logs_in_datastore(
+ ds_properties,
+ prefix=str(prefix),
+ legacy_log_folder_name=legacy_folder_name,
+ )
+ if ds_properties is not None
+ else _current_details.log_files
+ )
+ # Get the list of new logs available after filtering out the processed ones
+ available_logs = _get_sorted_filtered_logs(_current_logs_dict, job_type, processed_logs)
+ content = ""
+ for current_log in available_logs:
+ content = download_text_from_url(
+ _current_logs_dict[current_log],
+ pipeline_with_retries,
+ timeout=RunHistoryConstants._DEFAULT_GET_CONTENT_TIMEOUT,
+ )
+
+ _incremental_print(content, processed_logs, current_log, file_handle)
+
+ # TODO: Temporary solution to wait for all the logs to be printed in the finalizing state.
+ if (
+ _current_details.status not in RunHistoryConstants.IN_PROGRESS_STATUSES
+ and _current_details.status == JobStatus.FINALIZING
+ and "The activity completed successfully. Finalizing run..." in content
+ ):
+ break
+
+ file_handle.write("\n")
+ file_handle.write("Execution Summary\n")
+ file_handle.write("=================\n")
+ file_handle.write("RunId: {}\n".format(job_name))
+ file_handle.write("Web View: {}\n".format(studio_endpoint))
+
+ warnings = _current_details.warnings
+ if warnings:
+ messages = [x.message for x in warnings if x.message]
+ if len(messages) > 0:
+ file_handle.write("\nWarnings:\n")
+ for message in messages:
+ file_handle.write(message + "\n")
+ file_handle.write("\n")
+
+ if _current_details.status == JobStatus.FAILED:
+ error = (
+ _current_details.error.as_dict()
+ if _current_details.error
+ else "Detailed error not set on the Run. Please check the logs for details."
+ )
+ # If we are raising the error later on, so we don't double print.
+ if not raise_exception_on_failed_job:
+ file_handle.write("\nError:\n")
+ file_handle.write(json.dumps(error, indent=4))
+ file_handle.write("\n")
+ else:
+ raise JobException(
+ message="Exception : \n {} ".format(json.dumps(error, indent=4)),
+ target=ErrorTarget.JOB,
+ no_personal_data_message="Exception raised on failed job.",
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ )
+
+ file_handle.write("\n")
+ file_handle.flush()
+ except KeyboardInterrupt as e:
+ error_message = (
+ "The output streaming for the run interrupted.\n"
+ "But the run is still executing on the compute target. \n"
+ "Details for canceling the run can be found here: "
+ "https://aka.ms/aml-docs-cancel-run"
+ )
+ raise JobException(
+ message=error_message,
+ target=ErrorTarget.JOB,
+ no_personal_data_message=error_message,
+ error_category=ErrorCategory.USER_ERROR,
+ ) from e
+
+
+def get_git_properties() -> Dict[str, str]:
+ """Gather Git tracking info from the local environment.
+
+ :return: Properties dictionary.
+ :rtype: dict
+ """
+
+ def _clean_git_property_bool(value: Any) -> Optional[bool]:
+ if value is None:
+ return None
+ return str(value).strip().lower() in ["true", "1"]
+
+ def _clean_git_property_str(value: Any) -> Optional[str]:
+ if value is None:
+ return None
+ return str(value).strip() or None
+
+ def _run_git_cmd(args: Iterable[str]) -> Optional[str]:
+ """Runs git with the provided arguments
+
+ :param args: A iterable of arguments for a git command. Should not include leading "git"
+ :type args: Iterable[str]
+ :return: The output of running git with arguments, or None if it fails.
+ :rtype: Optional[str]
+ """
+ try:
+ with open(os.devnull, "wb") as devnull:
+ return subprocess.check_output(["git"] + list(args), stderr=devnull).decode()
+ except KeyboardInterrupt:
+ raise
+ except BaseException: # pylint: disable=W0718
+ return None
+
+ # Check for environment variable overrides.
+ repository_uri = os.environ.get(GitProperties.ENV_REPOSITORY_URI, None)
+ branch = os.environ.get(GitProperties.ENV_BRANCH, None)
+ commit = os.environ.get(GitProperties.ENV_COMMIT, None)
+ dirty: Optional[Union[str, bool]] = os.environ.get(GitProperties.ENV_DIRTY, None)
+ build_id = os.environ.get(GitProperties.ENV_BUILD_ID, None)
+ build_uri = os.environ.get(GitProperties.ENV_BUILD_URI, None)
+
+ is_git_repo = _run_git_cmd(["rev-parse", "--is-inside-work-tree"])
+ if _clean_git_property_bool(is_git_repo):
+ repository_uri = repository_uri or _run_git_cmd(["ls-remote", "--get-url"])
+ branch = branch or _run_git_cmd(["symbolic-ref", "--short", "HEAD"])
+ commit = commit or _run_git_cmd(["rev-parse", "HEAD"])
+ dirty = dirty or _run_git_cmd(["status", "--porcelain", "."]) and True
+
+ # Parsing logic.
+ repository_uri = _clean_git_property_str(repository_uri)
+ commit = _clean_git_property_str(commit)
+ branch = _clean_git_property_str(branch)
+ dirty = _clean_git_property_bool(dirty)
+ build_id = _clean_git_property_str(build_id)
+ build_uri = _clean_git_property_str(build_uri)
+
+ # Return with appropriate labels.
+ properties = {}
+
+ if repository_uri is not None:
+ properties[GitProperties.PROP_MLFLOW_GIT_REPO_URL] = repository_uri
+
+ if branch is not None:
+ properties[GitProperties.PROP_MLFLOW_GIT_BRANCH] = branch
+
+ if commit is not None:
+ properties[GitProperties.PROP_MLFLOW_GIT_COMMIT] = commit
+
+ if dirty is not None:
+ properties[GitProperties.PROP_DIRTY] = str(dirty)
+
+ if build_id is not None:
+ properties[GitProperties.PROP_BUILD_ID] = build_id
+
+ if build_uri is not None:
+ properties[GitProperties.PROP_BUILD_URI] = build_uri
+
+ return properties
+
+
+def get_job_output_uris_from_dataplane(
+ job_name: Optional[str],
+ run_operations: RunOperations,
+ dataset_dataplane_operations: DatasetDataplaneOperations,
+ model_dataplane_operations: Optional[ModelDataplaneOperations],
+ output_names: Optional[Union[Iterable[str], str]] = None,
+) -> Dict[str, str]:
+ """Returns the output path for the given output in cloud storage of the given job.
+
+ If no output names are given, the output paths for all outputs will be returned.
+ URIs obtained from the service will be in the long-form azureml:// format.
+
+ For example:
+ azureml://subscriptions/<sub>/resource[gG]roups/<rg_name>/workspaces/<ws_name>/datastores/<ds_name>/paths/<ds_path>
+
+ :param job_name: The job name
+ :type job_name: str
+ :param run_operations: The RunOperations used to fetch run data for the job
+ :type run_operations: RunOperations
+ :param dataset_dataplane_operations: The DatasetDataplaneOperations used to fetch dataset uris
+ :type dataset_dataplane_operations: DatasetDataplaneOperations
+ :param model_dataplane_operations: The ModelDataplaneOperations used to fetch dataset uris
+ :type model_dataplane_operations: ModelDataplaneOperations
+ :param output_names: The output name(s) to fetch. If not specified, retrieves all.
+ :type output_names: Optional[Union[Iterable[str] str]]
+ :return: Dictionary mapping user-defined output name to output uri
+ :rtype: Dict[str, str]
+ """
+ run_metadata: Run = run_operations.get_run_data(str(job_name)).run_metadata
+ run_outputs: Dict[str, TypedAssetReference] = run_metadata.outputs or {}
+
+ # Create a reverse mapping from internal asset id to user-defined output name
+ asset_id_to_output_name = {v.asset_id: k for k, v in run_outputs.items()}
+ if not output_names:
+ # Assume all outputs are needed if no output name is provided
+ output_names = run_outputs.keys()
+ else:
+ if isinstance(output_names, str):
+ output_names = [output_names]
+ output_names = [o for o in output_names if o in run_outputs]
+
+ # Collect all output ids that correspond to data assets
+ dataset_ids = [
+ run_outputs[output_name].asset_id
+ for output_name in output_names
+ if run_outputs[output_name].type in [o.value for o in DataType]
+ ]
+
+ # Collect all output ids that correspond to models
+ model_ids = [
+ run_outputs[output_name].asset_id
+ for output_name in output_names
+ if run_outputs[output_name].type in [o.value for o in ModelType]
+ ]
+
+ output_name_to_dataset_uri = {}
+ if dataset_ids:
+ # Get the data paths from the service
+ dataset_uris = dataset_dataplane_operations.get_batch_dataset_uris(dataset_ids)
+ # Map the user-defined output name to the output uri
+ # The service returns a mapping from internal asset id to output metadata, so we need the reverse map
+ # defined above to get the user-defined output name from the internal asset id.
+ output_name_to_dataset_uri = {asset_id_to_output_name[k]: v.uri for k, v in dataset_uris.values.items()}
+
+ # This is a repeat of the logic above for models.
+ output_name_to_model_uri = {}
+ if model_ids:
+ model_uris = (
+ model_dataplane_operations.get_batch_model_uris(model_ids)
+ if model_dataplane_operations is not None
+ else None
+ )
+ output_name_to_model_uri = {
+ asset_id_to_output_name[k]: v.path for k, v in model_uris.values.items() # type: ignore
+ }
+ return {**output_name_to_dataset_uri, **output_name_to_model_uri}
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_deployment_helper.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_deployment_helper.py
new file mode 100644
index 00000000..f49ff59d
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_deployment_helper.py
@@ -0,0 +1,390 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access,too-many-locals
+
+import json
+import logging
+import os
+import shutil
+from pathlib import Path
+from typing import Any, Iterable, Optional, Union
+
+from marshmallow.exceptions import ValidationError as SchemaValidationError
+
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._local_endpoints import AzureMlImageContext, DockerfileResolver, LocalEndpointMode
+from azure.ai.ml._local_endpoints.docker_client import (
+ DockerClient,
+ get_deployment_json_from_container,
+ get_status_from_container,
+)
+from azure.ai.ml._local_endpoints.mdc_config_resolver import MdcConfigResolver
+from azure.ai.ml._local_endpoints.validators.code_validator import get_code_configuration_artifacts
+from azure.ai.ml._local_endpoints.validators.environment_validator import get_environment_artifacts
+from azure.ai.ml._local_endpoints.validators.model_validator import get_model_artifacts
+from azure.ai.ml._scope_dependent_operations import OperationsContainer
+from azure.ai.ml._utils._endpoint_utils import local_endpoint_polling_wrapper
+from azure.ai.ml._utils.utils import DockerProxy
+from azure.ai.ml.constants._common import AzureMLResourceType, DefaultOpenEncoding
+from azure.ai.ml.constants._endpoint import LocalEndpointConstants
+from azure.ai.ml.entities import OnlineDeployment
+from azure.ai.ml.exceptions import InvalidLocalEndpointError, LocalEndpointNotFoundError, ValidationException
+
+docker = DockerProxy()
+module_logger = logging.getLogger(__name__)
+
+
+class _LocalDeploymentHelper(object):
+ """A helper class to interact with Azure ML endpoints locally.
+
+ Use this helper to manage Azure ML endpoints locally, e.g. create, invoke, show, list, delete.
+ """
+
+ def __init__(
+ self,
+ operation_container: OperationsContainer,
+ ):
+ self._docker_client = DockerClient()
+ self._model_operations: Any = operation_container.all_operations.get(AzureMLResourceType.MODEL)
+ self._code_operations: Any = operation_container.all_operations.get(AzureMLResourceType.CODE)
+ self._environment_operations: Any = operation_container.all_operations.get(AzureMLResourceType.ENVIRONMENT)
+
+ def create_or_update( # type: ignore
+ self,
+ deployment: OnlineDeployment,
+ local_endpoint_mode: LocalEndpointMode,
+ local_enable_gpu: Optional[bool] = False,
+ ) -> OnlineDeployment:
+ """Create or update an deployment locally using Docker.
+
+ :param deployment: OnlineDeployment object with information from user yaml.
+ :type deployment: OnlineDeployment
+ :param local_endpoint_mode: Mode for how to create the local user container.
+ :type local_endpoint_mode: LocalEndpointMode
+ :param local_enable_gpu: enable local container to access gpu
+ :type local_enable_gpu: bool
+ """
+ try:
+ if deployment is None:
+ msg = "The entity provided for local endpoint was null. Please provide valid entity."
+ raise InvalidLocalEndpointError(message=msg, no_personal_data_message=msg)
+
+ endpoint_metadata: Any = None
+ try:
+ self.get(endpoint_name=str(deployment.endpoint_name), deployment_name=str(deployment.name))
+ endpoint_metadata = json.dumps(
+ self._docker_client.get_endpoint(endpoint_name=str(deployment.endpoint_name))
+ )
+ operation_message = "Updating local deployment"
+ except LocalEndpointNotFoundError:
+ operation_message = "Creating local deployment"
+
+ deployment_metadata = json.dumps(deployment._to_dict())
+ endpoint_metadata = (
+ endpoint_metadata
+ if endpoint_metadata
+ else _get_stubbed_endpoint_metadata(endpoint_name=str(deployment.endpoint_name))
+ )
+ local_endpoint_polling_wrapper(
+ func=self._create_deployment,
+ message=f"{operation_message} ({deployment.endpoint_name} / {deployment.name}) ",
+ endpoint_name=deployment.endpoint_name,
+ deployment=deployment,
+ local_endpoint_mode=local_endpoint_mode,
+ local_enable_gpu=local_enable_gpu,
+ endpoint_metadata=endpoint_metadata,
+ deployment_metadata=deployment_metadata,
+ )
+ return self.get(endpoint_name=str(deployment.endpoint_name), deployment_name=str(deployment.name))
+ except Exception as ex: # pylint: disable=W0718
+ if isinstance(ex, (ValidationException, SchemaValidationError)):
+ log_and_raise_error(ex)
+ else:
+ raise ex
+
+ def get_deployment_logs(self, endpoint_name: str, deployment_name: str, lines: int) -> str:
+ """Get logs from a local endpoint.
+
+ :param endpoint_name: Name of endpoint to invoke.
+ :type endpoint_name: str
+ :param deployment_name: Name of specific deployment to invoke.
+ :type deployment_name: str
+ :param lines: Number of most recent lines from container logs.
+ :type lines: int
+ :return: The deployment logs
+ :rtype: str
+ """
+ return str(self._docker_client.logs(endpoint_name=endpoint_name, deployment_name=deployment_name, lines=lines))
+
+ def get(self, endpoint_name: str, deployment_name: str) -> OnlineDeployment:
+ """Get a local deployment.
+
+ :param endpoint_name: Name of endpoint.
+ :type endpoint_name: str
+ :param deployment_name: Name of deployment.
+ :type deployment_name: str
+ :return: The deployment
+ :rtype: OnlineDeployment
+ """
+ container = self._docker_client.get_endpoint_container(
+ endpoint_name=endpoint_name,
+ deployment_name=deployment_name,
+ include_stopped=True,
+ )
+ if container is None:
+ raise LocalEndpointNotFoundError(endpoint_name=endpoint_name, deployment_name=deployment_name)
+ return _convert_container_to_deployment(container=container)
+
+ def list(self) -> Iterable[OnlineDeployment]:
+ """List all local endpoints.
+
+ :return: The OnlineDeployments
+ :rtype: Iterable[OnlineDeployment]
+ """
+ containers = self._docker_client.list_containers()
+ deployments = []
+ for container in containers:
+ deployments.append(_convert_container_to_deployment(container=container))
+ return deployments
+
+ def delete(self, name: str, deployment_name: Optional[str] = None) -> None:
+ """Delete a local deployment.
+
+ :param name: Name of endpoint associated with the deployment to delete.
+ :type name: str
+ :param deployment_name: Name of specific deployment to delete.
+ :type deployment_name: str
+ """
+ self._docker_client.delete(endpoint_name=name, deployment_name=deployment_name)
+ try:
+ build_directory = _get_deployment_directory(endpoint_name=name, deployment_name=deployment_name)
+ shutil.rmtree(build_directory)
+ except (PermissionError, OSError):
+ pass
+
+ def _create_deployment(
+ self,
+ endpoint_name: str,
+ deployment: OnlineDeployment,
+ local_endpoint_mode: LocalEndpointMode,
+ local_enable_gpu: Optional[bool] = False,
+ endpoint_metadata: Optional[dict] = None,
+ deployment_metadata: Optional[dict] = None,
+ ) -> None:
+ """Create deployment locally using Docker.
+
+ :param endpoint_name: OnlineDeployment object with information from user yaml.
+ :type endpoint_name: str
+ :param deployment: Deployment to create
+ :type deployment: OnlineDeployment
+ :param local_endpoint_mode: Mode for local endpoint.
+ :type local_endpoint_mode: LocalEndpointMode
+ :param local_enable_gpu: enable local container to access gpu
+ :type local_enable_gpu: bool
+ :param endpoint_metadata: Endpoint metadata (json serialied Endpoint entity)
+ :type endpoint_metadata: dict
+ :param deployment_metadata: Deployment metadata (json serialied Deployment entity)
+ :type deployment_metadata: dict
+ """
+ deployment_name = deployment.name
+ deployment_directory = _create_build_directory(
+ endpoint_name=endpoint_name, deployment_name=str(deployment_name)
+ )
+ deployment_directory_path = str(deployment_directory.resolve())
+
+ # Get assets for mounting into the container
+ # If code_directory_path is None, consider NCD flow
+ code_directory_path = get_code_configuration_artifacts(
+ endpoint_name=endpoint_name,
+ deployment=deployment,
+ code_operations=self._code_operations,
+ download_path=deployment_directory_path,
+ )
+ # We always require the model, however it may be anonymous for local (model_name=None)
+ (
+ model_name,
+ model_version,
+ model_directory_path,
+ ) = get_model_artifacts( # type: ignore[misc]
+ endpoint_name=endpoint_name,
+ deployment=deployment,
+ model_operations=self._model_operations,
+ download_path=deployment_directory_path,
+ )
+
+ # Resolve the environment information
+ # - Image + conda file - environment.image / environment.conda_file
+ # - Docker context - environment.build
+ (
+ yaml_base_image_name,
+ yaml_env_conda_file_path,
+ yaml_env_conda_file_contents,
+ downloaded_build_context,
+ yaml_dockerfile,
+ inference_config,
+ ) = get_environment_artifacts( # type: ignore[misc]
+ endpoint_name=endpoint_name,
+ deployment=deployment,
+ environment_operations=self._environment_operations,
+ download_path=deployment_directory, # type: ignore[arg-type]
+ )
+ # Retrieve AzureML specific information
+ # - environment variables required for deployment
+ # - volumes to mount into container
+ image_context = AzureMlImageContext(
+ endpoint_name=endpoint_name,
+ deployment_name=str(deployment_name),
+ yaml_code_directory_path=str(code_directory_path),
+ yaml_code_scoring_script_file_name=(
+ deployment.code_configuration.scoring_script if code_directory_path else None # type: ignore
+ ),
+ model_directory_path=model_directory_path,
+ model_mount_path=f"/{model_name}/{model_version}" if model_name else "",
+ )
+
+ # Construct Dockerfile if necessary, ie.
+ # - User did not provide environment.inference_config, then this is not BYOC flow, cases below:
+ # --- user provided environment.build
+ # --- user provided environment.image
+ # --- user provided environment.image + environment.conda_file
+ is_byoc = inference_config is not None
+ dockerfile: Any = None
+ if not is_byoc:
+ install_debugpy = local_endpoint_mode is LocalEndpointMode.VSCodeDevContainer
+ if yaml_env_conda_file_path:
+ _write_conda_file(
+ conda_contents=yaml_env_conda_file_contents,
+ directory_path=deployment_directory,
+ conda_file_name=LocalEndpointConstants.CONDA_FILE_NAME,
+ )
+ dockerfile = DockerfileResolver(
+ dockerfile=yaml_dockerfile,
+ docker_base_image=yaml_base_image_name,
+ docker_azureml_app_path=image_context.docker_azureml_app_path,
+ docker_conda_file_name=LocalEndpointConstants.CONDA_FILE_NAME,
+ docker_port=LocalEndpointConstants.DOCKER_PORT,
+ install_debugpy=install_debugpy,
+ )
+ else:
+ dockerfile = DockerfileResolver(
+ dockerfile=yaml_dockerfile,
+ docker_base_image=yaml_base_image_name,
+ docker_azureml_app_path=image_context.docker_azureml_app_path,
+ docker_conda_file_name=None,
+ docker_port=LocalEndpointConstants.DOCKER_PORT,
+ install_debugpy=install_debugpy,
+ )
+ dockerfile.write_file(directory_path=deployment_directory_path)
+
+ # Merge AzureML environment variables and user environment variables
+ user_environment_variables = deployment.environment_variables
+ environment_variables = {
+ **image_context.environment,
+ **user_environment_variables,
+ }
+
+ volumes = {}
+ volumes.update(image_context.volumes)
+
+ if deployment.data_collector:
+ mdc_config = MdcConfigResolver(deployment.data_collector)
+ mdc_config.write_file(deployment_directory_path)
+
+ environment_variables.update(mdc_config.environment_variables)
+ volumes.update(mdc_config.volumes)
+
+ # Determine whether we need to use local context or downloaded context
+ build_directory = downloaded_build_context if downloaded_build_context else deployment_directory
+ self._docker_client.create_deployment(
+ endpoint_name=endpoint_name,
+ deployment_name=str(deployment_name),
+ endpoint_metadata=endpoint_metadata, # type: ignore[arg-type]
+ deployment_metadata=deployment_metadata, # type: ignore[arg-type]
+ build_directory=str(build_directory),
+ dockerfile_path=None if is_byoc else dockerfile.local_path, # type: ignore[arg-type]
+ conda_source_path=yaml_env_conda_file_path,
+ conda_yaml_contents=yaml_env_conda_file_contents,
+ volumes=volumes,
+ environment=environment_variables,
+ azureml_port=(
+ inference_config.scoring_route.port if is_byoc else LocalEndpointConstants.DOCKER_PORT # type: ignore
+ ),
+ local_endpoint_mode=local_endpoint_mode,
+ prebuilt_image_name=yaml_base_image_name if is_byoc else None,
+ local_enable_gpu=local_enable_gpu,
+ )
+
+
+# Bug Item number: 2885719
+def _convert_container_to_deployment(
+ # Bug Item number: 2885719
+ container: "docker.models.containers.Container", # type: ignore
+) -> OnlineDeployment:
+ """Converts provided Container for local deployment to OnlineDeployment entity.
+
+ :param container: Container for a local deployment.
+ :type container: docker.models.containers.Container
+ :return: The OnlineDeployment entity
+ :rtype: OnlineDeployment
+ """
+ deployment_json = get_deployment_json_from_container(container=container)
+ provisioning_state = get_status_from_container(container=container)
+ if provisioning_state == LocalEndpointConstants.CONTAINER_EXITED:
+ return _convert_json_to_deployment(
+ deployment_json=deployment_json,
+ instance_type=LocalEndpointConstants.ENDPOINT_STATE_LOCATION,
+ provisioning_state=LocalEndpointConstants.ENDPOINT_STATE_FAILED,
+ )
+ return _convert_json_to_deployment(
+ deployment_json=deployment_json,
+ instance_type=LocalEndpointConstants.ENDPOINT_STATE_LOCATION,
+ provisioning_state=LocalEndpointConstants.ENDPOINT_STATE_SUCCEEDED,
+ )
+
+
+def _write_conda_file(conda_contents: str, directory_path: Union[str, os.PathLike], conda_file_name: str) -> None:
+ """Writes out conda file to provided directory.
+
+ :param conda_contents: contents of conda yaml file provided by user
+ :type conda_contents: str
+ :param directory_path: directory on user's local system to write conda file
+ :type directory_path: str
+ :param conda_file_name: The filename to write to
+ :type conda_file_name: str
+ """
+ conda_file_path = f"{directory_path}/{conda_file_name}"
+ p = Path(conda_file_path)
+ p.write_text(conda_contents, encoding=DefaultOpenEncoding.WRITE)
+
+
+def _convert_json_to_deployment(deployment_json: Optional[dict], **kwargs: Any) -> OnlineDeployment:
+ """Converts metadata json and kwargs to OnlineDeployment entity.
+
+ :param deployment_json: dictionary representation of OnlineDeployment entity.
+ :type deployment_json: dict
+ :returns: The OnlineDeployment entity
+ :rtype: OnlineDeployment
+ """
+ params_override = []
+ for k, v in kwargs.items():
+ params_override.append({k: v})
+ return OnlineDeployment._load(data=deployment_json, params_override=params_override)
+
+
+def _get_stubbed_endpoint_metadata(endpoint_name: str) -> str:
+ return json.dumps({"name": endpoint_name})
+
+
+def _create_build_directory(endpoint_name: str, deployment_name: str) -> Path:
+ build_directory = _get_deployment_directory(endpoint_name=endpoint_name, deployment_name=deployment_name)
+ build_directory.mkdir(parents=True, exist_ok=True)
+ return build_directory
+
+
+def _get_deployment_directory(endpoint_name: str, deployment_name: Optional[str]) -> Path:
+ if deployment_name is not None:
+ return Path(Path.home(), ".azureml", "inferencing", endpoint_name, deployment_name)
+
+ return Path(Path.home(), ".azureml", "inferencing", endpoint_name, "")
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_endpoint_helper.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_endpoint_helper.py
new file mode 100644
index 00000000..341c55b0
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_endpoint_helper.py
@@ -0,0 +1,205 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import json
+import logging
+from typing import Any, Iterable, List, Optional
+
+from marshmallow.exceptions import ValidationError as SchemaValidationError
+
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._local_endpoints import EndpointStub
+from azure.ai.ml._local_endpoints.docker_client import (
+ DockerClient,
+ get_endpoint_json_from_container,
+ get_scoring_uri_from_container,
+ get_status_from_container,
+)
+from azure.ai.ml._utils._endpoint_utils import local_endpoint_polling_wrapper
+from azure.ai.ml._utils._http_utils import HttpPipeline
+from azure.ai.ml._utils.utils import DockerProxy
+from azure.ai.ml.constants._endpoint import EndpointInvokeFields, LocalEndpointConstants
+from azure.ai.ml.entities import OnlineEndpoint
+from azure.ai.ml.exceptions import InvalidLocalEndpointError, LocalEndpointNotFoundError, ValidationException
+
+docker = DockerProxy()
+module_logger = logging.getLogger(__name__)
+
+
+class _LocalEndpointHelper(object):
+ """A helper class to interact with Azure ML endpoints locally.
+
+ Use this helper to manage Azure ML endpoints locally, e.g. create, invoke, show, list, delete.
+ """
+
+ def __init__(self, *, requests_pipeline: HttpPipeline):
+ self._docker_client = DockerClient()
+ self._endpoint_stub = EndpointStub()
+ self._requests_pipeline = requests_pipeline
+
+ def create_or_update(self, endpoint: OnlineEndpoint) -> OnlineEndpoint: # type: ignore
+ """Create or update an endpoint locally using Docker.
+
+ :param endpoint: OnlineEndpoint object with information from user yaml.
+ :type endpoint: OnlineEndpoint
+ """
+ try:
+ if endpoint is None:
+ msg = "The entity provided for local endpoint was null. Please provide valid entity."
+ raise InvalidLocalEndpointError(message=msg, no_personal_data_message=msg)
+
+ try:
+ self.get(endpoint_name=str(endpoint.name))
+ operation_message = "Updating local endpoint"
+ except LocalEndpointNotFoundError:
+ operation_message = "Creating local endpoint"
+
+ local_endpoint_polling_wrapper(
+ func=self._endpoint_stub.create_or_update,
+ message=f"{operation_message} ({endpoint.name}) ",
+ endpoint=endpoint,
+ )
+ return self.get(endpoint_name=str(endpoint.name))
+ except Exception as ex: # pylint: disable=W0718
+ if isinstance(ex, (ValidationException, SchemaValidationError)):
+ log_and_raise_error(ex)
+ else:
+ raise ex
+
+ def invoke(self, endpoint_name: str, data: dict, deployment_name: Optional[str] = None) -> str:
+ """Invoke a local endpoint.
+
+ :param endpoint_name: Name of endpoint to invoke.
+ :type endpoint_name: str
+ :param data: json data to pass
+ :type data: dict
+ :param deployment_name: Name of specific deployment to invoke.
+ :type deployment_name: (str, optional)
+ :return: The text response
+ :rtype: str
+ """
+ # get_scoring_uri will throw user error if there are multiple deployments and no deployment_name is specified
+ scoring_uri = self._docker_client.get_scoring_uri(endpoint_name=endpoint_name, deployment_name=deployment_name)
+ if scoring_uri:
+ headers = {}
+ if deployment_name is not None:
+ headers[EndpointInvokeFields.MODEL_DEPLOYMENT] = deployment_name
+ return str(self._requests_pipeline.post(scoring_uri, json=data, headers=headers).text())
+ endpoint_stub = self._endpoint_stub.get(endpoint_name=endpoint_name)
+ if endpoint_stub:
+ return str(self._endpoint_stub.invoke())
+ raise LocalEndpointNotFoundError(endpoint_name=endpoint_name, deployment_name=deployment_name)
+
+ def get(self, endpoint_name: str) -> OnlineEndpoint:
+ """Get a local endpoint.
+
+ :param endpoint_name: Name of endpoint.
+ :type endpoint_name: str
+ :return OnlineEndpoint:
+ """
+ endpoint = self._endpoint_stub.get(endpoint_name=endpoint_name)
+ container = self._docker_client.get_endpoint_container(endpoint_name=endpoint_name, include_stopped=True)
+ if endpoint:
+ if container:
+ return _convert_container_to_endpoint(container=container, endpoint_json=endpoint.dump())
+ return endpoint
+ if container:
+ return _convert_container_to_endpoint(container=container)
+ raise LocalEndpointNotFoundError(endpoint_name=endpoint_name)
+
+ def list(self) -> Iterable[OnlineEndpoint]:
+ """List all local endpoints.
+
+ :return: An iterable of local endpoints
+ :rtype: Iterable[OnlineEndpoint]
+ """
+ endpoints: List = []
+ containers = self._docker_client.list_containers()
+ endpoint_stubs = self._endpoint_stub.list()
+ # Iterate through all cached endpoint files
+ for endpoint_file in endpoint_stubs:
+ endpoint_json = json.loads(endpoint_file.read_text())
+ container = self._docker_client.get_endpoint_container(
+ endpoint_name=endpoint_json.get("name"), include_stopped=True
+ )
+ # If a deployment is associated with endpoint,
+ # override certain endpoint properties with deployment information and remove it from containers list.
+ # Otherwise, return endpoint spec.
+ if container:
+ endpoints.append(_convert_container_to_endpoint(endpoint_json=endpoint_json, container=container))
+ containers.remove(container)
+ else:
+ endpoints.append(
+ OnlineEndpoint._load(
+ data=endpoint_json,
+ params_override=[{"location": LocalEndpointConstants.ENDPOINT_STATE_LOCATION}],
+ )
+ )
+ # Iterate through any deployments that don't have an explicit local endpoint stub.
+ for container in containers:
+ endpoints.append(_convert_container_to_endpoint(container=container))
+ return endpoints
+
+ def delete(self, name: str) -> None:
+ """Delete a local endpoint.
+
+ :param name: Name of endpoint to delete.
+ :type name: str
+ """
+ endpoint_stub = self._endpoint_stub.get(endpoint_name=name)
+ if endpoint_stub:
+ self._endpoint_stub.delete(endpoint_name=name)
+ endpoint_container = self._docker_client.get_endpoint_container(endpoint_name=name)
+ if endpoint_container:
+ self._docker_client.delete(endpoint_name=name)
+ else:
+ raise LocalEndpointNotFoundError(endpoint_name=name)
+
+
+def _convert_container_to_endpoint(
+ # Bug Item number: 2885719
+ container: "docker.models.containers.Container", # type: ignore
+ endpoint_json: Optional[dict] = None,
+) -> OnlineEndpoint:
+ """Converts provided Container for local deployment to OnlineEndpoint entity.
+
+ :param container: Container for a local deployment.
+ :type container: docker.models.containers.Container
+ :param endpoint_json: The endpoint json
+ :type endpoint_json: Optional[dict]
+ :return: The OnlineEndpoint entity
+ :rtype: OnlineEndpoint
+ """
+ if endpoint_json is None:
+ endpoint_json = get_endpoint_json_from_container(container=container)
+ provisioning_state = get_status_from_container(container=container)
+ if provisioning_state == LocalEndpointConstants.CONTAINER_EXITED:
+ return _convert_json_to_endpoint(
+ endpoint_json=endpoint_json,
+ location=LocalEndpointConstants.ENDPOINT_STATE_LOCATION,
+ provisioning_state=LocalEndpointConstants.ENDPOINT_STATE_FAILED,
+ )
+ scoring_uri = get_scoring_uri_from_container(container=container)
+ return _convert_json_to_endpoint(
+ endpoint_json=endpoint_json,
+ location=LocalEndpointConstants.ENDPOINT_STATE_LOCATION,
+ provisioning_state=LocalEndpointConstants.ENDPOINT_STATE_SUCCEEDED,
+ scoring_uri=scoring_uri,
+ )
+
+
+def _convert_json_to_endpoint(endpoint_json: Optional[dict], **kwargs: Any) -> OnlineEndpoint:
+ """Converts metadata json and kwargs to OnlineEndpoint entity.
+
+ :param endpoint_json: dictionary representation of OnlineEndpoint entity.
+ :type endpoint_json: dict
+ :return: The OnlineEndpoint entity
+ :rtype: OnlineEndpoint
+ """
+ params_override = []
+ for k, v in kwargs.items():
+ params_override.append({k: v})
+ return OnlineEndpoint._load(data=endpoint_json, params_override=params_override) # type: ignore
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_job_invoker.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_job_invoker.py
new file mode 100644
index 00000000..90914e4c
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_local_job_invoker.py
@@ -0,0 +1,432 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import base64
+import io
+import json
+import logging
+import os
+import re
+import shutil
+import subprocess
+import tarfile
+import tempfile
+import urllib.parse
+import zipfile
+from pathlib import Path
+from threading import Thread
+from typing import Any, Dict, Optional, Tuple
+
+from azure.ai.ml._restclient.v2022_02_01_preview.models import JobBaseData
+from azure.ai.ml._utils._http_utils import HttpPipeline
+from azure.ai.ml._utils.utils import DockerProxy
+from azure.ai.ml.constants._common import (
+ AZUREML_RUN_SETUP_DIR,
+ AZUREML_RUNS_DIR,
+ EXECUTION_SERVICE_URL_KEY,
+ INVOCATION_BASH_FILE,
+ INVOCATION_BAT_FILE,
+ LOCAL_JOB_FAILURE_MSG,
+ DefaultOpenEncoding,
+)
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, JobException, MlException
+from azure.core.credentials import TokenCredential
+from azure.core.exceptions import AzureError
+
+docker = DockerProxy()
+module_logger = logging.getLogger(__name__)
+
+
+def unzip_to_temporary_file(job_definition: JobBaseData, zip_content: Any) -> Path:
+ temp_dir = Path(tempfile.gettempdir(), AZUREML_RUNS_DIR, job_definition.name)
+ temp_dir.mkdir(parents=True, exist_ok=True)
+ with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_ref:
+ zip_ref.extractall(temp_dir)
+ return temp_dir
+
+
+def _get_creationflags_and_startupinfo_for_background_process(
+ os_override: Optional[str] = None,
+) -> Dict:
+ args: Dict = {
+ "startupinfo": None,
+ "creationflags": None,
+ "stdin": None,
+ "stdout": None,
+ "stderr": None,
+ "shell": False,
+ }
+ os_name = os_override if os_override is not None else os.name
+ if os_name == "nt":
+ # Windows process creation flag to not reuse the parent console.
+
+ # Without this, the background service is associated with the
+ # starting process's console, and will block that console from
+ # exiting until the background service self-terminates. Elsewhere,
+ # fork just does the right thing.
+
+ CREATE_NEW_CONSOLE = 0x00000010
+ args["creationflags"] = CREATE_NEW_CONSOLE
+
+ # Bug Item number: 2895261
+ startupinfo = subprocess.STARTUPINFO() # type: ignore
+ startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW # type: ignore
+ startupinfo.wShowWindow = subprocess.SW_HIDE # type: ignore
+ args["startupinfo"] = startupinfo
+
+ else:
+ # On MacOS, the child inherits the parent's stdio descriptors by
+ # default this can block the parent's stdout/stderr from closing even
+ # after the parent has exited.
+
+ args["stdin"] = subprocess.DEVNULL
+ args["stdout"] = subprocess.DEVNULL
+ args["stderr"] = subprocess.STDOUT
+
+ # filter entries with value None
+ return {k: v for (k, v) in args.items() if v}
+
+
+def patch_invocation_script_serialization(invocation_path: Path) -> None:
+ content = invocation_path.read_text()
+ searchRes = re.search(r"([\s\S]*)(--snapshots \'.*\')([\s\S]*)", content)
+ if searchRes:
+ patched_json = searchRes.group(2).replace('"', '\\"')
+ patched_json = patched_json.replace("'", '"')
+ invocation_path.write_text(searchRes.group(1) + patched_json + searchRes.group(3))
+
+
+def invoke_command(project_temp_dir: Path) -> None:
+ if os.name == "nt":
+ invocation_script = project_temp_dir / AZUREML_RUN_SETUP_DIR / INVOCATION_BAT_FILE
+ # There is a bug in Execution service on the serialized json for snapshots.
+ # This is a client-side patch until the service fixes it, at which point it should
+ # be a no-op
+ patch_invocation_script_serialization(invocation_script)
+ invoked_command = ["cmd.exe", "/c", "{0}".format(invocation_script)]
+ else:
+ invocation_script = project_temp_dir / AZUREML_RUN_SETUP_DIR / INVOCATION_BASH_FILE
+ subprocess.check_output(["chmod", "+x", invocation_script])
+ invoked_command = ["/bin/bash", "-c", "{0}".format(invocation_script)]
+
+ env = os.environ.copy()
+ env.pop("AZUREML_TARGET_TYPE", None)
+ subprocess.Popen( # pylint: disable=consider-using-with
+ invoked_command,
+ cwd=project_temp_dir,
+ env=env,
+ **_get_creationflags_and_startupinfo_for_background_process(),
+ )
+
+
+def get_execution_service_response(
+ job_definition: JobBaseData, token: str, requests_pipeline: HttpPipeline
+) -> Tuple[Dict[str, str], str]:
+ """Get zip file containing local run information from Execution Service.
+
+ MFE will send down a mock job contract, with service 'local'.
+ This will have the URL for contacting Execution Service, with a URL-encoded JSON object following the '&fake='
+ string (aka EXECUTION_SERVICE_URL_KEY constant below). The encoded JSON should be the body to pass from the
+ client to ES. The ES response will be a zip file containing all the scripts required to invoke a local run.
+
+ :param job_definition: Job definition data
+ :type job_definition: JobBaseData
+ :param token: The bearer token to use when retrieving information from Execution Service
+ :type token: str
+ :param requests_pipeline: The HttpPipeline to use when sending network requests
+ :type requests_pipeline: HttpPipeline
+ :return: Execution service response and snapshot ID
+ :rtype: Tuple[Dict[str, str], str]
+ """
+ try:
+ local = job_definition.properties.services.get("Local", None)
+
+ (url, encodedBody) = local.endpoint.split(EXECUTION_SERVICE_URL_KEY)
+ body = urllib.parse.unquote_plus(encodedBody)
+ body_dict: Dict = json.loads(body)
+ response = requests_pipeline.post(url, json=body_dict, headers={"Authorization": "Bearer " + token})
+ response.raise_for_status()
+ return (response.content, body_dict.get("SnapshotId", None))
+ except AzureError as err:
+ raise SystemExit(err) from err
+ except Exception as e:
+ msg = "Failed to read in local executable job"
+ raise JobException(
+ message=msg,
+ target=ErrorTarget.LOCAL_JOB,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ ) from e
+
+
+def is_local_run(job_definition: JobBaseData) -> bool:
+ if not job_definition.properties.services:
+ return False
+ local = job_definition.properties.services.get("Local", None)
+ return local is not None and EXECUTION_SERVICE_URL_KEY in local.endpoint
+
+
+class CommonRuntimeHelper:
+ COMMON_RUNTIME_BOOTSTRAPPER_INFO = "common_runtime_bootstrapper_info.json"
+ COMMON_RUNTIME_JOB_SPEC = "common_runtime_job_spec.json"
+ VM_BOOTSTRAPPER_FILE_NAME = "vm-bootstrapper"
+ LOCAL_JOB_ENV_VARS = {
+ "RUST_LOG": "1",
+ "AZ_BATCHAI_CLUSTER_NAME": "fake_cluster_name",
+ "AZ_BATCH_NODE_ID": "fake_id",
+ "AZ_BATCH_NODE_ROOT_DIR": ".",
+ "AZ_BATCH_CERTIFICATES_DIR": ".",
+ "AZ_BATCH_NODE_SHARED_DIR": ".",
+ "AZ_LS_CERT_THUMBPRINT": "fake_thumbprint",
+ }
+ DOCKER_IMAGE_WARNING_MSG = (
+ "Failed to pull required Docker image. "
+ "Please try removing all unused containers to free up space and then re-submit your job."
+ )
+ DOCKER_CLIENT_FAILURE_MSG = (
+ "Failed to create Docker client. Is Docker running/installed?\n "
+ "For local submissions, we need to build a Docker container to run your job in.\n Detailed message: {}"
+ )
+ DOCKER_DAEMON_FAILURE_MSG = (
+ "Unable to communicate with Docker daemon. Is Docker running/installed?\n "
+ "For local submissions, we need to build a Docker container to run your job in.\n Detailed message: {}"
+ )
+ DOCKER_LOGIN_FAILURE_MSG = "Login to Docker registry '{}' failed. See error message: {}"
+ BOOTSTRAP_BINARY_FAILURE_MSG = (
+ "Azure Common Runtime execution failed. See detailed message below for troubleshooting "
+ "information or re-submit with flag --use-local-runtime to try running on your local runtime: {}"
+ )
+
+ def __init__(self, job_name: str):
+ self.common_runtime_temp_folder = os.path.join(Path.home(), ".azureml-common-runtime", job_name)
+ if os.path.exists(self.common_runtime_temp_folder):
+ shutil.rmtree(self.common_runtime_temp_folder)
+ Path(self.common_runtime_temp_folder).mkdir(parents=True)
+ self.vm_bootstrapper_full_path = os.path.join(
+ self.common_runtime_temp_folder,
+ CommonRuntimeHelper.VM_BOOTSTRAPPER_FILE_NAME,
+ )
+ self.stdout = open( # pylint: disable=consider-using-with
+ os.path.join(self.common_runtime_temp_folder, "stdout"), "w+", encoding=DefaultOpenEncoding.WRITE
+ )
+ self.stderr = open( # pylint: disable=consider-using-with
+ os.path.join(self.common_runtime_temp_folder, "stderr"), "w+", encoding=DefaultOpenEncoding.WRITE
+ )
+
+ # Bug Item number: 2885723
+ def get_docker_client(self, registry: Dict) -> "docker.DockerClient": # type: ignore
+ """Retrieves the Docker client for performing docker operations.
+
+ :param registry: Registry information
+ :type registry: Dict[str, str]
+ :return: Docker client
+ :rtype: docker.DockerClient
+ """
+ try:
+ client = docker.from_env(version="auto")
+ except docker.errors.DockerException as e:
+ msg = self.DOCKER_CLIENT_FAILURE_MSG.format(e)
+ raise MlException(message=msg, no_personal_data_message=msg) from e
+
+ try:
+ client.version()
+ except Exception as e:
+ msg = self.DOCKER_DAEMON_FAILURE_MSG.format(e)
+ raise MlException(message=msg, no_personal_data_message=msg) from e
+
+ if registry:
+ try:
+ client.login(
+ username=registry.get("username"),
+ password=registry.get("password"),
+ registry=registry.get("url"),
+ )
+ except Exception as e:
+ raise RuntimeError(self.DOCKER_LOGIN_FAILURE_MSG.format(registry.get("url"), e)) from e
+ else:
+ raise RuntimeError("Registry information is missing from bootstrapper configuration.")
+
+ return client
+
+ # Bug Item number: 2885719
+ def copy_bootstrapper_from_container(self, container: "docker.models.containers.Container") -> None: # type: ignore
+ """Copy file/folder from container to local machine.
+
+ :param container: Docker container
+ :type container: docker.models.containers.Container
+ """
+ path_in_container = CommonRuntimeHelper.VM_BOOTSTRAPPER_FILE_NAME
+ path_in_host = self.vm_bootstrapper_full_path
+
+ try:
+ data_stream, _ = container.get_archive(path_in_container)
+ tar_file = path_in_host + ".tar"
+ with open(tar_file, "wb") as f:
+ for chunk in data_stream:
+ f.write(chunk)
+ with tarfile.open(tar_file, mode="r") as tar:
+ for file_name in tar.getnames():
+ tar.extract(file_name, os.path.dirname(path_in_host))
+ os.remove(tar_file)
+ except docker.errors.APIError as e:
+ msg = f"Copying {path_in_container} from container has failed. Detailed message: {e}"
+ raise MlException(message=msg, no_personal_data_message=msg) from e
+
+ def get_common_runtime_info_from_response(self, response: Any) -> Tuple[Dict[str, str], str]:
+ """Extract common-runtime info from Execution Service response.
+
+ :param response: Content of zip file from Execution Service containing all the
+ scripts required to invoke a local run.
+ :type response: Dict[str, str]
+ :return: Bootstrapper info and job specification
+ :rtype: Tuple[Dict[str, str], str]
+ """
+
+ with zipfile.ZipFile(io.BytesIO(response)) as zip_ref:
+ bootstrapper_path = f"{AZUREML_RUN_SETUP_DIR}/{self.COMMON_RUNTIME_BOOTSTRAPPER_INFO}"
+ job_spec_path = f"{AZUREML_RUN_SETUP_DIR}/{self.COMMON_RUNTIME_JOB_SPEC}"
+ if not all(file_path in zip_ref.namelist() for file_path in [bootstrapper_path, job_spec_path]):
+ raise RuntimeError(f"{bootstrapper_path}, {job_spec_path} are not in the execution service response.")
+
+ with zip_ref.open(bootstrapper_path, "r") as bootstrapper_file:
+ bootstrapper_json = json.loads(base64.b64decode(bootstrapper_file.read()))
+ with zip_ref.open(job_spec_path, "r") as job_spec_file:
+ job_spec = job_spec_file.read().decode("utf-8")
+
+ return bootstrapper_json, job_spec
+
+ def get_bootstrapper_binary(self, bootstrapper_info: Dict) -> None:
+ """Copy bootstrapper binary from the bootstrapper image to local machine.
+
+ :param bootstrapper_info:
+ :type bootstrapper_info: Dict[str, str]
+ """
+ Path(self.common_runtime_temp_folder).mkdir(parents=True, exist_ok=True)
+
+ # Pull and build the docker image
+ registry: Any = bootstrapper_info.get("registry")
+ docker_client = self.get_docker_client(registry)
+ repo_prefix = bootstrapper_info.get("repo_prefix")
+ repository = registry.get("url")
+ tag = bootstrapper_info.get("tag")
+
+ if repo_prefix:
+ bootstrapper_image = f"{repository}/{repo_prefix}/boot/vm-bootstrapper/binimage/linux:{tag}"
+ else:
+ bootstrapper_image = f"{repository}/boot/vm-bootstrapper/binimage/linux:{tag}"
+
+ try:
+ boot_img = docker_client.images.pull(bootstrapper_image)
+ except Exception as e:
+ module_logger.warning(self.DOCKER_IMAGE_WARNING_MSG)
+ raise e
+
+ boot_container = docker_client.containers.create(image=boot_img, command=[""])
+ self.copy_bootstrapper_from_container(boot_container)
+
+ boot_container.stop()
+ boot_container.remove()
+
+ def execute_bootstrapper(self, bootstrapper_binary: str, job_spec: str) -> subprocess.Popen:
+ """Runs vm-bootstrapper with the job specification passed to it. This will build the Docker container, create
+ all necessary files and directories, and run the job locally. Command args are defined by Common Runtime team
+ here: https://msdata.visualstudio.com/Vienna/_git/vienna?path=/src/azureml- job-runtime/common-
+ runtime/bootstrapper/vm-bootstrapper/src/main.rs&ver
+ sion=GBmaster&line=764&lineEnd=845&lineStartColumn=1&lineEndColumn=6&li neStyle=plain&_a=contents.
+
+ :param bootstrapper_binary: Binary file path for VM bootstrapper
+ (".azureml-common-runtime/<job_name>/vm-bootstrapper")
+ :type bootstrapper_binary: str
+ :param job_spec: JSON content of job specification
+ :type job_spec: str
+ :return process: Subprocess running the bootstrapper
+ :rtype process: subprocess.Popen
+ """
+ cmd = [
+ bootstrapper_binary,
+ "--job-spec",
+ job_spec,
+ "--skip-auto-update", # Skip the auto update
+ # "Disable the standard Identity Responder and use a dummy command instead."
+ "--disable-identity-responder",
+ "--skip-cleanup", # "Keep containers and volumes for debug."
+ ]
+
+ env = self.LOCAL_JOB_ENV_VARS
+
+ process = subprocess.Popen( # pylint: disable=consider-using-with
+ cmd,
+ env=env,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ cwd=self.common_runtime_temp_folder,
+ encoding="utf-8",
+ )
+ _log_subprocess(process.stdout, self.stdout)
+ _log_subprocess(process.stderr, self.stderr)
+
+ if self.check_bootstrapper_process_status(process):
+ return process
+ process.terminate()
+ process.kill()
+ raise RuntimeError(LOCAL_JOB_FAILURE_MSG.format(self.stderr.read()))
+
+ def check_bootstrapper_process_status(self, bootstrapper_process: subprocess.Popen) -> Optional[int]:
+ """Check if bootstrapper process status is non-zero.
+
+ :param bootstrapper_process: bootstrapper process
+ :type bootstrapper_process: subprocess.Popen
+ :return: return_code
+ :rtype: int
+ """
+ return_code = bootstrapper_process.poll()
+ if return_code:
+ self.stderr.seek(0)
+ raise RuntimeError(self.BOOTSTRAP_BINARY_FAILURE_MSG.format(self.stderr.read()))
+ return return_code
+
+
+def start_run_if_local(
+ job_definition: JobBaseData,
+ credential: TokenCredential,
+ ws_base_url: str,
+ requests_pipeline: HttpPipeline,
+) -> str:
+ """Request execution bundle from ES and run job. If Linux or WSL environment, unzip and invoke job using job spec
+ and bootstrapper. Otherwise, invoke command locally.
+
+ :param job_definition: Job definition data
+ :type job_definition: JobBaseData
+ :param credential: Credential to use for authentication
+ :type credential: TokenCredential
+ :param ws_base_url: Base url to workspace
+ :type ws_base_url: str
+ :param requests_pipeline: The HttpPipeline to use when sending network requests
+ :type requests_pipeline: HttpPipeline
+ :return: snapshot ID
+ :rtype: str
+ """
+ token = credential.get_token(ws_base_url + "/.default").token
+ (zip_content, snapshot_id) = get_execution_service_response(job_definition, token, requests_pipeline)
+
+ try:
+ temp_dir = unzip_to_temporary_file(job_definition, zip_content)
+ invoke_command(temp_dir)
+ except Exception as e:
+ msg = LOCAL_JOB_FAILURE_MSG.format(e)
+ raise MlException(message=msg, no_personal_data_message=msg) from e
+
+ return snapshot_id
+
+
+def _log_subprocess(output_io: Any, file: Any, show_in_console: bool = False) -> None:
+ def log_subprocess() -> None:
+ for line in iter(output_io.readline, ""):
+ if show_in_console:
+ print(line, end="")
+ file.write(line)
+
+ thread = Thread(target=log_subprocess)
+ thread.daemon = True
+ thread.start()
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_marketplace_subscription_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_marketplace_subscription_operations.py
new file mode 100644
index 00000000..36683e3a
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_marketplace_subscription_operations.py
@@ -0,0 +1,122 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Iterable
+
+from azure.ai.ml._restclient.v2024_01_01_preview import (
+ AzureMachineLearningWorkspaces as ServiceClient202401Preview,
+)
+from azure.ai.ml._scope_dependent_operations import (
+ OperationConfig,
+ OperationScope,
+ _ScopeDependentOperations,
+)
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml.entities._autogen_entities.models import MarketplaceSubscription
+from azure.core.polling import LROPoller
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class MarketplaceSubscriptionOperations(_ScopeDependentOperations):
+ """MarketplaceSubscriptionOperations.
+
+ You should not instantiate this class directly. Instead, you should
+ create an MLClient instance that instantiates it for you and
+ attaches it as an attribute.
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client: ServiceClient202401Preview,
+ ):
+ super().__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+ self._service_client = service_client.marketplace_subscriptions
+
+ @experimental
+ @monitor_with_activity(
+ ops_logger,
+ "MarketplaceSubscription.BeginCreateOrUpdate",
+ ActivityType.PUBLICAPI,
+ )
+ def begin_create_or_update(
+ self, marketplace_subscription: MarketplaceSubscription, **kwargs
+ ) -> LROPoller[MarketplaceSubscription]:
+ """Create or update a Marketplace Subscription.
+
+ :param marketplace_subscription: The marketplace subscription entity.
+ :type marketplace_subscription: ~azure.ai.ml.entities.MarketplaceSubscription
+ :return: A poller to track the operation status
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.MarketplaceSubscription]
+ """
+ return self._service_client.begin_create_or_update(
+ self._resource_group_name,
+ self._workspace_name,
+ marketplace_subscription.name,
+ marketplace_subscription._to_rest_object(), # type: ignore
+ cls=lambda response, deserialized, headers: MarketplaceSubscription._from_rest_object( # type: ignore
+ deserialized
+ ),
+ **kwargs,
+ )
+
+ @experimental
+ @monitor_with_activity(ops_logger, "MarketplaceSubscription.Get", ActivityType.PUBLICAPI)
+ def get(self, name: str, **kwargs) -> MarketplaceSubscription:
+ """Get a Marketplace Subscription resource.
+
+ :param name: Name of the marketplace subscription.
+ :type name: str
+ :return: Marketplace subscription object retrieved from the service.
+ :rtype: ~azure.ai.ml.entities.MarketplaceSubscription
+ """
+ return self._service_client.get(
+ self._resource_group_name,
+ self._workspace_name,
+ name,
+ cls=lambda response, deserialized, headers: MarketplaceSubscription._from_rest_object( # type: ignore
+ deserialized
+ ),
+ **kwargs,
+ )
+
+ @experimental
+ @monitor_with_activity(ops_logger, "MarketplaceSubscription.List", ActivityType.PUBLICAPI)
+ def list(self, **kwargs) -> Iterable[MarketplaceSubscription]:
+ """List marketplace subscriptions of the workspace.
+
+ :return: A list of marketplace subscriptions
+ :rtype: ~typing.Iterable[~azure.ai.ml.entities.MarketplaceSubscription]
+ """
+ return self._service_client.list(
+ self._resource_group_name,
+ self._workspace_name,
+ cls=lambda objs: [MarketplaceSubscription._from_rest_object(obj) for obj in objs], # type: ignore
+ **kwargs,
+ )
+
+ @experimental
+ @monitor_with_activity(ops_logger, "MarketplaceSubscription.BeginDelete", ActivityType.PUBLICAPI)
+ def begin_delete(self, name: str, **kwargs) -> LROPoller[None]:
+ """Delete a Marketplace Subscription.
+
+ :param name: Name of the marketplace subscription.
+ :type name: str
+ :return: A poller to track the operation status.
+ :rtype: ~azure.core.polling.LROPoller[None]
+ """
+ return self._service_client.begin_delete(
+ self._resource_group_name,
+ self._workspace_name,
+ name=name,
+ **kwargs,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_model_dataplane_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_model_dataplane_operations.py
new file mode 100644
index 00000000..da15d079
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_model_dataplane_operations.py
@@ -0,0 +1,32 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import logging
+from typing import List
+
+from azure.ai.ml._restclient.model_dataplane import AzureMachineLearningWorkspaces as ServiceClientModelDataplane
+from azure.ai.ml._restclient.model_dataplane.models import BatchGetResolvedUrisDto, BatchModelPathResponseDto
+from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations
+
+module_logger = logging.getLogger(__name__)
+
+
+class ModelDataplaneOperations(_ScopeDependentOperations):
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client: ServiceClientModelDataplane,
+ ):
+ super().__init__(operation_scope, operation_config)
+ self._operation = service_client.models
+
+ def get_batch_model_uris(self, model_ids: List[str]) -> BatchModelPathResponseDto:
+ batch_uri_request = BatchGetResolvedUrisDto(values=model_ids)
+ return self._operation.batch_get_resolved_uris(
+ self._operation_scope.subscription_id,
+ self._operation_scope.resource_group_name,
+ self._workspace_name,
+ body=batch_uri_request,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_model_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_model_operations.py
new file mode 100644
index 00000000..be468089
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_model_operations.py
@@ -0,0 +1,833 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access,disable=docstring-missing-return,docstring-missing-param,docstring-missing-rtype,line-too-long,too-many-statements
+
+import re
+from contextlib import contextmanager
+from os import PathLike, path
+from typing import Any, Dict, Generator, Iterable, Optional, Union, cast
+
+from marshmallow.exceptions import ValidationError as SchemaValidationError
+
+from azure.ai.ml._artifacts._artifact_utilities import (
+ _check_and_upload_path,
+ _get_default_datastore_info,
+ _update_metadata,
+)
+from azure.ai.ml._artifacts._constants import (
+ ASSET_PATH_ERROR,
+ CHANGED_ASSET_PATH_MSG,
+ CHANGED_ASSET_PATH_MSG_NO_PERSONAL_DATA,
+)
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._restclient.v2021_10_01_dataplanepreview import (
+ AzureMachineLearningWorkspaces as ServiceClient102021Dataplane,
+)
+from azure.ai.ml._restclient.v2023_08_01_preview import AzureMachineLearningWorkspaces as ServiceClient082023Preview
+from azure.ai.ml._restclient.v2023_08_01_preview.models import ListViewType, ModelVersion
+from azure.ai.ml._scope_dependent_operations import (
+ OperationConfig,
+ OperationsContainer,
+ OperationScope,
+ _ScopeDependentOperations,
+)
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._arm_id_utils import AMLVersionedArmId, is_ARM_id_for_resource
+from azure.ai.ml._utils._asset_utils import (
+ _archive_or_restore,
+ _get_latest,
+ _get_next_version_from_container,
+ _resolve_label_to_asset,
+)
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._utils._registry_utils import (
+ get_asset_body_for_registry_storage,
+ get_registry_client,
+ get_sas_uri_for_registry_asset,
+ get_storage_details_for_registry_assets,
+)
+from azure.ai.ml._utils._storage_utils import get_ds_name_and_path_prefix, get_storage_client
+from azure.ai.ml._utils.utils import _is_evaluator, resolve_short_datastore_url, validate_ml_flow_folder
+from azure.ai.ml.constants._common import ARM_ID_PREFIX, ASSET_ID_FORMAT, REGISTRY_URI_FORMAT, AzureMLResourceType
+from azure.ai.ml.entities import AzureDataLakeGen2Datastore
+from azure.ai.ml.entities._assets import Environment, Model, ModelPackage
+from azure.ai.ml.entities._assets._artifacts.code import Code
+from azure.ai.ml.entities._assets.workspace_asset_reference import WorkspaceAssetReference
+from azure.ai.ml.entities._credentials import AccountKeyConfiguration
+from azure.ai.ml.exceptions import (
+ AssetPathException,
+ ErrorCategory,
+ ErrorTarget,
+ ValidationErrorType,
+ ValidationException,
+)
+from azure.ai.ml.operations._datastore_operations import DatastoreOperations
+from azure.core.exceptions import ResourceNotFoundError
+
+from ._operation_orchestrator import OperationOrchestrator
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+# pylint: disable=too-many-instance-attributes
+class ModelOperations(_ScopeDependentOperations):
+ """ModelOperations.
+
+ You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it
+ for you and attaches it as an attribute.
+
+ :param operation_scope: Scope variables for the operations classes of an MLClient object.
+ :type operation_scope: ~azure.ai.ml._scope_dependent_operations.OperationScope
+ :param operation_config: Common configuration for operations classes of an MLClient object.
+ :type operation_config: ~azure.ai.ml._scope_dependent_operations.OperationConfig
+ :param service_client: Service client to allow end users to operate on Azure Machine Learning Workspace
+ resources (ServiceClient082023Preview or ServiceClient102021Dataplane).
+ :type service_client: typing.Union[
+ azure.ai.ml._restclient.v2023_04_01_preview._azure_machine_learning_workspaces.AzureMachineLearningWorkspaces,
+ azure.ai.ml._restclient.v2021_10_01_dataplanepreview._azure_machine_learning_workspaces.
+ AzureMachineLearningWorkspaces]
+ :param datastore_operations: Represents a client for performing operations on Datastores.
+ :type datastore_operations: ~azure.ai.ml.operations._datastore_operations.DatastoreOperations
+ :param all_operations: All operations classes of an MLClient object.
+ :type all_operations: ~azure.ai.ml._scope_dependent_operations.OperationsContainer
+ """
+
+ _IS_EVALUATOR = "__is_evaluator"
+
+ # pylint: disable=unused-argument
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client: Union[ServiceClient082023Preview, ServiceClient102021Dataplane],
+ datastore_operations: DatastoreOperations,
+ all_operations: Optional[OperationsContainer] = None,
+ **kwargs,
+ ):
+ super(ModelOperations, self).__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+ self._model_versions_operation = service_client.model_versions
+ self._model_container_operation = service_client.model_containers
+ self._service_client = service_client
+ self._datastore_operation = datastore_operations
+ self._all_operations = all_operations
+ self._control_plane_client: Any = kwargs.get("control_plane_client", None)
+ self._workspace_rg = kwargs.pop("workspace_rg", None)
+ self._workspace_sub = kwargs.pop("workspace_sub", None)
+ self._registry_reference = kwargs.pop("registry_reference", None)
+
+ # Maps a label to a function which given an asset name,
+ # returns the asset associated with the label
+ self._managed_label_resolver = {"latest": self._get_latest_version}
+ self.__is_evaluator = kwargs.pop(ModelOperations._IS_EVALUATOR, False)
+
+ @monitor_with_activity(ops_logger, "Model.CreateOrUpdate", ActivityType.PUBLICAPI)
+ def create_or_update( # type: ignore
+ self, model: Union[Model, WorkspaceAssetReference]
+ ) -> Model: # TODO: Are we going to implement job_name?
+ """Returns created or updated model asset.
+
+ :param model: Model asset object.
+ :type model: ~azure.ai.ml.entities.Model
+ :raises ~azure.ai.ml.exceptions.AssetPathException: Raised when the Model artifact path is
+ already linked to another asset
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Model cannot be successfully validated.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty directory.
+ :return: Model asset object.
+ :rtype: ~azure.ai.ml.entities.Model
+ """
+ # Check if we have the model with the same name and it is an
+ # evaluator. In this aces raise the exception do not create the model.
+ if not self.__is_evaluator and _is_evaluator(model.properties):
+ msg = (
+ "Unable to create the evaluator using ModelOperations. To create "
+ "evaluator, please use EvaluatorOperations by calling "
+ "ml_client.evaluators.create_or_update(model) instead."
+ )
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.MODEL,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ if model.name is not None:
+ model_properties = self._get_model_properties(model.name)
+ if model_properties is not None and _is_evaluator(model_properties) != _is_evaluator(model.properties):
+ if _is_evaluator(model.properties):
+ msg = (
+ f"Unable to create the model with name {model.name} "
+ "because this version of model was marked as promptflow evaluator, but the previous "
+ "version is a regular model. "
+ "Please change the model name and try again."
+ )
+ else:
+ msg = (
+ f"Unable to create the model with name {model.name} "
+ "because previous version of model was marked as promptflow evaluator, but this "
+ "version is a regular model. "
+ "Please change the model name and try again."
+ )
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.MODEL,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ try:
+ name = model.name
+ if not model.version and model._auto_increment_version:
+ model.version = _get_next_version_from_container(
+ name=model.name,
+ container_operation=self._model_container_operation,
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ registry_name=self._registry_name,
+ )
+
+ version = model.version
+
+ sas_uri = None
+
+ if self._registry_name:
+ # Case of copy model to registry
+ if isinstance(model, WorkspaceAssetReference):
+ # verify that model is not already in registry
+ try:
+ self._model_versions_operation.get(
+ name=model.name,
+ version=model.version,
+ resource_group_name=self._resource_group_name,
+ registry_name=self._registry_name,
+ )
+ except Exception as err: # pylint: disable=W0718
+ if isinstance(err, ResourceNotFoundError):
+ pass
+ else:
+ raise err
+ else:
+ msg = "A model with this name and version already exists in registry"
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.MODEL,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ model_rest = model._to_rest_object()
+ result = self._service_client.resource_management_asset_reference.begin_import_method(
+ resource_group_name=self._resource_group_name,
+ registry_name=self._registry_name,
+ body=model_rest,
+ ).result()
+
+ if not result:
+ model_rest_obj = self._get(name=str(model.name), version=model.version)
+ return Model._from_rest_object(model_rest_obj)
+
+ sas_uri = get_sas_uri_for_registry_asset(
+ service_client=self._service_client,
+ name=model.name,
+ version=model.version,
+ resource_group=self._resource_group_name,
+ registry=self._registry_name,
+ body=get_asset_body_for_registry_storage(self._registry_name, "models", model.name, model.version),
+ )
+
+ model, indicator_file = _check_and_upload_path( # type: ignore[type-var]
+ artifact=model,
+ asset_operations=self,
+ sas_uri=sas_uri,
+ artifact_type=ErrorTarget.MODEL,
+ show_progress=self._show_progress,
+ )
+
+ model.path = resolve_short_datastore_url(model.path, self._operation_scope) # type: ignore
+ validate_ml_flow_folder(model.path, model.type) # type: ignore
+ model_version_resource = model._to_rest_object()
+ auto_increment_version = model._auto_increment_version
+ try:
+ result = (
+ self._model_versions_operation.begin_create_or_update(
+ name=name,
+ version=version,
+ body=model_version_resource,
+ registry_name=self._registry_name,
+ **self._scope_kwargs,
+ ).result()
+ if self._registry_name
+ else self._model_versions_operation.create_or_update(
+ name=name,
+ version=version,
+ body=model_version_resource,
+ workspace_name=self._workspace_name,
+ **self._scope_kwargs,
+ )
+ )
+
+ if not result and self._registry_name:
+ result = self._get(name=str(model.name), version=model.version)
+
+ except Exception as e:
+ # service side raises an exception if we attempt to update an existing asset's path
+ if str(e) == ASSET_PATH_ERROR:
+ raise AssetPathException(
+ message=CHANGED_ASSET_PATH_MSG,
+ target=ErrorTarget.MODEL,
+ no_personal_data_message=CHANGED_ASSET_PATH_MSG_NO_PERSONAL_DATA,
+ error_category=ErrorCategory.USER_ERROR,
+ ) from e
+ raise e
+
+ model = Model._from_rest_object(result)
+ if auto_increment_version and indicator_file:
+ datastore_info = _get_default_datastore_info(self._datastore_operation)
+ _update_metadata(model.name, model.version, indicator_file, datastore_info) # update version in storage
+
+ return model
+ except Exception as ex: # pylint: disable=W0718
+ if isinstance(ex, SchemaValidationError):
+ log_and_raise_error(ex)
+ else:
+ raise ex
+
+ def _get(self, name: str, version: Optional[str] = None) -> ModelVersion: # name:latest
+ if version:
+ return (
+ self._model_versions_operation.get(
+ name=name,
+ version=version,
+ registry_name=self._registry_name,
+ **self._scope_kwargs,
+ )
+ if self._registry_name
+ else self._model_versions_operation.get(
+ name=name,
+ version=version,
+ workspace_name=self._workspace_name,
+ **self._scope_kwargs,
+ )
+ )
+
+ return (
+ self._model_container_operation.get(name=name, registry_name=self._registry_name, **self._scope_kwargs)
+ if self._registry_name
+ else self._model_container_operation.get(
+ name=name, workspace_name=self._workspace_name, **self._scope_kwargs
+ )
+ )
+
+ @monitor_with_activity(ops_logger, "Model.Get", ActivityType.PUBLICAPI)
+ def get(self, name: str, version: Optional[str] = None, label: Optional[str] = None) -> Model:
+ """Returns information about the specified model asset.
+
+ :param name: Name of the model.
+ :type name: str
+ :param version: Version of the model.
+ :type version: str
+ :param label: Label of the model. (mutually exclusive with version)
+ :type label: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Model cannot be successfully validated.
+ Details will be provided in the error message.
+ :return: Model asset object.
+ :rtype: ~azure.ai.ml.entities.Model
+ """
+ if version and label:
+ msg = "Cannot specify both version and label."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.MODEL,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ if label:
+ return _resolve_label_to_asset(self, name, label)
+
+ if not version:
+ msg = "Must provide either version or label"
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.MODEL,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+ # TODO: We should consider adding an exception trigger for internal_model=None
+ model_version_resource = self._get(name, version)
+
+ return Model._from_rest_object(model_version_resource)
+
+ @monitor_with_activity(ops_logger, "Model.Download", ActivityType.PUBLICAPI)
+ def download(self, name: str, version: str, download_path: Union[PathLike, str] = ".") -> None:
+ """Download files related to a model.
+
+ :param name: Name of the model.
+ :type name: str
+ :param version: Version of the model.
+ :type version: str
+ :param download_path: Local path as download destination, defaults to current working directory of the current
+ user. Contents will be overwritten.
+ :type download_path: Union[PathLike, str]
+ :raises ResourceNotFoundError: if can't find a model matching provided name.
+ """
+
+ model_uri = self.get(name=name, version=version).path
+ ds_name, path_prefix = get_ds_name_and_path_prefix(model_uri, self._registry_name)
+ if self._registry_name:
+ sas_uri, auth_type = get_storage_details_for_registry_assets(
+ service_client=self._service_client,
+ asset_name=name,
+ asset_version=version,
+ reg_name=self._registry_name,
+ asset_type=AzureMLResourceType.MODEL,
+ rg_name=self._resource_group_name,
+ uri=model_uri,
+ )
+ if auth_type == "SAS":
+ storage_client = get_storage_client(credential=None, storage_account=None, account_url=sas_uri)
+ else:
+ parts = sas_uri.split("/")
+ storage_account = parts[2].split(".")[0]
+ container_name = parts[3]
+ storage_client = get_storage_client(
+ credential=None,
+ storage_account=storage_account,
+ container_name=container_name,
+ )
+
+ else:
+ ds = self._datastore_operation.get(ds_name, include_secrets=True)
+ acc_name = ds.account_name
+
+ if isinstance(ds.credentials, AccountKeyConfiguration):
+ credential = ds.credentials.account_key
+ else:
+ try:
+ credential = ds.credentials.sas_token
+ except Exception as e: # pylint: disable=W0718
+ if not hasattr(ds.credentials, "sas_token"):
+ credential = self._datastore_operation._credential
+ else:
+ raise e
+
+ if isinstance(ds, AzureDataLakeGen2Datastore):
+ container = ds.filesystem
+ try:
+ from azure.identity import ClientSecretCredential
+
+ token_credential = ClientSecretCredential(
+ tenant_id=ds.credentials["tenant_id"],
+ client_id=ds.credentials["client_id"],
+ client_secret=ds.credentials["client_secret"],
+ authority=ds.credentials["authority_url"],
+ )
+ credential = token_credential
+ except (KeyError, TypeError):
+ pass
+
+ else:
+ container = ds.container_name
+ datastore_type = ds.type
+ storage_client = get_storage_client(
+ credential=credential,
+ container_name=container,
+ storage_account=acc_name,
+ storage_type=datastore_type,
+ )
+
+ path_file = "{}{}{}".format(download_path, path.sep, name)
+ is_directory = storage_client.exists(f"{path_prefix.rstrip('/')}/")
+ if is_directory:
+ path_file = path.join(path_file, path.basename(path_prefix.rstrip("/")))
+ module_logger.info("Downloading the model %s at %s\n", path_prefix, path_file)
+ storage_client.download(starts_with=path_prefix, destination=path_file)
+
+ @monitor_with_activity(ops_logger, "Model.Archive", ActivityType.PUBLICAPI)
+ def archive(
+ self,
+ name: str,
+ version: Optional[str] = None,
+ label: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Archive a model asset.
+
+ :param name: Name of model asset.
+ :type name: str
+ :param version: Version of model asset.
+ :type version: str
+ :param label: Label of the model asset. (mutually exclusive with version)
+ :type label: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START model_operations_archive]
+ :end-before: [END model_operations_archive]
+ :language: python
+ :dedent: 8
+ :caption: Archive a model.
+ """
+ _archive_or_restore(
+ asset_operations=self,
+ version_operation=self._model_versions_operation,
+ container_operation=self._model_container_operation,
+ is_archived=True,
+ name=name,
+ version=version,
+ label=label,
+ )
+
+ @monitor_with_activity(ops_logger, "Model.Restore", ActivityType.PUBLICAPI)
+ def restore(
+ self,
+ name: str,
+ version: Optional[str] = None,
+ label: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
+ """Restore an archived model asset.
+
+ :param name: Name of model asset.
+ :type name: str
+ :param version: Version of model asset.
+ :type version: str
+ :param label: Label of the model asset. (mutually exclusive with version)
+ :type label: str
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_misc.py
+ :start-after: [START model_operations_restore]
+ :end-before: [END model_operations_restore]
+ :language: python
+ :dedent: 8
+ :caption: Restore an archived model.
+ """
+ _archive_or_restore(
+ asset_operations=self,
+ version_operation=self._model_versions_operation,
+ container_operation=self._model_container_operation,
+ is_archived=False,
+ name=name,
+ version=version,
+ label=label,
+ )
+
+ @monitor_with_activity(ops_logger, "Model.List", ActivityType.PUBLICAPI)
+ def list(
+ self,
+ name: Optional[str] = None,
+ stage: Optional[str] = None,
+ *,
+ list_view_type: ListViewType = ListViewType.ACTIVE_ONLY,
+ ) -> Iterable[Model]:
+ """List all model assets in workspace.
+
+ :param name: Name of the model.
+ :type name: Optional[str]
+ :param stage: The Model stage
+ :type stage: Optional[str]
+ :keyword list_view_type: View type for including/excluding (for example) archived models.
+ Defaults to :attr:`ListViewType.ACTIVE_ONLY`.
+ :paramtype list_view_type: ListViewType
+ :return: An iterator like instance of Model objects
+ :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.Model]
+ """
+ if name:
+ return cast(
+ Iterable[Model],
+ (
+ self._model_versions_operation.list(
+ name=name,
+ registry_name=self._registry_name,
+ cls=lambda objs: [Model._from_rest_object(obj) for obj in objs],
+ **self._scope_kwargs,
+ )
+ if self._registry_name
+ else self._model_versions_operation.list(
+ name=name,
+ workspace_name=self._workspace_name,
+ cls=lambda objs: [Model._from_rest_object(obj) for obj in objs],
+ list_view_type=list_view_type,
+ stage=stage,
+ **self._scope_kwargs,
+ )
+ ),
+ )
+
+ return cast(
+ Iterable[Model],
+ (
+ self._model_container_operation.list(
+ registry_name=self._registry_name,
+ cls=lambda objs: [Model._from_container_rest_object(obj) for obj in objs],
+ list_view_type=list_view_type,
+ **self._scope_kwargs,
+ )
+ if self._registry_name
+ else self._model_container_operation.list(
+ workspace_name=self._workspace_name,
+ cls=lambda objs: [Model._from_container_rest_object(obj) for obj in objs],
+ list_view_type=list_view_type,
+ **self._scope_kwargs,
+ )
+ ),
+ )
+
+ @monitor_with_activity(ops_logger, "Model.Share", ActivityType.PUBLICAPI)
+ @experimental
+ def share(
+ self, name: str, version: str, *, share_with_name: str, share_with_version: str, registry_name: str
+ ) -> Model:
+ """Share a model asset from workspace to registry.
+
+ :param name: Name of model asset.
+ :type name: str
+ :param version: Version of model asset.
+ :type version: str
+ :keyword share_with_name: Name of model asset to share with.
+ :paramtype share_with_name: str
+ :keyword share_with_version: Version of model asset to share with.
+ :paramtype share_with_version: str
+ :keyword registry_name: Name of the destination registry.
+ :paramtype registry_name: str
+ :return: Model asset object.
+ :rtype: ~azure.ai.ml.entities.Model
+ """
+
+ # Get workspace info to get workspace GUID
+ workspace = self._service_client.workspaces.get(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ )
+ workspace_guid = workspace.workspace_id
+ workspace_location = workspace.location
+
+ # Get model asset ID
+ asset_id = ASSET_ID_FORMAT.format(
+ workspace_location,
+ workspace_guid,
+ AzureMLResourceType.MODEL,
+ name,
+ version,
+ )
+
+ model_ref = WorkspaceAssetReference(
+ name=share_with_name if share_with_name else name,
+ version=share_with_version if share_with_version else version,
+ asset_id=asset_id,
+ )
+
+ with self._set_registry_client(registry_name):
+ return self.create_or_update(model_ref)
+
+ def _get_latest_version(self, name: str) -> Model:
+ """Returns the latest version of the asset with the given name.
+
+ Latest is defined as the most recently created, not the most recently updated.
+ """
+ result = _get_latest(
+ name,
+ self._model_versions_operation,
+ self._resource_group_name,
+ self._workspace_name,
+ self._registry_name,
+ )
+ return Model._from_rest_object(result)
+
+ @contextmanager
+ def _set_registry_client(self, registry_name: str) -> Generator:
+ """Sets the registry client for the model operations.
+
+ :param registry_name: Name of the registry.
+ :type registry_name: str
+ """
+ rg_ = self._operation_scope._resource_group_name
+ sub_ = self._operation_scope._subscription_id
+ registry_ = self._operation_scope.registry_name
+ client_ = self._service_client
+ model_versions_operation_ = self._model_versions_operation
+
+ try:
+ _client, _rg, _sub = get_registry_client(self._service_client._config.credential, registry_name)
+ self._operation_scope.registry_name = registry_name
+ self._operation_scope._resource_group_name = _rg
+ self._operation_scope._subscription_id = _sub
+ self._service_client = _client
+ self._model_versions_operation = _client.model_versions
+ yield
+ finally:
+ self._operation_scope.registry_name = registry_
+ self._operation_scope._resource_group_name = rg_
+ self._operation_scope._subscription_id = sub_
+ self._service_client = client_
+ self._model_versions_operation = model_versions_operation_
+
+ @experimental
+ @monitor_with_activity(ops_logger, "Model.Package", ActivityType.PUBLICAPI)
+ def package(self, name: str, version: str, package_request: ModelPackage, **kwargs: Any) -> Environment:
+ """Package a model asset
+
+ :param name: Name of model asset.
+ :type name: str
+ :param version: Version of model asset.
+ :type version: str
+ :param package_request: Model package request.
+ :type package_request: ~azure.ai.ml.entities.ModelPackage
+ :return: Environment object
+ :rtype: ~azure.ai.ml.entities.Environment
+ """
+
+ is_deployment_flow = kwargs.pop("skip_to_rest", False)
+ if not is_deployment_flow:
+ orchestrators = OperationOrchestrator(
+ operation_container=self._all_operations, # type: ignore[arg-type]
+ operation_scope=self._operation_scope,
+ operation_config=self._operation_config,
+ )
+
+ # Create a code asset if code is not already an ARM ID
+ if hasattr(package_request.inferencing_server, "code_configuration"):
+ if package_request.inferencing_server.code_configuration and not is_ARM_id_for_resource(
+ package_request.inferencing_server.code_configuration.code,
+ AzureMLResourceType.CODE,
+ ):
+ if package_request.inferencing_server.code_configuration.code.startswith(ARM_ID_PREFIX):
+ package_request.inferencing_server.code_configuration.code = orchestrators.get_asset_arm_id(
+ package_request.inferencing_server.code_configuration.code[len(ARM_ID_PREFIX) :],
+ azureml_type=AzureMLResourceType.CODE,
+ )
+ else:
+ package_request.inferencing_server.code_configuration.code = orchestrators.get_asset_arm_id(
+ Code(
+ base_path=package_request._base_path,
+ path=package_request.inferencing_server.code_configuration.code,
+ ),
+ azureml_type=AzureMLResourceType.CODE,
+ )
+ if package_request.inferencing_server.code_configuration and hasattr(
+ package_request.inferencing_server.code_configuration, "code"
+ ):
+ package_request.inferencing_server.code_configuration.code = (
+ "azureml:/" + package_request.inferencing_server.code_configuration.code
+ )
+
+ if package_request.base_environment_source and hasattr(
+ package_request.base_environment_source, "resource_id"
+ ):
+ if not package_request.base_environment_source.resource_id.startswith(REGISTRY_URI_FORMAT):
+ package_request.base_environment_source.resource_id = orchestrators.get_asset_arm_id(
+ package_request.base_environment_source.resource_id,
+ azureml_type=AzureMLResourceType.ENVIRONMENT,
+ )
+
+ package_request.base_environment_source.resource_id = (
+ "azureml:/" + package_request.base_environment_source.resource_id
+ if not package_request.base_environment_source.resource_id.startswith(ARM_ID_PREFIX)
+ else package_request.base_environment_source.resource_id
+ )
+
+ # create ARM id for the target environment
+ if self._operation_scope._workspace_location and self._operation_scope._workspace_id:
+ package_request.target_environment_id = f"azureml://locations/{self._operation_scope._workspace_location}/workspaces/{self._operation_scope._workspace_id}/environments/{package_request.target_environment_id}"
+ else:
+ if self._all_operations is not None:
+ ws: Any = self._all_operations.all_operations.get("workspaces")
+ ws_details = ws.get(self._workspace_name)
+ workspace_location, workspace_id = (
+ ws_details.location,
+ ws_details._workspace_id,
+ )
+ package_request.target_environment_id = f"azureml://locations/{workspace_location}/workspaces/{workspace_id}/environments/{package_request.target_environment_id}"
+
+ if package_request.environment_version is not None:
+ package_request.target_environment_id = (
+ package_request.target_environment_id + f"/versions/{package_request.environment_version}"
+ )
+ package_request = package_request._to_rest_object()
+
+ if self._registry_reference:
+ package_request.target_environment_id = f"azureml://locations/{self._operation_scope._workspace_location}/workspaces/{self._operation_scope._workspace_id}/environments/{package_request.target_environment_id}"
+ package_out = (
+ self._model_versions_operation.begin_package(
+ name=name,
+ version=version,
+ registry_name=self._registry_name if self._registry_name else self._registry_reference,
+ body=package_request,
+ **self._scope_kwargs,
+ ).result()
+ if self._registry_name or self._registry_reference
+ else self._model_versions_operation.begin_package(
+ name=name,
+ version=version,
+ workspace_name=self._workspace_name,
+ body=package_request,
+ **self._scope_kwargs,
+ ).result()
+ )
+ if is_deployment_flow: # No need to go through the schema, as this is for deployment notification only
+ return package_out
+ if hasattr(package_out, "target_environment_id"):
+ environment_id = package_out.target_environment_id
+ else:
+ environment_id = package_out.additional_properties["targetEnvironmentId"]
+
+ pattern = r"azureml://locations/(\w+)/workspaces/([\w-]+)/environments/([\w.-]+)/versions/(\d+)"
+ parsed_id: Any = re.search(pattern, environment_id)
+
+ if parsed_id:
+ environment_name = parsed_id.group(3)
+ environment_version = parsed_id.group(4)
+ else:
+ parsed_id = AMLVersionedArmId(environment_id)
+ environment_name = parsed_id.asset_name
+ environment_version = parsed_id.asset_version
+
+ module_logger.info("\nPackage Created")
+ if package_out is not None and package_out.__class__.__name__ == "PackageResponse":
+ if self._registry_name:
+ current_rg = self._scope_kwargs.pop("resource_group_name", None)
+ self._scope_kwargs["resource_group_name"] = self._workspace_rg
+ self._control_plane_client._config.subscription_id = self._workspace_sub
+ env_out = self._control_plane_client.environment_versions.get(
+ name=environment_name,
+ version=environment_version,
+ workspace_name=self._workspace_name,
+ **self._scope_kwargs,
+ )
+ package_out = Environment._from_rest_object(env_out)
+ self._scope_kwargs["resource_group_name"] = current_rg
+ else:
+ if self._all_operations is not None:
+ environment_operation = self._all_operations.all_operations[AzureMLResourceType.ENVIRONMENT]
+ package_out = environment_operation.get(name=environment_name, version=environment_version)
+
+ return package_out
+
+ def _get_model_properties(
+ self, name: str, version: Optional[str] = None, label: Optional[str] = None
+ ) -> Optional[Dict[str, Any]]:
+ """
+ Return the model properties if the model with this name exists.
+
+ :param name: Model name.
+ :type name: str
+ :param version: Model version.
+ :type version: Optional[str]
+ :param label: model label.
+ :type label: Optional[str]
+ :return: Model properties, if the model exists, or None.
+ """
+ try:
+ if version or label:
+ return self.get(name, version, label).properties
+ return self._get_latest_version(name).properties
+ except (ResourceNotFoundError, ValidationException):
+ return None
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_online_deployment_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_online_deployment_operations.py
new file mode 100644
index 00000000..13fe2357
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_online_deployment_operations.py
@@ -0,0 +1,415 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access,broad-except
+
+import random
+import re
+import subprocess
+from typing import Any, Dict, Optional
+
+from marshmallow.exceptions import ValidationError as SchemaValidationError
+
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._local_endpoints import LocalEndpointMode
+from azure.ai.ml._restclient.v2022_02_01_preview.models import DeploymentLogsRequest
+from azure.ai.ml._restclient.v2023_04_01_preview import AzureMachineLearningWorkspaces as ServiceClient042023Preview
+from azure.ai.ml._scope_dependent_operations import (
+ OperationConfig,
+ OperationsContainer,
+ OperationScope,
+ _ScopeDependentOperations,
+)
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._arm_id_utils import AMLVersionedArmId
+from azure.ai.ml._utils._azureml_polling import AzureMLPolling
+from azure.ai.ml._utils._endpoint_utils import upload_dependencies, validate_scoring_script
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._utils._package_utils import package_deployment
+from azure.ai.ml.constants._common import ARM_ID_PREFIX, AzureMLResourceType, LROConfigurations
+from azure.ai.ml.constants._deployment import DEFAULT_MDC_PATH, EndpointDeploymentLogContainerType, SmallSKUs
+from azure.ai.ml.entities import Data, OnlineDeployment
+from azure.ai.ml.exceptions import (
+ ErrorCategory,
+ ErrorTarget,
+ InvalidVSCodeRequestError,
+ LocalDeploymentGPUNotAvailable,
+ ValidationErrorType,
+ ValidationException,
+)
+from azure.core.credentials import TokenCredential
+from azure.core.paging import ItemPaged
+from azure.core.polling import LROPoller
+from azure.core.tracing.decorator import distributed_trace
+
+from ._local_deployment_helper import _LocalDeploymentHelper
+from ._operation_orchestrator import OperationOrchestrator
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class OnlineDeploymentOperations(_ScopeDependentOperations):
+ """OnlineDeploymentOperations.
+
+ You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it
+ for you and attaches it as an attribute.
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client_04_2023_preview: ServiceClient042023Preview,
+ all_operations: OperationsContainer,
+ local_deployment_helper: _LocalDeploymentHelper,
+ credentials: Optional[TokenCredential] = None,
+ **kwargs: Dict,
+ ):
+ super(OnlineDeploymentOperations, self).__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+ self._local_deployment_helper = local_deployment_helper
+ self._online_deployment = service_client_04_2023_preview.online_deployments
+ self._online_endpoint_operations = service_client_04_2023_preview.online_endpoints
+ self._all_operations = all_operations
+ self._credentials = credentials
+ self._init_kwargs = kwargs
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "OnlineDeployment.BeginCreateOrUpdate", ActivityType.PUBLICAPI)
+ def begin_create_or_update(
+ self,
+ deployment: OnlineDeployment,
+ *,
+ local: bool = False,
+ vscode_debug: bool = False,
+ skip_script_validation: bool = False,
+ local_enable_gpu: bool = False,
+ **kwargs: Any,
+ ) -> LROPoller[OnlineDeployment]:
+ """Create or update a deployment.
+
+ :param deployment: the deployment entity
+ :type deployment: ~azure.ai.ml.entities.OnlineDeployment
+ :keyword local: Whether deployment should be created locally, defaults to False
+ :paramtype local: bool
+ :keyword vscode_debug: Whether to open VSCode instance to debug local deployment, defaults to False
+ :paramtype vscode_debug: bool
+ :keyword skip_script_validation: Whether or not to skip validation of the deployment script. Defaults to False.
+ :paramtype skip_script_validation: bool
+ :keyword local_enable_gpu: enable local container to access gpu
+ :paramtype local_enable_gpu: bool
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if OnlineDeployment cannot
+ be successfully validated. Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.AssetException: Raised if OnlineDeployment assets
+ (e.g. Data, Code, Model, Environment) cannot be successfully validated.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.ModelException: Raised if OnlineDeployment model cannot be
+ successfully validated. Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.DeploymentException: Raised if OnlineDeployment type is unsupported.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.LocalEndpointNotFoundError: Raised if local endpoint resource does not exist.
+ :raises ~azure.ai.ml.exceptions.LocalEndpointInFailedStateError: Raised if local endpoint is in a failed state.
+ :raises ~azure.ai.ml.exceptions.InvalidLocalEndpointError: Raised if Docker image cannot be
+ found for local deployment.
+ :raises ~azure.ai.ml.exceptions.LocalEndpointImageBuildError: Raised if Docker image cannot be
+ successfully built for local deployment.
+ :raises ~azure.ai.ml.exceptions.RequiredLocalArtifactsNotFoundError: Raised if local artifacts cannot be
+ found for local deployment.
+ :raises ~azure.ai.ml.exceptions.InvalidVSCodeRequestError: Raised if VS Debug is invoked with a remote endpoint.
+ VSCode debug is only supported for local endpoints.
+ :raises ~azure.ai.ml.exceptions.LocalDeploymentGPUNotAvailable: Raised if Nvidia GPU is not available in the
+ system and local_enable_gpu is set while local deployment
+ :raises ~azure.ai.ml.exceptions.VSCodeCommandNotFound: Raised if VSCode instance cannot be instantiated.
+ :return: A poller to track the operation status
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.OnlineDeployment]
+ """
+ try:
+ if vscode_debug and not local:
+ raise InvalidVSCodeRequestError(
+ msg="VSCode Debug is only support for local endpoints. Please set local to True."
+ )
+ if local:
+ if local_enable_gpu:
+ try:
+ subprocess.run("nvidia-smi", check=True)
+ except Exception as ex:
+ raise LocalDeploymentGPUNotAvailable(
+ msg=(
+ "Nvidia GPU is not available in your local system."
+ " Use nvidia-smi command to see the available GPU"
+ )
+ ) from ex
+ return self._local_deployment_helper.create_or_update(
+ deployment=deployment,
+ local_endpoint_mode=self._get_local_endpoint_mode(vscode_debug),
+ local_enable_gpu=local_enable_gpu,
+ )
+ if deployment and deployment.instance_type and deployment.instance_type.lower() in SmallSKUs:
+ module_logger.warning(
+ "Instance type %s may be too small for compute resources. "
+ "Minimum recommended compute SKU is Standard_DS3_v2 for general purpose endpoints. Learn more about SKUs here: " # pylint: disable=line-too-long
+ "https://learn.microsoft.com/azure/machine-learning/referencemanaged-online-endpoints-vm-sku-list",
+ deployment.instance_type,
+ )
+ if (
+ not skip_script_validation
+ and deployment
+ and deployment.code_configuration
+ and not deployment.code_configuration.code.startswith(ARM_ID_PREFIX) # type: ignore[union-attr]
+ and not re.match(AMLVersionedArmId.REGEX_PATTERN, deployment.code_configuration.code) # type: ignore
+ ):
+ validate_scoring_script(deployment)
+
+ path_format_arguments = {
+ "endpointName": deployment.name,
+ "resourceGroupName": self._resource_group_name,
+ "workspaceName": self._workspace_name,
+ }
+
+ # This get() is to ensure, the endpoint exists and fail before even start the deployment
+ module_logger.info("Check: endpoint %s exists", deployment.endpoint_name)
+ self._online_endpoint_operations.get(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=deployment.endpoint_name,
+ )
+ orchestrators = OperationOrchestrator(
+ operation_container=self._all_operations,
+ operation_scope=self._operation_scope,
+ operation_config=self._operation_config,
+ )
+ if deployment.data_collector:
+ self._register_collection_data_assets(deployment=deployment)
+
+ upload_dependencies(deployment, orchestrators)
+ try:
+ location = self._get_workspace_location()
+ is_package_model = deployment.package_model if hasattr(deployment, "package_model") else False
+ if kwargs.pop("package_model", False) or is_package_model:
+ deployment = package_deployment(deployment, self._all_operations.all_operations["models"])
+ module_logger.info("\nStarting deployment")
+
+ deployment_rest = deployment._to_rest_object(location=location) # type: ignore
+
+ poller = self._online_deployment.begin_create_or_update(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=deployment.endpoint_name,
+ deployment_name=deployment.name,
+ body=deployment_rest,
+ polling=AzureMLPolling(
+ LROConfigurations.POLL_INTERVAL,
+ path_format_arguments=path_format_arguments,
+ **self._init_kwargs,
+ ),
+ polling_interval=LROConfigurations.POLL_INTERVAL,
+ **self._init_kwargs,
+ cls=lambda response, deserialized, headers: OnlineDeployment._from_rest_object(deserialized),
+ )
+ return poller
+ except Exception as ex:
+ raise ex
+ except Exception as ex: # pylint: disable=W0718
+ if isinstance(ex, (ValidationException, SchemaValidationError)):
+ log_and_raise_error(ex)
+ else:
+ raise ex
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "OnlineDeployment.Get", ActivityType.PUBLICAPI)
+ def get(self, name: str, endpoint_name: str, *, local: Optional[bool] = False) -> OnlineDeployment:
+ """Get a deployment resource.
+
+ :param name: The name of the deployment
+ :type name: str
+ :param endpoint_name: The name of the endpoint
+ :type endpoint_name: str
+ :keyword local: Whether deployment should be retrieved from local docker environment, defaults to False
+ :paramtype local: Optional[bool]
+ :raises ~azure.ai.ml.exceptions.LocalEndpointNotFoundError: Raised if local endpoint resource does not exist.
+ :return: a deployment entity
+ :rtype: ~azure.ai.ml.entities.OnlineDeployment
+ """
+ if local:
+ deployment = self._local_deployment_helper.get(endpoint_name=endpoint_name, deployment_name=name)
+ else:
+ deployment = OnlineDeployment._from_rest_object(
+ self._online_deployment.get(
+ endpoint_name=endpoint_name,
+ deployment_name=name,
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ **self._init_kwargs,
+ )
+ )
+
+ deployment.endpoint_name = endpoint_name
+ return deployment
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "OnlineDeployment.Delete", ActivityType.PUBLICAPI)
+ def begin_delete(self, name: str, endpoint_name: str, *, local: Optional[bool] = False) -> LROPoller[None]:
+ """Delete a deployment.
+
+ :param name: The name of the deployment
+ :type name: str
+ :param endpoint_name: The name of the endpoint
+ :type endpoint_name: str
+ :keyword local: Whether deployment should be retrieved from local docker environment, defaults to False
+ :paramtype local: Optional[bool]
+ :raises ~azure.ai.ml.exceptions.LocalEndpointNotFoundError: Raised if local endpoint resource does not exist.
+ :return: A poller to track the operation status
+ :rtype: ~azure.core.polling.LROPoller[None]
+ """
+ if local:
+ return self._local_deployment_helper.delete(name=endpoint_name, deployment_name=name)
+ return self._online_deployment.begin_delete(
+ endpoint_name=endpoint_name,
+ deployment_name=name,
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ **self._init_kwargs,
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "OnlineDeployment.GetLogs", ActivityType.PUBLICAPI)
+ def get_logs(
+ self,
+ name: str,
+ endpoint_name: str,
+ lines: int,
+ *,
+ container_type: Optional[str] = None,
+ local: bool = False,
+ ) -> str:
+ """Retrive the logs from online deployment.
+
+ :param name: The name of the deployment
+ :type name: str
+ :param endpoint_name: The name of the endpoint
+ :type endpoint_name: str
+ :param lines: The maximum number of lines to tail
+ :type lines: int
+ :keyword container_type: The type of container to retrieve logs from. Possible values include:
+ "StorageInitializer", "InferenceServer", defaults to None
+ :type container_type: Optional[str]
+ :keyword local: [description], defaults to False
+ :paramtype local: bool
+ :return: the logs
+ :rtype: str
+ """
+ if local:
+ return self._local_deployment_helper.get_deployment_logs(
+ endpoint_name=endpoint_name, deployment_name=name, lines=lines
+ )
+ if container_type:
+ container_type = self._validate_deployment_log_container_type(container_type) # type: ignore
+ log_request = DeploymentLogsRequest(container_type=container_type, tail=lines)
+ return str(
+ self._online_deployment.get_logs(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=endpoint_name,
+ deployment_name=name,
+ body=log_request,
+ **self._init_kwargs,
+ ).content
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "OnlineDeployment.List", ActivityType.PUBLICAPI)
+ def list(self, endpoint_name: str, *, local: bool = False) -> ItemPaged[OnlineDeployment]:
+ """List a deployment resource.
+
+ :param endpoint_name: The name of the endpoint
+ :type endpoint_name: str
+ :keyword local: Whether deployment should be retrieved from local docker environment, defaults to False
+ :paramtype local: bool
+ :return: an iterator of deployment entities
+ :rtype: Iterable[~azure.ai.ml.entities.OnlineDeployment]
+ """
+ if local:
+ return self._local_deployment_helper.list()
+ return self._online_deployment.list(
+ endpoint_name=endpoint_name,
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ cls=lambda objs: [OnlineDeployment._from_rest_object(obj) for obj in objs],
+ **self._init_kwargs,
+ )
+
+ def _validate_deployment_log_container_type(self, container_type: EndpointDeploymentLogContainerType) -> str:
+ if container_type == EndpointDeploymentLogContainerType.INFERENCE_SERVER:
+ return EndpointDeploymentLogContainerType.INFERENCE_SERVER_REST
+
+ if container_type == EndpointDeploymentLogContainerType.STORAGE_INITIALIZER:
+ return EndpointDeploymentLogContainerType.STORAGE_INITIALIZER_REST
+
+ msg = "Invalid container type '{}'. Supported container types are {} and {}"
+ msg = msg.format(
+ container_type,
+ EndpointDeploymentLogContainerType.INFERENCE_SERVER,
+ EndpointDeploymentLogContainerType.STORAGE_INITIALIZER,
+ )
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.ONLINE_DEPLOYMENT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ def _get_ARM_deployment_name(self, name: str) -> str:
+ random.seed(version=2)
+ return f"{self._workspace_name}-{name}-{random.randint(1, 10000000)}"
+
+ def _get_workspace_location(self) -> str:
+ """Get the workspace location
+
+ TODO[TASK 1260265]: can we cache this information and only refresh when the operation_scope is changed?
+
+ :return: The workspace location
+ :rtype: str
+ """
+ return str(
+ self._all_operations.all_operations[AzureMLResourceType.WORKSPACE].get(self._workspace_name).location
+ )
+
+ def _get_local_endpoint_mode(self, vscode_debug: Any) -> LocalEndpointMode:
+ return LocalEndpointMode.VSCodeDevContainer if vscode_debug else LocalEndpointMode.DetachedContainer
+
+ def _register_collection_data_assets(self, deployment: OnlineDeployment) -> None:
+ for name, value in deployment.data_collector.collections.items():
+ data_name = f"{deployment.endpoint_name}-{deployment.name}-{name}"
+ data_version = "1"
+ data_path = f"{DEFAULT_MDC_PATH}/{deployment.endpoint_name}/{deployment.name}/{name}"
+ if value.data:
+ if value.data.name:
+ data_name = value.data.name
+
+ if value.data.version:
+ data_version = value.data.version
+
+ if value.data.path:
+ data_path = value.data.path
+
+ data_object = Data(
+ name=data_name,
+ version=data_version,
+ path=data_path,
+ )
+
+ try:
+ result = self._all_operations._all_operations[AzureMLResourceType.DATA].create_or_update(data_object)
+ except Exception as e:
+ if "already exists" in str(e):
+ result = self._all_operations._all_operations[AzureMLResourceType.DATA].get(data_name, data_version)
+ else:
+ raise e
+ deployment.data_collector.collections[name].data = (
+ f"/subscriptions/{self._subscription_id}/resourceGroups/{self._resource_group_name}"
+ f"/providers/Microsoft.MachineLearningServices/workspaces/{self._workspace_name}"
+ f"/data/{result.name}/versions/{result.version}"
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_online_endpoint_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_online_endpoint_operations.py
new file mode 100644
index 00000000..6dce4283
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_online_endpoint_operations.py
@@ -0,0 +1,471 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import json
+from typing import Any, Dict, Optional, Union
+
+from marshmallow.exceptions import ValidationError as SchemaValidationError
+
+from azure.ai.ml._azure_environments import _resource_to_scopes
+from azure.ai.ml._exception_helper import log_and_raise_error
+from azure.ai.ml._restclient.v2022_02_01_preview import AzureMachineLearningWorkspaces as ServiceClient022022Preview
+from azure.ai.ml._restclient.v2022_02_01_preview.models import KeyType, RegenerateEndpointKeysRequest
+from azure.ai.ml._scope_dependent_operations import (
+ OperationConfig,
+ OperationsContainer,
+ OperationScope,
+ _ScopeDependentOperations,
+)
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._azureml_polling import AzureMLPolling
+from azure.ai.ml._utils._endpoint_utils import validate_response
+from azure.ai.ml._utils._http_utils import HttpPipeline
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml.constants._common import (
+ AAD_TOKEN,
+ AAD_TOKEN_RESOURCE_ENDPOINT,
+ EMPTY_CREDENTIALS_ERROR,
+ KEY,
+ AzureMLResourceType,
+ LROConfigurations,
+)
+from azure.ai.ml.constants._endpoint import EndpointInvokeFields, EndpointKeyType
+from azure.ai.ml.entities import OnlineDeployment, OnlineEndpoint
+from azure.ai.ml.entities._assets import Data
+from azure.ai.ml.entities._endpoint.online_endpoint import EndpointAadToken, EndpointAuthKeys, EndpointAuthToken
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, MlException, ValidationErrorType, ValidationException
+from azure.ai.ml.operations._local_endpoint_helper import _LocalEndpointHelper
+from azure.core.credentials import TokenCredential
+from azure.core.paging import ItemPaged
+from azure.core.polling import LROPoller
+from azure.core.tracing.decorator import distributed_trace
+
+from ._operation_orchestrator import OperationOrchestrator
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+def _strip_zeroes_from_traffic(traffic: Dict[str, str]) -> Dict[str, str]:
+ return {k.lower(): v for k, v in traffic.items() if v and int(v) != 0}
+
+
+class OnlineEndpointOperations(_ScopeDependentOperations):
+ """OnlineEndpointOperations.
+
+ You should not instantiate this class directly. Instead, you should
+ create an MLClient instance that instantiates it for you and
+ attaches it as an attribute.
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client_02_2022_preview: ServiceClient022022Preview,
+ all_operations: OperationsContainer,
+ local_endpoint_helper: _LocalEndpointHelper,
+ credentials: Optional[TokenCredential] = None,
+ **kwargs: Dict,
+ ):
+ super(OnlineEndpointOperations, self).__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+ self._online_operation = service_client_02_2022_preview.online_endpoints
+ self._online_deployment_operation = service_client_02_2022_preview.online_deployments
+ self._all_operations = all_operations
+ self._local_endpoint_helper = local_endpoint_helper
+ self._credentials = credentials
+ self._init_kwargs = kwargs
+
+ self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline")
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "OnlineEndpoint.List", ActivityType.PUBLICAPI)
+ def list(self, *, local: bool = False) -> ItemPaged[OnlineEndpoint]:
+ """List endpoints of the workspace.
+
+ :keyword local: (Optional) Flag to indicate whether to interact with endpoints in local Docker environment.
+ Default: False
+ :type local: bool
+ :return: A list of endpoints
+ :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.OnlineEndpoint]
+ """
+ if local:
+ return self._local_endpoint_helper.list()
+ return self._online_operation.list(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ cls=lambda objs: [OnlineEndpoint._from_rest_object(obj) for obj in objs],
+ **self._init_kwargs,
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "OnlineEndpoint.ListKeys", ActivityType.PUBLICAPI)
+ def get_keys(self, name: str) -> Union[EndpointAuthKeys, EndpointAuthToken, EndpointAadToken]:
+ """Get the auth credentials.
+
+ :param name: The endpoint name
+ :type name: str
+ :raise: Exception if cannot get online credentials
+ :return: Depending on the auth mode in the endpoint, returns either keys or token
+ :rtype: Union[~azure.ai.ml.entities.EndpointAuthKeys, ~azure.ai.ml.entities.EndpointAuthToken]
+ """
+ return self._get_online_credentials(name=name)
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "OnlineEndpoint.Get", ActivityType.PUBLICAPI)
+ def get(
+ self,
+ name: str,
+ *,
+ local: bool = False,
+ ) -> OnlineEndpoint:
+ """Get a Endpoint resource.
+
+ :param name: Name of the endpoint.
+ :type name: str
+ :keyword local: Indicates whether to interact with endpoints in local Docker environment. Defaults to False.
+ :paramtype local: Optional[bool]
+ :raises ~azure.ai.ml.exceptions.LocalEndpointNotFoundError: Raised if local endpoint resource does not exist.
+ :return: Endpoint object retrieved from the service.
+ :rtype: ~azure.ai.ml.entities.OnlineEndpoint
+ """
+ # first get the endpoint
+ if local:
+ return self._local_endpoint_helper.get(endpoint_name=name)
+
+ endpoint = self._online_operation.get(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=name,
+ **self._init_kwargs,
+ )
+
+ deployments_list = self._online_deployment_operation.list(
+ endpoint_name=name,
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ cls=lambda objs: [OnlineDeployment._from_rest_object(obj) for obj in objs],
+ **self._init_kwargs,
+ )
+
+ # populate deployments without traffic with zeroes in traffic map
+ converted_endpoint = OnlineEndpoint._from_rest_object(endpoint)
+ if deployments_list:
+ for deployment in deployments_list:
+ if not converted_endpoint.traffic.get(deployment.name) and not converted_endpoint.mirror_traffic.get(
+ deployment.name
+ ):
+ converted_endpoint.traffic[deployment.name] = 0
+
+ return converted_endpoint
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "OnlineEndpoint.BeginDelete", ActivityType.PUBLICAPI)
+ def begin_delete(self, name: Optional[str] = None, *, local: bool = False) -> LROPoller[None]:
+ """Delete an Online Endpoint.
+
+ :param name: Name of the endpoint.
+ :type name: str
+ :keyword local: Whether to interact with the endpoint in local Docker environment. Defaults to False.
+ :paramtype local: bool
+ :raises ~azure.ai.ml.exceptions.LocalEndpointNotFoundError: Raised if local endpoint resource does not exist.
+ :return: A poller to track the operation status if remote, else returns None if local.
+ :rtype: ~azure.core.polling.LROPoller[None]
+ """
+ if local:
+ return self._local_endpoint_helper.delete(name=str(name))
+
+ path_format_arguments = {
+ "endpointName": name,
+ "resourceGroupName": self._resource_group_name,
+ "workspaceName": self._workspace_name,
+ }
+
+ delete_poller = self._online_operation.begin_delete(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=name,
+ polling=AzureMLPolling(
+ LROConfigurations.POLL_INTERVAL,
+ path_format_arguments=path_format_arguments,
+ **self._init_kwargs,
+ ),
+ polling_interval=LROConfigurations.POLL_INTERVAL,
+ **self._init_kwargs,
+ )
+ return delete_poller
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "OnlineEndpoint.BeginDeleteOrUpdate", ActivityType.PUBLICAPI)
+ def begin_create_or_update(self, endpoint: OnlineEndpoint, *, local: bool = False) -> LROPoller[OnlineEndpoint]:
+ """Create or update an endpoint.
+
+ :param endpoint: The endpoint entity.
+ :type endpoint: ~azure.ai.ml.entities.OnlineEndpoint
+ :keyword local: Whether to interact with the endpoint in local Docker environment. Defaults to False.
+ :paramtype local: bool
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if OnlineEndpoint cannot be successfully validated.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.AssetException: Raised if OnlineEndpoint assets
+ (e.g. Data, Code, Model, Environment) cannot be successfully validated.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.ModelException: Raised if OnlineEndpoint model cannot be successfully validated.
+ Details will be provided in the error message.
+ :raises ~azure.ai.ml.exceptions.EmptyDirectoryError: Raised if local path provided points to an empty directory.
+ :raises ~azure.ai.ml.exceptions.LocalEndpointNotFoundError: Raised if local endpoint resource does not exist.
+ :return: A poller to track the operation status if remote, else returns None if local.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.OnlineEndpoint]
+ """
+ try:
+ if local:
+ return self._local_endpoint_helper.create_or_update(endpoint=endpoint)
+
+ try:
+ location = self._get_workspace_location()
+
+ if endpoint.traffic:
+ endpoint.traffic = _strip_zeroes_from_traffic(endpoint.traffic)
+
+ if endpoint.mirror_traffic:
+ endpoint.mirror_traffic = _strip_zeroes_from_traffic(endpoint.mirror_traffic)
+
+ endpoint_resource = endpoint._to_rest_online_endpoint(location=location)
+ orchestrators = OperationOrchestrator(
+ operation_container=self._all_operations,
+ operation_scope=self._operation_scope,
+ operation_config=self._operation_config,
+ )
+ if hasattr(endpoint_resource.properties, "compute"):
+ endpoint_resource.properties.compute = orchestrators.get_asset_arm_id(
+ endpoint_resource.properties.compute,
+ azureml_type=AzureMLResourceType.COMPUTE,
+ )
+ poller = self._online_operation.begin_create_or_update(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=endpoint.name,
+ body=endpoint_resource,
+ cls=lambda response, deserialized, headers: OnlineEndpoint._from_rest_object(deserialized),
+ **self._init_kwargs,
+ )
+ return poller
+
+ except Exception as ex:
+ raise ex
+ except Exception as ex: # pylint: disable=W0718
+ if isinstance(ex, (ValidationException, SchemaValidationError)):
+ log_and_raise_error(ex)
+ else:
+ raise ex
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "OnlineEndpoint.BeginGenerateKeys", ActivityType.PUBLICAPI)
+ def begin_regenerate_keys(
+ self,
+ name: str,
+ *,
+ key_type: str = EndpointKeyType.PRIMARY_KEY_TYPE,
+ ) -> LROPoller[None]:
+ """Regenerate keys for endpoint.
+
+ :param name: The endpoint name.
+ :type name: The endpoint type. Defaults to ONLINE_ENDPOINT_TYPE.
+ :keyword key_type: One of "primary", "secondary". Defaults to "primary".
+ :paramtype key_type: str
+ :return: A poller to track the operation status.
+ :rtype: ~azure.core.polling.LROPoller[None]
+ """
+ endpoint = self._online_operation.get(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=name,
+ **self._init_kwargs,
+ )
+
+ if endpoint.properties.auth_mode.lower() == "key":
+ return self._regenerate_online_keys(name=name, key_type=key_type)
+ raise ValidationException(
+ message=f"Endpoint '{name}' does not use keys for authentication.",
+ target=ErrorTarget.ONLINE_ENDPOINT,
+ no_personal_data_message="Endpoint does not use keys for authentication.",
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "OnlineEndpoint.Invoke", ActivityType.PUBLICAPI)
+ def invoke(
+ self,
+ endpoint_name: str,
+ *,
+ request_file: Optional[str] = None,
+ deployment_name: Optional[str] = None,
+ # pylint: disable=unused-argument
+ input_data: Optional[Union[str, Data]] = None,
+ params_override: Any = None,
+ local: bool = False,
+ **kwargs: Any,
+ ) -> str:
+ """Invokes the endpoint with the provided payload.
+
+ :param endpoint_name: The endpoint name
+ :type endpoint_name: str
+ :keyword request_file: File containing the request payload. This is only valid for online endpoint.
+ :paramtype request_file: Optional[str]
+ :keyword deployment_name: Name of a specific deployment to invoke. This is optional.
+ By default requests are routed to any of the deployments according to the traffic rules.
+ :paramtype deployment_name: Optional[str]
+ :keyword input_data: To use a pre-registered data asset, pass str in format
+ :paramtype input_data: Optional[Union[str, Data]]
+ :keyword params_override: A dictionary of payload parameters to override and their desired values.
+ :paramtype params_override: Any
+ :keyword local: Indicates whether to interact with endpoints in local Docker environment. Defaults to False.
+ :paramtype local: Optional[bool]
+ :raises ~azure.ai.ml.exceptions.LocalEndpointNotFoundError: Raised if local endpoint resource does not exist.
+ :raises ~azure.ai.ml.exceptions.MultipleLocalDeploymentsFoundError: Raised if there are multiple deployments
+ and no deployment_name is specified.
+ :raises ~azure.ai.ml.exceptions.InvalidLocalEndpointError: Raised if local endpoint is None.
+ :return: Prediction output for online endpoint.
+ :rtype: str
+ """
+ params_override = params_override or []
+ # Until this bug is resolved https://msdata.visualstudio.com/Vienna/_workitems/edit/1446538
+ if deployment_name:
+ self._validate_deployment_name(endpoint_name, deployment_name)
+
+ with open(request_file, "rb") as f: # type: ignore[arg-type]
+ data = json.loads(f.read())
+ if local:
+ return self._local_endpoint_helper.invoke(
+ endpoint_name=endpoint_name, data=data, deployment_name=deployment_name
+ )
+ endpoint = self._online_operation.get(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=endpoint_name,
+ **self._init_kwargs,
+ )
+ keys = self._get_online_credentials(name=endpoint_name, auth_mode=endpoint.properties.auth_mode)
+ if isinstance(keys, EndpointAuthKeys):
+ key = keys.primary_key
+ elif isinstance(keys, (EndpointAuthToken, EndpointAadToken)):
+ key = keys.access_token
+ else:
+ key = ""
+ headers = EndpointInvokeFields.DEFAULT_HEADER
+ if key:
+ headers[EndpointInvokeFields.AUTHORIZATION] = f"Bearer {key}"
+ if deployment_name:
+ headers[EndpointInvokeFields.MODEL_DEPLOYMENT] = deployment_name
+
+ response = self._requests_pipeline.post(endpoint.properties.scoring_uri, json=data, headers=headers)
+ validate_response(response)
+ return str(response.text())
+
+ def _get_workspace_location(self) -> str:
+ return str(
+ self._all_operations.all_operations[AzureMLResourceType.WORKSPACE].get(self._workspace_name).location
+ )
+
+ def _get_online_credentials(
+ self, name: str, auth_mode: Optional[str] = None
+ ) -> Union[EndpointAuthKeys, EndpointAuthToken, EndpointAadToken]:
+ if not auth_mode:
+ endpoint = self._online_operation.get(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=name,
+ **self._init_kwargs,
+ )
+ auth_mode = endpoint.properties.auth_mode
+ if auth_mode is not None and auth_mode.lower() == KEY:
+ return self._online_operation.list_keys(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=name,
+ # pylint: disable=protected-access
+ cls=lambda x, response, z: EndpointAuthKeys._from_rest_object(response),
+ **self._init_kwargs,
+ )
+
+ if auth_mode is not None and auth_mode.lower() == AAD_TOKEN:
+ if self._credentials:
+ return EndpointAadToken(self._credentials.get_token(*_resource_to_scopes(AAD_TOKEN_RESOURCE_ENDPOINT)))
+ msg = EMPTY_CREDENTIALS_ERROR
+ raise MlException(message=msg, no_personal_data_message=msg)
+
+ return self._online_operation.get_token(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=name,
+ # pylint: disable=protected-access
+ cls=lambda x, response, z: EndpointAuthToken._from_rest_object(response),
+ **self._init_kwargs,
+ )
+
+ def _regenerate_online_keys(
+ self,
+ name: str,
+ key_type: str = EndpointKeyType.PRIMARY_KEY_TYPE,
+ ) -> LROPoller[None]:
+ keys = self._online_operation.list_keys(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=name,
+ **self._init_kwargs,
+ )
+ if key_type.lower() == EndpointKeyType.PRIMARY_KEY_TYPE:
+ key_request = RegenerateEndpointKeysRequest(key_type=KeyType.Primary, key_value=keys.primary_key)
+ elif key_type.lower() == EndpointKeyType.SECONDARY_KEY_TYPE:
+ key_request = RegenerateEndpointKeysRequest(key_type=KeyType.Secondary, key_value=keys.secondary_key)
+ else:
+ msg = "Key type must be 'primary' or 'secondary'."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.ONLINE_ENDPOINT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ poller = self._online_operation.begin_regenerate_keys(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=name,
+ body=key_request,
+ **self._init_kwargs,
+ )
+
+ return poller
+
+ def _validate_deployment_name(self, endpoint_name: str, deployment_name: str) -> None:
+ deployments_list = self._online_deployment_operation.list(
+ endpoint_name=endpoint_name,
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ cls=lambda objs: [obj.name for obj in objs],
+ **self._init_kwargs,
+ )
+
+ if deployments_list:
+ if deployment_name not in deployments_list:
+ raise ValidationException(
+ message=f"Deployment name {deployment_name} not found for this endpoint",
+ target=ErrorTarget.ONLINE_ENDPOINT,
+ no_personal_data_message="Deployment name not found for this endpoint",
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.RESOURCE_NOT_FOUND,
+ )
+ else:
+ msg = "No deployment exists for this endpoint"
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.ONLINE_ENDPOINT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.RESOURCE_NOT_FOUND,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_operation_orchestrator.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_operation_orchestrator.py
new file mode 100644
index 00000000..2ad44bca
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_operation_orchestrator.py
@@ -0,0 +1,571 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import logging
+import re
+from os import PathLike
+from typing import Any, Optional, Tuple, Union
+
+from typing_extensions import Protocol
+
+from azure.ai.ml._artifacts._artifact_utilities import _check_and_upload_env_build_context, _check_and_upload_path
+from azure.ai.ml._scope_dependent_operations import (
+ OperationConfig,
+ OperationsContainer,
+ OperationScope,
+ _ScopeDependentOperations,
+)
+from azure.ai.ml._utils._arm_id_utils import (
+ AMLLabelledArmId,
+ AMLNamedArmId,
+ AMLVersionedArmId,
+ get_arm_id_with_version,
+ is_ARM_id_for_resource,
+ is_registry_id_for_resource,
+ is_singularity_full_name_for_resource,
+ is_singularity_id_for_resource,
+ is_singularity_short_name_for_resource,
+ parse_name_label,
+ parse_prefixed_name_version,
+)
+from azure.ai.ml._utils._asset_utils import _resolve_label_to_asset, get_storage_info_for_non_registry_asset
+from azure.ai.ml._utils._storage_utils import AzureMLDatastorePathUri
+from azure.ai.ml.constants._common import (
+ ARM_ID_PREFIX,
+ AZUREML_RESOURCE_PROVIDER,
+ CURATED_ENV_PREFIX,
+ DEFAULT_LABEL_NAME,
+ FILE_PREFIX,
+ FOLDER_PREFIX,
+ HTTPS_PREFIX,
+ JOB_URI_REGEX_FORMAT,
+ LABELLED_RESOURCE_ID_FORMAT,
+ LABELLED_RESOURCE_NAME,
+ MLFLOW_URI_REGEX_FORMAT,
+ NAMED_RESOURCE_ID_FORMAT,
+ REGISTRY_VERSION_PATTERN,
+ SINGULARITY_FULL_NAME_REGEX_FORMAT,
+ SINGULARITY_ID_FORMAT,
+ SINGULARITY_SHORT_NAME_REGEX_FORMAT,
+ VERSIONED_RESOURCE_ID_FORMAT,
+ VERSIONED_RESOURCE_NAME,
+ AzureMLResourceType,
+)
+from azure.ai.ml.entities import Component
+from azure.ai.ml.entities._assets import Code, Data, Environment, Model
+from azure.ai.ml.entities._assets.asset import Asset
+from azure.ai.ml.exceptions import (
+ AssetException,
+ EmptyDirectoryError,
+ ErrorCategory,
+ ErrorTarget,
+ MlException,
+ ModelException,
+ ValidationErrorType,
+ ValidationException,
+)
+from azure.core.exceptions import HttpResponseError, ResourceNotFoundError
+
+module_logger = logging.getLogger(__name__)
+
+
+class OperationOrchestrator(object):
+ def __init__(
+ self,
+ operation_container: OperationsContainer,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ ):
+ self._operation_container = operation_container
+ self._operation_scope = operation_scope
+ self._operation_config = operation_config
+
+ @property
+ def _datastore_operation(self) -> _ScopeDependentOperations:
+ return self._operation_container.all_operations[AzureMLResourceType.DATASTORE]
+
+ @property
+ def _code_assets(self) -> _ScopeDependentOperations:
+ return self._operation_container.all_operations[AzureMLResourceType.CODE]
+
+ @property
+ def _model(self) -> _ScopeDependentOperations:
+ return self._operation_container.all_operations[AzureMLResourceType.MODEL]
+
+ @property
+ def _environments(self) -> _ScopeDependentOperations:
+ return self._operation_container.all_operations[AzureMLResourceType.ENVIRONMENT]
+
+ @property
+ def _data(self) -> _ScopeDependentOperations:
+ return self._operation_container.all_operations[AzureMLResourceType.DATA]
+
+ @property
+ def _component(self) -> _ScopeDependentOperations:
+ return self._operation_container.all_operations[AzureMLResourceType.COMPONENT]
+
+ @property
+ def _virtual_cluster(self) -> _ScopeDependentOperations:
+ return self._operation_container.all_operations[AzureMLResourceType.VIRTUALCLUSTER]
+
+ def get_asset_arm_id(
+ self,
+ asset: Optional[Union[str, Asset]],
+ azureml_type: str,
+ register_asset: bool = True,
+ sub_workspace_resource: bool = True,
+ ) -> Optional[Union[str, Asset]]:
+ """This method converts AzureML Id to ARM Id. Or if the given asset is entity object, it tries to
+ register/upload the asset based on register_asset and azureml_type.
+
+ :param asset: The asset to resolve/register. It can be a ARM id or a entity's object.
+ :type asset: Optional[Union[str, Asset]]
+ :param azureml_type: The AzureML resource type. Defined in AzureMLResourceType.
+ :type azureml_type: str
+ :param register_asset: Indicates if the asset should be registered, defaults to True.
+ :type register_asset: Optional[bool]
+ :param sub_workspace_resource:
+ :type sub_workspace_resource: Optional[bool]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if asset's ID cannot be converted
+ or asset cannot be successfully registered.
+ :return: The ARM Id or entity object
+ :rtype: Optional[Union[str, ~azure.ai.ml.entities.Asset]]
+ """
+ # pylint: disable=too-many-return-statements, too-many-branches
+ if (
+ asset is None
+ or is_ARM_id_for_resource(asset, azureml_type, sub_workspace_resource)
+ or is_registry_id_for_resource(asset)
+ or is_singularity_id_for_resource(asset)
+ ):
+ return asset
+ if is_singularity_full_name_for_resource(asset):
+ return self._get_singularity_arm_id_from_full_name(str(asset))
+ if is_singularity_short_name_for_resource(asset):
+ return self._get_singularity_arm_id_from_short_name(str(asset))
+ if isinstance(asset, str):
+ if azureml_type in AzureMLResourceType.NAMED_TYPES:
+ return NAMED_RESOURCE_ID_FORMAT.format(
+ self._operation_scope.subscription_id,
+ self._operation_scope.resource_group_name,
+ AZUREML_RESOURCE_PROVIDER,
+ self._operation_scope.workspace_name,
+ azureml_type,
+ asset,
+ )
+ if azureml_type in AzureMLResourceType.VERSIONED_TYPES:
+ # Short form of curated env will be expanded on the backend side.
+ # CLI strips off azureml: in the schema, appending it back as required by backend
+ if azureml_type == AzureMLResourceType.ENVIRONMENT:
+ azureml_prefix = "azureml:"
+ # return the same value if resolved result is passed in
+ _asset = asset[len(azureml_prefix) :] if asset.startswith(azureml_prefix) else asset
+ if _asset.startswith(CURATED_ENV_PREFIX) or re.match(
+ REGISTRY_VERSION_PATTERN, f"{azureml_prefix}{_asset}"
+ ):
+ return f"{azureml_prefix}{_asset}"
+
+ name, label = parse_name_label(asset)
+ # TODO: remove this condition after label is fully supported for all versioned resources
+ if label == DEFAULT_LABEL_NAME and azureml_type == AzureMLResourceType.COMPONENT:
+ return LABELLED_RESOURCE_ID_FORMAT.format(
+ self._operation_scope.subscription_id,
+ self._operation_scope.resource_group_name,
+ AZUREML_RESOURCE_PROVIDER,
+ self._operation_scope.workspace_name,
+ azureml_type,
+ name,
+ label,
+ )
+ name, version = self._resolve_name_version_from_name_label(asset, azureml_type)
+ if not version:
+ name, version = parse_prefixed_name_version(asset)
+
+ if not version:
+ msg = (
+ "Failed to extract version when parsing asset {} of type {} as arm id. "
+ "Version must be provided."
+ )
+ raise ValidationException(
+ message=msg.format(asset, azureml_type),
+ target=ErrorTarget.ASSET,
+ no_personal_data_message=msg.format("", azureml_type),
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.MISSING_FIELD,
+ )
+ if self._operation_scope.registry_name:
+ # Short form for env not supported with registry flow except when it's a curated env.
+ # Adding a graceful error message for the scenario
+ if not asset.startswith(CURATED_ENV_PREFIX):
+ msg = (
+ "Use fully qualified name to reference custom environments "
+ "when creating assets in registry. "
+ "The syntax for fully qualified names is "
+ "azureml://registries/azureml/environments/{{env-name}}/versions/{{version}}"
+ )
+ raise ValidationException(
+ message=msg.format(asset, azureml_type),
+ target=ErrorTarget.ASSET,
+ no_personal_data_message=msg.format("", azureml_type),
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ return VERSIONED_RESOURCE_ID_FORMAT.format(
+ self._operation_scope.subscription_id,
+ self._operation_scope.resource_group_name,
+ AZUREML_RESOURCE_PROVIDER,
+ self._operation_scope.workspace_name,
+ azureml_type,
+ name,
+ version,
+ )
+ msg = "Unsupported azureml type {} for asset: {}"
+ raise ValidationException(
+ message=msg.format(azureml_type, asset),
+ target=ErrorTarget.ASSET,
+ no_personal_data_message=msg.format(azureml_type, ""),
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ if isinstance(asset, Asset):
+ try:
+ result: Any = None
+ # TODO: once the asset redesign is finished, this logic can be replaced with unified API
+ if azureml_type == AzureMLResourceType.CODE and isinstance(asset, Code):
+ result = self._get_code_asset_arm_id(asset, register_asset=register_asset)
+ elif azureml_type == AzureMLResourceType.ENVIRONMENT and isinstance(asset, Environment):
+ result = self._get_environment_arm_id(asset, register_asset=register_asset)
+ elif azureml_type == AzureMLResourceType.MODEL and isinstance(asset, Model):
+ result = self._get_model_arm_id(asset, register_asset=register_asset)
+ elif azureml_type == AzureMLResourceType.DATA and isinstance(asset, Data):
+ result = self._get_data_arm_id(asset, register_asset=register_asset)
+ elif azureml_type == AzureMLResourceType.COMPONENT and isinstance(asset, Component):
+ result = self._get_component_arm_id(asset)
+ else:
+ msg = "Unsupported azureml type {} for asset: {}"
+ raise ValidationException(
+ message=msg.format(azureml_type, asset),
+ target=ErrorTarget.ASSET,
+ no_personal_data_message=msg.format(azureml_type, ""),
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ except EmptyDirectoryError as e:
+ msg = f"Error creating {azureml_type} asset : {e.message}"
+ raise AssetException(
+ message=msg.format(azureml_type, e.message),
+ target=ErrorTarget.ASSET,
+ no_personal_data_message=msg.format(azureml_type, ""),
+ error=e,
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ ) from e
+ return result
+ msg = f"Error creating {azureml_type} asset: must be type Optional[Union[str, Asset]]"
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.ASSET,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ def _get_code_asset_arm_id(self, code_asset: Code, register_asset: bool = True) -> Union[Code, str]:
+ try:
+ self._validate_datastore_name(code_asset.path)
+ if register_asset:
+ code_asset = self._code_assets.create_or_update(code_asset) # type: ignore[attr-defined]
+ return str(code_asset.id)
+ sas_info = get_storage_info_for_non_registry_asset(
+ service_client=self._code_assets._service_client, # type: ignore[attr-defined]
+ workspace_name=self._operation_scope.workspace_name,
+ name=code_asset.name,
+ version=code_asset.version,
+ resource_group=self._operation_scope.resource_group_name,
+ )
+ uploaded_code_asset, _ = _check_and_upload_path(
+ artifact=code_asset,
+ asset_operations=self._code_assets, # type: ignore[arg-type]
+ artifact_type=ErrorTarget.CODE,
+ show_progress=self._operation_config.show_progress,
+ sas_uri=sas_info["sas_uri"],
+ blob_uri=sas_info["blob_uri"],
+ )
+ uploaded_code_asset._id = get_arm_id_with_version(
+ self._operation_scope,
+ AzureMLResourceType.CODE,
+ code_asset.name,
+ code_asset.version,
+ )
+ return uploaded_code_asset
+ except (MlException, HttpResponseError) as e:
+ raise e
+ except Exception as e:
+ raise AssetException(
+ message=f"Error with code: {e}",
+ target=ErrorTarget.ASSET,
+ no_personal_data_message="Error getting code asset",
+ error=e,
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ ) from e
+
+ def _get_environment_arm_id(self, environment: Environment, register_asset: bool = True) -> Union[str, Environment]:
+ if register_asset:
+ if environment.id:
+ return environment.id
+ env_response = self._environments.create_or_update(environment) # type: ignore[attr-defined]
+ return env_response.id
+ environment = _check_and_upload_env_build_context(
+ environment=environment,
+ operations=self._environments, # type: ignore[arg-type]
+ show_progress=self._operation_config.show_progress,
+ )
+ environment._id = get_arm_id_with_version(
+ self._operation_scope,
+ AzureMLResourceType.ENVIRONMENT,
+ environment.name,
+ environment.version,
+ )
+ return environment
+
+ def _get_model_arm_id(self, model: Model, register_asset: bool = True) -> Union[str, Model]:
+ try:
+ self._validate_datastore_name(model.path)
+
+ if register_asset:
+ if model.id:
+ return model.id
+ return self._model.create_or_update(model).id # type: ignore[attr-defined]
+ uploaded_model, _ = _check_and_upload_path(
+ artifact=model,
+ asset_operations=self._model, # type: ignore[arg-type]
+ artifact_type=ErrorTarget.MODEL,
+ show_progress=self._operation_config.show_progress,
+ )
+ uploaded_model._id = get_arm_id_with_version(
+ self._operation_scope,
+ AzureMLResourceType.MODEL,
+ model.name,
+ model.version,
+ )
+ return uploaded_model
+ except (MlException, HttpResponseError) as e:
+ raise e
+ except Exception as e:
+ raise ModelException(
+ message=f"Error with model: {e}",
+ target=ErrorTarget.MODEL,
+ no_personal_data_message="Error getting model",
+ error=e,
+ error_category=ErrorCategory.SYSTEM_ERROR,
+ ) from e
+
+ def _get_data_arm_id(self, data_asset: Data, register_asset: bool = True) -> Union[str, Data]:
+ self._validate_datastore_name(data_asset.path)
+
+ if register_asset:
+ return self._data.create_or_update(data_asset).id # type: ignore[attr-defined]
+ data_asset, _ = _check_and_upload_path(
+ artifact=data_asset,
+ asset_operations=self._data, # type: ignore[arg-type]
+ artifact_type=ErrorTarget.DATA,
+ show_progress=self._operation_config.show_progress,
+ )
+ return data_asset
+
+ def _get_component_arm_id(self, component: Component) -> str:
+ """Gets the component ARM ID.
+
+ :param component: The component
+ :type component: Component
+ :return: The component id
+ :rtype: str
+ """
+
+ # If component arm id is already resolved, return the id otherwise get arm id via remote call.
+ # Register the component if necessary, and FILL BACK the arm id to component to reduce remote call.
+ if not component.id:
+ component._id = self._component.create_or_update( # type: ignore[attr-defined]
+ component, is_anonymous=True, show_progress=self._operation_config.show_progress
+ ).id
+ return str(component.id)
+
+ def _get_singularity_arm_id_from_full_name(self, singularity: str) -> str:
+ match = re.match(SINGULARITY_FULL_NAME_REGEX_FORMAT, singularity)
+ subscription_id = match.group("subscription_id") if match is not None else ""
+ resource_group_name = match.group("resource_group_name") if match is not None else ""
+ vc_name = match.group("name") if match is not None else ""
+ arm_id = SINGULARITY_ID_FORMAT.format(subscription_id, resource_group_name, vc_name)
+ vc = self._virtual_cluster.get(arm_id) # type: ignore[attr-defined]
+ return str(vc["id"])
+
+ def _get_singularity_arm_id_from_short_name(self, singularity: str) -> str:
+ match = re.match(SINGULARITY_SHORT_NAME_REGEX_FORMAT, singularity)
+ vc_name = match.group("name") if match is not None else ""
+ # below list operation can be time-consuming, may need an optimization on this
+ match_vcs = [vc for vc in self._virtual_cluster.list() if vc["name"] == vc_name] # type: ignore[attr-defined]
+ num_match_vc = len(match_vcs)
+ if num_match_vc != 1:
+ if num_match_vc == 0:
+ msg = "The virtual cluster {} could not be found."
+ else:
+ msg = "More than one match virtual clusters {} found."
+ raise ValidationException(
+ message=msg.format(vc_name),
+ no_personal_data_message=msg.format(""),
+ target=ErrorTarget.COMPUTE,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ return str(match_vcs[0]["id"])
+
+ def _resolve_name_version_from_name_label(self, aml_id: str, azureml_type: str) -> Tuple[str, Optional[str]]:
+ """Given an AzureML id of the form name@label, resolves the label to the actual ID.
+
+ :param aml_id: AzureML id of the form name@label
+ :type aml_id: str
+ :param azureml_type: The AzureML resource type. Defined in AzureMLResourceType.
+ :type azureml_type: str
+ :return: Returns tuple (name, version) on success, (name@label, None) if resolution fails
+ :rtype: Tuple[str, Optional[str]]
+ """
+ name, label = parse_name_label(aml_id)
+ if (
+ azureml_type not in AzureMLResourceType.VERSIONED_TYPES
+ or azureml_type == AzureMLResourceType.CODE
+ or not label
+ ):
+ return aml_id, None
+
+ return (
+ name,
+ _resolve_label_to_asset(
+ self._operation_container.all_operations[azureml_type],
+ name,
+ label=label,
+ ).version,
+ )
+
+ # pylint: disable=unused-argument
+ def resolve_azureml_id(self, arm_id: Optional[str] = None, **kwargs: Any) -> Optional[str]:
+ """This function converts ARM id to name or name:version AzureML id. It parses the ARM id and matches the
+ subscription Id, resource group name and workspace_name.
+
+ TODO: It is debatable whether this method should be in operation_orchestrator.
+
+ :param arm_id: entity's ARM id, defaults to None
+ :type arm_id: str
+ :return: AzureML id
+ :rtype: str
+ """
+
+ if arm_id:
+ if not isinstance(arm_id, str):
+ msg = "arm_id cannot be resolved: str expected but got {}".format(type(arm_id)) # type: ignore
+ raise ValidationException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.GENERAL,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ try:
+ arm_id_obj = AMLVersionedArmId(arm_id)
+ if arm_id_obj.is_registry_id:
+ return arm_id
+ if self._match(arm_id_obj):
+ return str(VERSIONED_RESOURCE_NAME.format(arm_id_obj.asset_name, arm_id_obj.asset_version))
+ except ValidationException:
+ pass # fall back to named arm id
+ try:
+ arm_id_obj = AMLLabelledArmId(arm_id)
+ if self._match(arm_id_obj):
+ return str(LABELLED_RESOURCE_NAME.format(arm_id_obj.asset_name, arm_id_obj.asset_label))
+ except ValidationException:
+ pass # fall back to named arm id
+ try:
+ arm_id_obj = AMLNamedArmId(arm_id)
+ if self._match(arm_id_obj):
+ return str(arm_id_obj.asset_name)
+ except ValidationException:
+ pass # fall back to be not a ARM_id
+ return arm_id
+
+ def _match(self, id_: Any) -> bool:
+ return bool(
+ (
+ id_.subscription_id == self._operation_scope.subscription_id
+ and id_.resource_group_name == self._operation_scope.resource_group_name
+ and id_.workspace_name == self._operation_scope.workspace_name
+ )
+ )
+
+ def _validate_datastore_name(self, datastore_uri: Optional[Union[str, PathLike]]) -> None:
+ if datastore_uri:
+ try:
+ if isinstance(datastore_uri, str):
+ if datastore_uri.startswith(FILE_PREFIX):
+ datastore_uri = datastore_uri[len(FILE_PREFIX) :]
+ elif datastore_uri.startswith(FOLDER_PREFIX):
+ datastore_uri = datastore_uri[len(FOLDER_PREFIX) :]
+ elif isinstance(datastore_uri, PathLike):
+ return
+
+ if datastore_uri.startswith(HTTPS_PREFIX) and datastore_uri.count("/") == 7:
+ # only long-form (i.e. "https://x.blob.core.windows.net/datastore/LocalUpload/guid/x/x")
+ # format includes datastore
+ datastore_name = datastore_uri.split("/")[3]
+ elif datastore_uri.startswith(ARM_ID_PREFIX) and not (
+ re.match(MLFLOW_URI_REGEX_FORMAT, datastore_uri) or re.match(JOB_URI_REGEX_FORMAT, datastore_uri)
+ ):
+ datastore_name = AzureMLDatastorePathUri(datastore_uri).datastore
+ else:
+ # local path
+ return
+
+ if datastore_name.startswith(ARM_ID_PREFIX):
+ datastore_name = datastore_name[len(ARM_ID_PREFIX) :]
+
+ self._datastore_operation.get(datastore_name) # type: ignore[attr-defined]
+ except ResourceNotFoundError as e:
+ msg = "The datastore {} could not be found in this workspace."
+ raise ValidationException(
+ message=msg.format(datastore_name),
+ target=ErrorTarget.DATASTORE,
+ no_personal_data_message=msg.format(""),
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.RESOURCE_NOT_FOUND,
+ ) from e
+
+
+class _AssetResolver(Protocol):
+ """Describes the type of a function used by operation classes like :py:class:`JobOperations` and
+ :py:class:`ComponentOperations` to resolve Assets
+
+ .. see-also:: methods :py:method:`OperationOrchestrator.get_asset_arm_id`,
+ :py:method:`OperationOrchestrator.resolve_azureml_id`
+
+ """
+
+ def __call__(
+ self,
+ asset: Optional[Union[str, Asset]],
+ azureml_type: str,
+ register_asset: bool = True,
+ sub_workspace_resource: bool = True,
+ ) -> Optional[Union[str, Asset]]:
+ """Resolver function
+
+ :param asset: The asset to resolve/register. It can be a ARM id or a entity's object.
+ :type asset: Optional[Union[str, Asset]]
+ :param azureml_type: The AzureML resource type. Defined in AzureMLResourceType.
+ :type azureml_type: str
+ :param register_asset: Indicates if the asset should be registered, defaults to True.
+ :type register_asset: Optional[bool]
+ :param sub_workspace_resource:
+ :type sub_workspace_resource: Optional[bool]
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if asset's ID cannot be converted
+ or asset cannot be successfully registered.
+ :return: The ARM Id or entity object
+ :rtype: Optional[Union[str, ~azure.ai.ml.entities.Asset]]
+ """
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_registry_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_registry_operations.py
new file mode 100644
index 00000000..6cd1a326
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_registry_operations.py
@@ -0,0 +1,168 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access,unused-argument
+
+from typing import Dict, Iterable, Optional, cast
+
+from azure.ai.ml._restclient.v2022_10_01_preview import AzureMachineLearningWorkspaces as ServiceClient102022
+from azure.ai.ml._scope_dependent_operations import OperationsContainer, OperationScope
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml.entities import Registry
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+from azure.core.credentials import TokenCredential
+from azure.core.polling import LROPoller
+
+from .._utils._azureml_polling import AzureMLPolling
+from ..constants._common import LROConfigurations, Scope
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class RegistryOperations:
+ """RegistryOperations.
+
+ You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it
+ for you and attaches it as an attribute.
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ service_client: ServiceClient102022,
+ all_operations: OperationsContainer,
+ credentials: Optional[TokenCredential] = None,
+ **kwargs: Dict,
+ ):
+ ops_logger.update_filter()
+ self._subscription_id = operation_scope.subscription_id
+ self._resource_group_name = operation_scope.resource_group_name
+ self._default_registry_name = operation_scope.registry_name
+ self._operation = service_client.registries
+ self._all_operations = all_operations
+ self._credentials = credentials
+ self.containerRegistry = "none"
+ self._init_kwargs = kwargs
+
+ @monitor_with_activity(ops_logger, "Registry.List", ActivityType.PUBLICAPI)
+ def list(self, *, scope: str = Scope.RESOURCE_GROUP) -> Iterable[Registry]:
+ """List all registries that the user has access to in the current resource group or subscription.
+
+ :keyword scope: scope of the listing, "resource_group" or "subscription", defaults to "resource_group"
+ :paramtype scope: str
+ :return: An iterator like instance of Registry objects
+ :rtype: ~azure.core.paging.ItemPaged[Registry]
+ """
+ if scope.lower() == Scope.SUBSCRIPTION:
+ return cast(
+ Iterable[Registry],
+ self._operation.list_by_subscription(
+ cls=lambda objs: [Registry._from_rest_object(obj) for obj in objs]
+ ),
+ )
+ return cast(
+ Iterable[Registry],
+ self._operation.list(
+ cls=lambda objs: [Registry._from_rest_object(obj) for obj in objs],
+ resource_group_name=self._resource_group_name,
+ ),
+ )
+
+ @monitor_with_activity(ops_logger, "Registry.Get", ActivityType.PUBLICAPI)
+ def get(self, name: Optional[str] = None) -> Registry:
+ """Get a registry by name.
+
+ :param name: Name of the registry.
+ :type name: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if Registry name cannot be
+ successfully validated. Details will be provided in the error message.
+ :raises ~azure.core.exceptions.HttpResponseError: Raised if the corresponding name and version cannot be
+ retrieved from the service.
+ :return: The registry with the provided name.
+ :rtype: ~azure.ai.ml.entities.Registry
+ """
+
+ registry_name = self._check_registry_name(name)
+ resource_group = self._resource_group_name
+ obj = self._operation.get(resource_group, registry_name)
+ return Registry._from_rest_object(obj) # type: ignore[return-value]
+
+ def _check_registry_name(self, name: Optional[str]) -> str:
+ registry_name = name or self._default_registry_name
+ if not registry_name:
+ msg = "Please provide a registry name or use a MLClient with a registry name set."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.REGISTRY,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ return registry_name
+
+ def _get_polling(self, name: str) -> AzureMLPolling:
+ """Return the polling with custom poll interval.
+
+ :param name: The registry name
+ :type name: str
+ :return: A poller with custom poll interval.
+ :rtype: AzureMLPolling
+ """
+ path_format_arguments = {
+ "registryName": name,
+ "resourceGroupName": self._resource_group_name,
+ }
+ return AzureMLPolling(
+ LROConfigurations.POLL_INTERVAL,
+ path_format_arguments=path_format_arguments,
+ )
+
+ @monitor_with_activity(ops_logger, "Registry.BeginCreate", ActivityType.PUBLICAPI)
+ def begin_create(
+ self,
+ registry: Registry,
+ **kwargs: Dict,
+ ) -> LROPoller[Registry]:
+ """Create a new Azure Machine Learning Registry, or try to update if it already exists.
+
+ Note: Due to service limitations we have to sleep for
+ an additional 30~45 seconds AFTER the LRO Poller concludes
+ before the registry will be consistently deleted from the
+ perspective of subsequent operations.
+ If a deletion is required for subsequent operations to
+ work properly, callers should implement that logic until the
+ service has been fixed to return a reliable LRO.
+
+ :param registry: Registry definition.
+ :type registry: Registry
+ :return: A poller to track the operation status.
+ :rtype: LROPoller
+ """
+ registry_data = registry._to_rest_object()
+ poller = self._operation.begin_create_or_update(
+ resource_group_name=self._resource_group_name,
+ registry_name=registry.name,
+ body=registry_data,
+ polling=self._get_polling(str(registry.name)),
+ cls=lambda response, deserialized, headers: Registry._from_rest_object(deserialized),
+ )
+
+ return poller
+
+ @monitor_with_activity(ops_logger, "Registry.BeginDelete", ActivityType.PUBLICAPI)
+ def begin_delete(self, *, name: str, **kwargs: Dict) -> LROPoller[None]:
+ """Delete a registry if it exists. Returns nothing on a successful operation.
+
+ :keyword name: Name of the registry
+ :paramtype name: str
+ :return: A poller to track the operation status.
+ :rtype: LROPoller
+ """
+ resource_group = kwargs.get("resource_group") or self._resource_group_name
+ return self._operation.begin_delete(
+ resource_group_name=resource_group,
+ registry_name=name,
+ **self._init_kwargs,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_run_history_constants.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_run_history_constants.py
new file mode 100644
index 00000000..17aaac29
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_run_history_constants.py
@@ -0,0 +1,82 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+import os
+
+# default timeout of session for getting content in the job run,
+# the 1st element is conn timeout, the 2nd is the read timeout.
+
+
+class JobStatus(object):
+ """
+ * NotStarted - This is a temporary state client-side Run objects are in before cloud submission.
+ * Starting - The Run has started being processed in the cloud. The caller has a run ID at this point.
+ * Provisioning - Returned when on-demand compute is being created for a given job submission.
+ * Preparing - The run environment is being prepared:
+ * docker image build
+ * conda environment setup
+ * Queued - The job is queued in the compute target. For example, in BatchAI the job is in queued state
+ while waiting for all the requested nodes to be ready.
+ * Running - The job started to run in the compute target.
+ * Finalizing - User code has completed and the run is in post-processing stages.
+ * CancelRequested - Cancellation has been requested for the job.
+ * Completed - The run completed successfully. This includes both the user code and run
+ post-processing stages.
+ * Failed - The run failed. Usually the Error property on a run will provide details as to why.
+ * Canceled - Follows a cancellation request and indicates that the run is now successfully cancelled.
+ * NotResponding - For runs that have Heartbeats enabled, no heartbeat has been recently sent.
+ """
+
+ # Ordered by transition order
+ QUEUED = "Queued"
+ NOT_STARTED = "NotStarted"
+ PREPARING = "Preparing"
+ PROVISIONING = "Provisioning"
+ STARTING = "Starting"
+ RUNNING = "Running"
+ CANCEL_REQUESTED = "CancelRequested"
+ CANCELED = "Canceled" # Not official yet
+ FINALIZING = "Finalizing"
+ COMPLETED = "Completed"
+ FAILED = "Failed"
+ PAUSED = "Paused"
+ NOTRESPONDING = "NotResponding"
+
+
+class RunHistoryConstants(object):
+ _DEFAULT_GET_CONTENT_TIMEOUT = (5, 120)
+ _WAIT_COMPLETION_POLLING_INTERVAL_MIN = os.environ.get("AZUREML_RUN_POLLING_INTERVAL_MIN", 2)
+ _WAIT_COMPLETION_POLLING_INTERVAL_MAX = os.environ.get("AZUREML_RUN_POLLING_INTERVAL_MAX", 60)
+ ALL_STATUSES = [
+ JobStatus.QUEUED,
+ JobStatus.PREPARING,
+ JobStatus.PROVISIONING,
+ JobStatus.STARTING,
+ JobStatus.RUNNING,
+ JobStatus.CANCEL_REQUESTED,
+ JobStatus.CANCELED,
+ JobStatus.FINALIZING,
+ JobStatus.COMPLETED,
+ JobStatus.FAILED,
+ JobStatus.NOT_STARTED,
+ JobStatus.FAILED,
+ JobStatus.PAUSED,
+ JobStatus.NOTRESPONDING,
+ ]
+ IN_PROGRESS_STATUSES = [
+ JobStatus.NOT_STARTED,
+ JobStatus.QUEUED,
+ JobStatus.PREPARING,
+ JobStatus.PROVISIONING,
+ JobStatus.STARTING,
+ JobStatus.RUNNING,
+ ]
+ POST_PROCESSING_STATUSES = [JobStatus.CANCEL_REQUESTED, JobStatus.FINALIZING]
+ TERMINAL_STATUSES = [
+ JobStatus.COMPLETED,
+ JobStatus.FAILED,
+ JobStatus.CANCELED,
+ JobStatus.NOTRESPONDING,
+ JobStatus.PAUSED,
+ ]
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_run_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_run_operations.py
new file mode 100644
index 00000000..627b1936
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_run_operations.py
@@ -0,0 +1,94 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import logging
+from typing import Any, Iterable, Optional, cast
+
+from azure.ai.ml._restclient.runhistory import AzureMachineLearningWorkspaces as RunHistoryServiceClient
+from azure.ai.ml._restclient.runhistory.models import GetRunDataRequest, GetRunDataResult, Run, RunDetails
+from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope, _ScopeDependentOperations
+from azure.ai.ml.constants._common import AZUREML_RESOURCE_PROVIDER, NAMED_RESOURCE_ID_FORMAT, AzureMLResourceType
+from azure.ai.ml.entities._job.base_job import _BaseJob
+from azure.ai.ml.entities._job.job import Job
+from azure.ai.ml.exceptions import JobParsingError
+
+module_logger = logging.getLogger(__name__)
+
+
+class RunOperations(_ScopeDependentOperations):
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client: RunHistoryServiceClient,
+ ):
+ super(RunOperations, self).__init__(operation_scope, operation_config)
+ self._operation = service_client.runs
+
+ def get_run(self, run_id: str) -> Run:
+ return self._operation.get(
+ self._operation_scope.subscription_id,
+ self._operation_scope.resource_group_name,
+ self._workspace_name,
+ run_id,
+ )
+
+ def get_run_details(self, run_id: str) -> RunDetails:
+ return self._operation.get_details(
+ self._operation_scope.subscription_id,
+ self._operation_scope.resource_group_name,
+ self._workspace_name,
+ run_id,
+ )
+
+ def get_run_children(self, run_id: str, **kwargs) -> Iterable[_BaseJob]:
+ return cast(
+ Iterable[_BaseJob],
+ self._operation.get_child(
+ self._subscription_id,
+ self._resource_group_name,
+ self._workspace_name,
+ run_id,
+ top=kwargs.pop("max_results", None),
+ cls=lambda objs: [self._translate_from_rest_object(obj) for obj in objs],
+ ),
+ )
+
+ def _translate_from_rest_object(self, job_object: Run) -> Optional[_BaseJob]:
+ """Handle errors during list operation.
+
+ :param job_object: The job object
+ :type job_object: Run
+ :return: A job entity if parsing was successful
+ :rtype: Optional[_BaseJob]
+ """
+ try:
+ from_rest_job: Any = Job._from_rest_object(job_object)
+ from_rest_job._id = NAMED_RESOURCE_ID_FORMAT.format(
+ self._subscription_id,
+ self._resource_group_name,
+ AZUREML_RESOURCE_PROVIDER,
+ self._workspace_name,
+ AzureMLResourceType.JOB,
+ from_rest_job.name,
+ )
+ return from_rest_job
+ except JobParsingError:
+ return None
+
+ def get_run_data(self, run_id: str) -> GetRunDataResult:
+ run_data_request = GetRunDataRequest(
+ run_id=run_id,
+ select_run_metadata=True,
+ select_run_definition=True,
+ select_job_specification=True,
+ )
+ return self._operation.get_run_data(
+ self._subscription_id,
+ self._resource_group_name,
+ self._workspace_name,
+ body=run_data_request,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_schedule_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_schedule_operations.py
new file mode 100644
index 00000000..8c34af40
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_schedule_operations.py
@@ -0,0 +1,608 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+# pylint: disable=protected-access
+from datetime import datetime, timezone
+from typing import Any, Iterable, List, Optional, Tuple, cast
+
+from azure.ai.ml._restclient.v2023_06_01_preview import AzureMachineLearningWorkspaces as ServiceClient062023Preview
+from azure.ai.ml._restclient.v2024_01_01_preview import AzureMachineLearningWorkspaces as ServiceClient012024Preview
+from azure.ai.ml._scope_dependent_operations import (
+ OperationConfig,
+ OperationsContainer,
+ OperationScope,
+ _ScopeDependentOperations,
+)
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity, monitor_with_telemetry_mixin
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml.entities import Job, JobSchedule, Schedule
+from azure.ai.ml.entities._inputs_outputs.input import Input
+from azure.ai.ml.entities._monitoring.schedule import MonitorSchedule
+from azure.ai.ml.entities._monitoring.signals import (
+ BaselineDataRange,
+ FADProductionData,
+ GenerationTokenStatisticsSignal,
+ LlmData,
+ ProductionData,
+ ReferenceData,
+)
+from azure.ai.ml.entities._monitoring.target import MonitoringTarget
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ScheduleException
+from azure.core.credentials import TokenCredential
+from azure.core.polling import LROPoller
+from azure.core.tracing.decorator import distributed_trace
+
+from .._restclient.v2022_10_01.models import ScheduleListViewType
+from .._restclient.v2024_01_01_preview.models import TriggerOnceRequest
+from .._utils._arm_id_utils import AMLNamedArmId, AMLVersionedArmId, is_ARM_id_for_parented_resource
+from .._utils._azureml_polling import AzureMLPolling
+from .._utils.utils import snake_to_camel
+from ..constants._common import (
+ ARM_ID_PREFIX,
+ AZUREML_RESOURCE_PROVIDER,
+ NAMED_RESOURCE_ID_FORMAT_WITH_PARENT,
+ AzureMLResourceType,
+ LROConfigurations,
+)
+from ..constants._monitoring import (
+ DEPLOYMENT_MODEL_INPUTS_COLLECTION_KEY,
+ DEPLOYMENT_MODEL_INPUTS_NAME_KEY,
+ DEPLOYMENT_MODEL_INPUTS_VERSION_KEY,
+ DEPLOYMENT_MODEL_OUTPUTS_COLLECTION_KEY,
+ DEPLOYMENT_MODEL_OUTPUTS_NAME_KEY,
+ DEPLOYMENT_MODEL_OUTPUTS_VERSION_KEY,
+ MonitorDatasetContext,
+ MonitorSignalType,
+)
+from ..entities._schedule.schedule import ScheduleTriggerResult
+from . import DataOperations, JobOperations, OnlineDeploymentOperations
+from ._job_ops_helper import stream_logs_until_completion
+from ._operation_orchestrator import OperationOrchestrator
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class ScheduleOperations(_ScopeDependentOperations):
+ # pylint: disable=too-many-instance-attributes
+ """
+ ScheduleOperations
+
+ You should not instantiate this class directly.
+ Instead, you should create an MLClient instance that instantiates it for you and attaches it as an attribute.
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client_06_2023_preview: ServiceClient062023Preview,
+ service_client_01_2024_preview: ServiceClient012024Preview,
+ all_operations: OperationsContainer,
+ credential: TokenCredential,
+ **kwargs: Any,
+ ):
+ super(ScheduleOperations, self).__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+ self.service_client = service_client_06_2023_preview.schedules
+ # Note: Trigger once is supported since 24_01, we don't upgrade other operations' client because there are
+ # some breaking changes, for example: AzMonMonitoringAlertNotificationSettings is removed.
+ self.schedule_trigger_service_client = service_client_01_2024_preview.schedules
+ self._all_operations = all_operations
+ self._stream_logs_until_completion = stream_logs_until_completion
+ # Dataplane service clients are lazily created as they are needed
+ self._runs_operations_client = None
+ self._dataset_dataplane_operations_client = None
+ self._model_dataplane_operations_client = None
+ # Kwargs to propagate to dataplane service clients
+ self._service_client_kwargs = kwargs.pop("_service_client_kwargs", {})
+ self._api_base_url = None
+ self._container = "azureml"
+ self._credential = credential
+ self._orchestrators = OperationOrchestrator(self._all_operations, self._operation_scope, self._operation_config)
+
+ self._kwargs = kwargs
+
+ @property
+ def _job_operations(self) -> JobOperations:
+ return cast(
+ JobOperations,
+ self._all_operations.get_operation( # type: ignore[misc]
+ AzureMLResourceType.JOB, lambda x: isinstance(x, JobOperations)
+ ),
+ )
+
+ @property
+ def _online_deployment_operations(self) -> OnlineDeploymentOperations:
+ return cast(
+ OnlineDeploymentOperations,
+ self._all_operations.get_operation( # type: ignore[misc]
+ AzureMLResourceType.ONLINE_DEPLOYMENT, lambda x: isinstance(x, OnlineDeploymentOperations)
+ ),
+ )
+
+ @property
+ def _data_operations(self) -> DataOperations:
+ return cast(
+ DataOperations,
+ self._all_operations.get_operation( # type: ignore[misc]
+ AzureMLResourceType.DATA, lambda x: isinstance(x, DataOperations)
+ ),
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Schedule.List", ActivityType.PUBLICAPI)
+ def list(
+ self,
+ *,
+ list_view_type: ScheduleListViewType = ScheduleListViewType.ENABLED_ONLY,
+ **kwargs: Any,
+ ) -> Iterable[Schedule]:
+ """List schedules in specified workspace.
+
+ :keyword list_view_type: View type for including/excluding (for example)
+ archived schedules. Default: ENABLED_ONLY.
+ :type list_view_type: Optional[ScheduleListViewType]
+ :return: An iterator to list Schedule.
+ :rtype: Iterable[Schedule]
+ """
+
+ def safe_from_rest_object(objs: Any) -> List:
+ result = []
+ for obj in objs:
+ try:
+ result.append(Schedule._from_rest_object(obj))
+ except Exception as e: # pylint: disable=W0718
+ print(f"Translate {obj.name} to Schedule failed with: {e}")
+ return result
+
+ return cast(
+ Iterable[Schedule],
+ self.service_client.list(
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ list_view_type=list_view_type,
+ cls=safe_from_rest_object,
+ **self._kwargs,
+ **kwargs,
+ ),
+ )
+
+ def _get_polling(self, name: Optional[str]) -> AzureMLPolling:
+ """Return the polling with custom poll interval.
+
+ :param name: The schedule name
+ :type name: str
+ :return: The AzureMLPolling object
+ :rtype: AzureMLPolling
+ """
+ path_format_arguments = {
+ "scheduleName": name,
+ "resourceGroupName": self._resource_group_name,
+ "workspaceName": self._workspace_name,
+ }
+ return AzureMLPolling(
+ LROConfigurations.POLL_INTERVAL,
+ path_format_arguments=path_format_arguments,
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Schedule.Delete", ActivityType.PUBLICAPI)
+ def begin_delete(
+ self,
+ name: str,
+ **kwargs: Any,
+ ) -> LROPoller[None]:
+ """Delete schedule.
+
+ :param name: Schedule name.
+ :type name: str
+ :return: A poller for deletion status
+ :rtype: LROPoller[None]
+ """
+ poller = self.service_client.begin_delete(
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ name=name,
+ polling=self._get_polling(name),
+ **self._kwargs,
+ **kwargs,
+ )
+ return poller
+
+ @distributed_trace
+ @monitor_with_telemetry_mixin(ops_logger, "Schedule.Get", ActivityType.PUBLICAPI)
+ def get(
+ self,
+ name: str,
+ **kwargs: Any,
+ ) -> Schedule:
+ """Get a schedule.
+
+ :param name: Schedule name.
+ :type name: str
+ :return: The schedule object.
+ :rtype: Schedule
+ """
+ return self.service_client.get(
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ name=name,
+ cls=lambda _, obj, __: Schedule._from_rest_object(obj),
+ **self._kwargs,
+ **kwargs,
+ )
+
+ @distributed_trace
+ @monitor_with_telemetry_mixin(ops_logger, "Schedule.CreateOrUpdate", ActivityType.PUBLICAPI)
+ def begin_create_or_update(
+ self,
+ schedule: Schedule,
+ **kwargs: Any,
+ ) -> LROPoller[Schedule]:
+ """Create or update schedule.
+
+ :param schedule: Schedule definition.
+ :type schedule: Schedule
+ :return: An instance of LROPoller that returns Schedule if no_wait=True, or Schedule if no_wait=False
+ :rtype: Union[LROPoller, Schedule]
+ :rtype: Union[LROPoller, ~azure.ai.ml.entities.Schedule]
+ """
+
+ if isinstance(schedule, JobSchedule):
+ schedule._validate(raise_error=True)
+ if isinstance(schedule.create_job, Job):
+ # Create all dependent resources for job inside schedule
+ self._job_operations._resolve_arm_id_or_upload_dependencies(schedule.create_job)
+ elif isinstance(schedule, MonitorSchedule):
+ # resolve ARM id for target, compute, and input datasets for each signal
+ self._resolve_monitor_schedule_arm_id(schedule)
+ # Create schedule
+ schedule_data = schedule._to_rest_object() # type: ignore
+ poller = self.service_client.begin_create_or_update(
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ name=schedule.name,
+ cls=lambda _, obj, __: Schedule._from_rest_object(obj),
+ body=schedule_data,
+ polling=self._get_polling(schedule.name),
+ **self._kwargs,
+ **kwargs,
+ )
+ return poller
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Schedule.Enable", ActivityType.PUBLICAPI)
+ def begin_enable(
+ self,
+ name: str,
+ **kwargs: Any,
+ ) -> LROPoller[Schedule]:
+ """Enable a schedule.
+
+ :param name: Schedule name.
+ :type name: str
+ :return: An instance of LROPoller that returns Schedule
+ :rtype: LROPoller
+ """
+ schedule = self.get(name=name)
+ schedule._is_enabled = True
+ return self.begin_create_or_update(schedule, **kwargs)
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Schedule.Disable", ActivityType.PUBLICAPI)
+ def begin_disable(
+ self,
+ name: str,
+ **kwargs: Any,
+ ) -> LROPoller[Schedule]:
+ """Disable a schedule.
+
+ :param name: Schedule name.
+ :type name: str
+ :return: An instance of LROPoller that returns Schedule if no_wait=True, or Schedule if no_wait=False
+ :rtype: LROPoller
+ """
+ schedule = self.get(name=name)
+ schedule._is_enabled = False
+ return self.begin_create_or_update(schedule, **kwargs)
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Schedule.Trigger", ActivityType.PUBLICAPI)
+ def trigger(
+ self,
+ name: str,
+ **kwargs: Any,
+ ) -> ScheduleTriggerResult:
+ """Trigger a schedule once.
+
+ :param name: Schedule name.
+ :type name: str
+ :return: TriggerRunSubmissionDto, or the result of cls(response)
+ :rtype: ~azure.ai.ml.entities.ScheduleTriggerResult
+ """
+ schedule_time = kwargs.pop("schedule_time", datetime.now(timezone.utc).isoformat())
+ return self.schedule_trigger_service_client.trigger(
+ name=name,
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._workspace_name,
+ body=TriggerOnceRequest(schedule_time=schedule_time),
+ cls=lambda _, obj, __: ScheduleTriggerResult._from_rest_object(obj),
+ **kwargs,
+ )
+
+ def _resolve_monitor_schedule_arm_id( # pylint:disable=too-many-branches,too-many-statements,too-many-locals
+ self, schedule: MonitorSchedule
+ ) -> None:
+ # resolve target ARM ID
+ model_inputs_name, model_outputs_name = None, None
+ app_traces_name, app_traces_version = None, None
+ model_inputs_version, model_outputs_version = None, None
+ mdc_input_enabled, mdc_output_enabled = False, False
+ target = schedule.create_monitor.monitoring_target
+ if target and target.endpoint_deployment_id:
+ endpoint_name, deployment_name = self._process_and_get_endpoint_deployment_names_from_id(target)
+ online_deployment = self._online_deployment_operations.get(deployment_name, endpoint_name)
+ deployment_data_collector = online_deployment.data_collector
+ if deployment_data_collector:
+ in_reg = AMLVersionedArmId(deployment_data_collector.collections.get("model_inputs").data)
+ out_reg = AMLVersionedArmId(deployment_data_collector.collections.get("model_outputs").data)
+ if "app_traces" in deployment_data_collector.collections:
+ app_traces = AMLVersionedArmId(deployment_data_collector.collections.get("app_traces").data)
+ app_traces_name = app_traces.asset_name
+ app_traces_version = app_traces.asset_version
+ model_inputs_name = in_reg.asset_name
+ model_inputs_version = in_reg.asset_version
+ model_outputs_name = out_reg.asset_name
+ model_outputs_version = out_reg.asset_version
+ mdc_input_enabled_str = deployment_data_collector.collections.get("model_inputs").enabled
+ mdc_output_enabled_str = deployment_data_collector.collections.get("model_outputs").enabled
+ else:
+ model_inputs_name = online_deployment.tags.get(DEPLOYMENT_MODEL_INPUTS_NAME_KEY)
+ model_inputs_version = online_deployment.tags.get(DEPLOYMENT_MODEL_INPUTS_VERSION_KEY)
+ model_outputs_name = online_deployment.tags.get(DEPLOYMENT_MODEL_OUTPUTS_NAME_KEY)
+ model_outputs_version = online_deployment.tags.get(DEPLOYMENT_MODEL_OUTPUTS_VERSION_KEY)
+ mdc_input_enabled_str = online_deployment.tags.get(DEPLOYMENT_MODEL_INPUTS_COLLECTION_KEY)
+ mdc_output_enabled_str = online_deployment.tags.get(DEPLOYMENT_MODEL_OUTPUTS_COLLECTION_KEY)
+ if mdc_input_enabled_str and mdc_input_enabled_str.lower() == "true":
+ mdc_input_enabled = True
+ if mdc_output_enabled_str and mdc_output_enabled_str.lower() == "true":
+ mdc_output_enabled = True
+ elif target and target.model_id:
+ target.model_id = self._orchestrators.get_asset_arm_id( # type: ignore
+ target.model_id,
+ AzureMLResourceType.MODEL,
+ register_asset=False,
+ )
+
+ if not schedule.create_monitor.monitoring_signals:
+ if mdc_input_enabled and mdc_output_enabled:
+ schedule._create_default_monitor_definition()
+ else:
+ msg = (
+ "An ARM id for a deployment with data collector enabled for model inputs and outputs must be "
+ "given if monitoring_signals is None"
+ )
+ raise ScheduleException(
+ message=msg,
+ no_personal_data_message=msg,
+ target=ErrorTarget.SCHEDULE,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ # resolve ARM id for each signal and populate any defaults if needed
+ for signal_name, signal in schedule.create_monitor.monitoring_signals.items(): # type: ignore
+ if signal.type == MonitorSignalType.GENERATION_SAFETY_QUALITY:
+ for llm_data in signal.production_data: # type: ignore[union-attr]
+ self._job_operations._resolve_job_input(llm_data.input_data, schedule._base_path)
+ continue
+ if signal.type == MonitorSignalType.GENERATION_TOKEN_STATISTICS:
+ if not signal.production_data: # type: ignore[union-attr]
+ # if target dataset is absent and data collector for input is enabled,
+ # create a default target dataset with production app traces as target
+ if isinstance(signal, GenerationTokenStatisticsSignal):
+ signal.production_data = LlmData( # type: ignore[union-attr]
+ input_data=Input(
+ path=f"{app_traces_name}:{app_traces_version}",
+ type=self._data_operations.get(app_traces_name, app_traces_version).type,
+ ),
+ data_window=BaselineDataRange(lookback_window_size="P7D", lookback_window_offset="P0D"),
+ )
+ self._job_operations._resolve_job_input(
+ signal.production_data.input_data, schedule._base_path # type: ignore[union-attr]
+ )
+ continue
+ if signal.type == MonitorSignalType.CUSTOM:
+ if signal.inputs: # type: ignore[union-attr]
+ for inputs in signal.inputs.values(): # type: ignore[union-attr]
+ self._job_operations._resolve_job_input(inputs, schedule._base_path)
+ for data in signal.input_data.values(): # type: ignore[union-attr]
+ if data.input_data is not None:
+ for inputs in data.input_data.values():
+ self._job_operations._resolve_job_input(inputs, schedule._base_path)
+ data.pre_processing_component = self._orchestrators.get_asset_arm_id(
+ asset=data.pre_processing_component if hasattr(data, "pre_processing_component") else None,
+ azureml_type=AzureMLResourceType.COMPONENT,
+ )
+ continue
+ error_messages = []
+ if not signal.production_data or not signal.reference_data: # type: ignore[union-attr]
+ # if there is no target dataset, we check the type of signal
+ if signal.type in {MonitorSignalType.DATA_DRIFT, MonitorSignalType.DATA_QUALITY}:
+ if mdc_input_enabled:
+ if not signal.production_data: # type: ignore[union-attr]
+ # if target dataset is absent and data collector for input is enabled,
+ # create a default target dataset with production model inputs as target
+ signal.production_data = ProductionData( # type: ignore[union-attr]
+ input_data=Input(
+ path=f"{model_inputs_name}:{model_inputs_version}",
+ type=self._data_operations.get(model_inputs_name, model_inputs_version).type,
+ ),
+ data_context=MonitorDatasetContext.MODEL_INPUTS,
+ data_window=BaselineDataRange(
+ lookback_window_size="default", lookback_window_offset="P0D"
+ ),
+ )
+ if not signal.reference_data: # type: ignore[union-attr]
+ signal.reference_data = ReferenceData( # type: ignore[union-attr]
+ input_data=Input(
+ path=f"{model_inputs_name}:{model_inputs_version}",
+ type=self._data_operations.get(model_inputs_name, model_inputs_version).type,
+ ),
+ data_context=MonitorDatasetContext.MODEL_INPUTS,
+ data_window=BaselineDataRange(
+ lookback_window_size="default", lookback_window_offset="default"
+ ),
+ )
+ elif not mdc_input_enabled and not (
+ signal.production_data and signal.reference_data # type: ignore[union-attr]
+ ):
+ # if target or baseline dataset is absent and data collector for input is not enabled,
+ # collect exception message
+ msg = (
+ f"A target and baseline dataset must be provided for signal with name {signal_name}"
+ f"and type {signal.type} if the monitoring_target endpoint_deployment_id is empty"
+ "or refers to a deployment for which data collection for model inputs is not enabled."
+ )
+ error_messages.append(msg)
+ elif signal.type == MonitorSignalType.PREDICTION_DRIFT:
+ if mdc_output_enabled:
+ if not signal.production_data: # type: ignore[union-attr]
+ # if target dataset is absent and data collector for output is enabled,
+ # create a default target dataset with production model outputs as target
+ signal.production_data = ProductionData( # type: ignore[union-attr]
+ input_data=Input(
+ path=f"{model_outputs_name}:{model_outputs_version}",
+ type=self._data_operations.get(model_outputs_name, model_outputs_version).type,
+ ),
+ data_context=MonitorDatasetContext.MODEL_OUTPUTS,
+ data_window=BaselineDataRange(
+ lookback_window_size="default", lookback_window_offset="P0D"
+ ),
+ )
+ if not signal.reference_data: # type: ignore[union-attr]
+ signal.reference_data = ReferenceData( # type: ignore[union-attr]
+ input_data=Input(
+ path=f"{model_outputs_name}:{model_outputs_version}",
+ type=self._data_operations.get(model_outputs_name, model_outputs_version).type,
+ ),
+ data_context=MonitorDatasetContext.MODEL_OUTPUTS,
+ data_window=BaselineDataRange(
+ lookback_window_size="default", lookback_window_offset="default"
+ ),
+ )
+ elif not mdc_output_enabled and not (
+ signal.production_data and signal.reference_data # type: ignore[union-attr]
+ ):
+ # if target dataset is absent and data collector for output is not enabled,
+ # collect exception message
+ msg = (
+ f"A target and baseline dataset must be provided for signal with name {signal_name}"
+ f"and type {signal.type} if the monitoring_target endpoint_deployment_id is empty"
+ "or refers to a deployment for which data collection for model outputs is not enabled."
+ )
+ error_messages.append(msg)
+ elif signal.type == MonitorSignalType.FEATURE_ATTRIBUTION_DRIFT:
+ if mdc_input_enabled:
+ if not signal.production_data: # type: ignore[union-attr]
+ # if production dataset is absent and data collector for input is enabled,
+ # create a default prod dataset with production model inputs and outputs as target
+ signal.production_data = [ # type: ignore[union-attr]
+ FADProductionData(
+ input_data=Input(
+ path=f"{model_inputs_name}:{model_inputs_version}",
+ type=self._data_operations.get(model_inputs_name, model_inputs_version).type,
+ ),
+ data_context=MonitorDatasetContext.MODEL_INPUTS,
+ data_window=BaselineDataRange(
+ lookback_window_size="default", lookback_window_offset="P0D"
+ ),
+ ),
+ FADProductionData(
+ input_data=Input(
+ path=f"{model_outputs_name}:{model_outputs_version}",
+ type=self._data_operations.get(model_outputs_name, model_outputs_version).type,
+ ),
+ data_context=MonitorDatasetContext.MODEL_OUTPUTS,
+ data_window=BaselineDataRange(
+ lookback_window_size="default", lookback_window_offset="P0D"
+ ),
+ ),
+ ]
+ elif not mdc_output_enabled and not signal.production_data: # type: ignore[union-attr]
+ # if target dataset is absent and data collector for output is not enabled,
+ # collect exception message
+ msg = (
+ f"A production data must be provided for signal with name {signal_name}"
+ f"and type {signal.type} if the monitoring_target endpoint_deployment_id is empty"
+ "or refers to a deployment for which data collection for model outputs is not enabled."
+ )
+ error_messages.append(msg)
+ if error_messages:
+ # if any error messages, raise an exception with all of them so user knows which signals
+ # need to be fixed
+ msg = "\n".join(error_messages)
+ raise ScheduleException(
+ message=msg,
+ no_personal_data_message=msg,
+ ErrorTarget=ErrorTarget.SCHEDULE,
+ ErrorCategory=ErrorCategory.USER_ERROR,
+ )
+ if signal.type == MonitorSignalType.FEATURE_ATTRIBUTION_DRIFT:
+ for prod_data in signal.production_data: # type: ignore[union-attr]
+ self._job_operations._resolve_job_input(prod_data.input_data, schedule._base_path)
+ prod_data.pre_processing_component = self._orchestrators.get_asset_arm_id( # type: ignore
+ asset=prod_data.pre_processing_component, # type: ignore[union-attr]
+ azureml_type=AzureMLResourceType.COMPONENT,
+ )
+ self._job_operations._resolve_job_input(
+ signal.reference_data.input_data, schedule._base_path # type: ignore[union-attr]
+ )
+ signal.reference_data.pre_processing_component = self._orchestrators.get_asset_arm_id( # type: ignore
+ asset=signal.reference_data.pre_processing_component, # type: ignore[union-attr]
+ azureml_type=AzureMLResourceType.COMPONENT,
+ )
+ continue
+
+ self._job_operations._resolve_job_inputs(
+ [signal.production_data.input_data, signal.reference_data.input_data], # type: ignore[union-attr]
+ schedule._base_path,
+ )
+ signal.production_data.pre_processing_component = self._orchestrators.get_asset_arm_id( # type: ignore
+ asset=signal.production_data.pre_processing_component, # type: ignore[union-attr]
+ azureml_type=AzureMLResourceType.COMPONENT,
+ )
+ signal.reference_data.pre_processing_component = self._orchestrators.get_asset_arm_id( # type: ignore
+ asset=signal.reference_data.pre_processing_component, # type: ignore[union-attr]
+ azureml_type=AzureMLResourceType.COMPONENT,
+ )
+
+ def _process_and_get_endpoint_deployment_names_from_id(self, target: MonitoringTarget) -> Tuple:
+ target.endpoint_deployment_id = (
+ target.endpoint_deployment_id[len(ARM_ID_PREFIX) :] # type: ignore
+ if target.endpoint_deployment_id is not None and target.endpoint_deployment_id.startswith(ARM_ID_PREFIX)
+ else target.endpoint_deployment_id
+ )
+
+ # if it is an ARM ID, don't process it
+ if not is_ARM_id_for_parented_resource(
+ target.endpoint_deployment_id,
+ snake_to_camel(AzureMLResourceType.ONLINE_ENDPOINT),
+ AzureMLResourceType.DEPLOYMENT,
+ ):
+ endpoint_name, deployment_name = target.endpoint_deployment_id.split(":") # type: ignore
+ target.endpoint_deployment_id = NAMED_RESOURCE_ID_FORMAT_WITH_PARENT.format(
+ self._subscription_id,
+ self._resource_group_name,
+ AZUREML_RESOURCE_PROVIDER,
+ self._workspace_name,
+ snake_to_camel(AzureMLResourceType.ONLINE_ENDPOINT),
+ endpoint_name,
+ AzureMLResourceType.DEPLOYMENT,
+ deployment_name,
+ )
+ else:
+ deployment_arm_id_entity = AMLNamedArmId(target.endpoint_deployment_id)
+ endpoint_name = deployment_arm_id_entity.parent_asset_name
+ deployment_name = deployment_arm_id_entity.asset_name
+
+ return endpoint_name, deployment_name
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_serverless_endpoint_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_serverless_endpoint_operations.py
new file mode 100644
index 00000000..5efec117
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_serverless_endpoint_operations.py
@@ -0,0 +1,223 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import re
+from typing import Iterable
+
+from azure.ai.ml._restclient.v2024_01_01_preview import (
+ AzureMachineLearningWorkspaces as ServiceClient202401Preview,
+)
+from azure.ai.ml._restclient.v2024_01_01_preview.models import (
+ KeyType,
+ RegenerateEndpointKeysRequest,
+)
+from azure.ai.ml._scope_dependent_operations import (
+ OperationConfig,
+ OperationsContainer,
+ OperationScope,
+ _ScopeDependentOperations,
+)
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._experimental import experimental
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml.constants._common import REGISTRY_VERSION_PATTERN, AzureMLResourceType
+from azure.ai.ml.constants._endpoint import EndpointKeyType
+from azure.ai.ml.entities._autogen_entities.models import ServerlessEndpoint
+from azure.ai.ml.entities._endpoint.online_endpoint import EndpointAuthKeys
+from azure.ai.ml.exceptions import (
+ ErrorCategory,
+ ErrorTarget,
+ ValidationErrorType,
+ ValidationException,
+)
+from azure.core.polling import LROPoller
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class ServerlessEndpointOperations(_ScopeDependentOperations):
+ """ServerlessEndpointOperations.
+
+ You should not instantiate this class directly. Instead, you should
+ create an MLClient instance that instantiates it for you and
+ attaches it as an attribute.
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client: ServiceClient202401Preview,
+ all_operations: OperationsContainer,
+ ):
+ super().__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+ self._service_client = service_client.serverless_endpoints
+ self._marketplace_subscriptions = service_client.marketplace_subscriptions
+ self._all_operations = all_operations
+
+ def _get_workspace_location(self) -> str:
+ return str(
+ self._all_operations.all_operations[AzureMLResourceType.WORKSPACE].get(self._workspace_name).location
+ )
+
+ @experimental
+ @monitor_with_activity(ops_logger, "ServerlessEndpoint.BeginCreateOrUpdate", ActivityType.PUBLICAPI)
+ def begin_create_or_update(self, endpoint: ServerlessEndpoint, **kwargs) -> LROPoller[ServerlessEndpoint]:
+ """Create or update a serverless endpoint.
+
+ :param endpoint: The serverless endpoint entity.
+ :type endpoint: ~azure.ai.ml.entities.ServerlessEndpoint
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if ServerlessEndpoint cannot be
+ successfully validated. Details will be provided in the error message.
+ :return: A poller to track the operation status
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.ServerlessEndpoint]
+ """
+ if not endpoint.location:
+ endpoint.location = self._get_workspace_location()
+ if re.match(REGISTRY_VERSION_PATTERN, endpoint.model_id):
+ msg = (
+ "The given model_id {} points to a specific model version, which is not supported. "
+ "Please provide a model_id without the version information."
+ )
+ raise ValidationException(
+ message=msg.format(endpoint.model_id),
+ no_personal_data_message="Invalid model_id given for serverless endpoint",
+ target=ErrorTarget.SERVERLESS_ENDPOINT,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+ return self._service_client.begin_create_or_update(
+ self._resource_group_name,
+ self._workspace_name,
+ endpoint.name,
+ endpoint._to_rest_object(), # type: ignore
+ cls=(
+ lambda response, deserialized, headers: ServerlessEndpoint._from_rest_object( # type: ignore
+ deserialized
+ )
+ ),
+ **kwargs,
+ )
+
+ @experimental
+ @monitor_with_activity(ops_logger, "ServerlessEndpoint.Get", ActivityType.PUBLICAPI)
+ def get(self, name: str, **kwargs) -> ServerlessEndpoint:
+ """Get a Serverless Endpoint resource.
+
+ :param name: Name of the serverless endpoint.
+ :type name: str
+ :return: Serverless endpoint object retrieved from the service.
+ :rtype: ~azure.ai.ml.entities.ServerlessEndpoint
+ """
+ return self._service_client.get(
+ self._resource_group_name,
+ self._workspace_name,
+ name,
+ cls=(
+ lambda response, deserialized, headers: ServerlessEndpoint._from_rest_object( # type: ignore
+ deserialized
+ )
+ ),
+ **kwargs,
+ )
+
+ @experimental
+ @monitor_with_activity(ops_logger, "ServerlessEndpoint.list", ActivityType.PUBLICAPI)
+ def list(self, **kwargs) -> Iterable[ServerlessEndpoint]:
+ """List serverless endpoints of the workspace.
+
+ :return: A list of serverless endpoints
+ :rtype: ~typing.Iterable[~azure.ai.ml.entities.ServerlessEndpoint]
+ """
+ return self._service_client.list(
+ self._resource_group_name,
+ self._workspace_name,
+ cls=lambda objs: [ServerlessEndpoint._from_rest_object(obj) for obj in objs], # type: ignore
+ **kwargs,
+ )
+
+ @experimental
+ @monitor_with_activity(ops_logger, "ServerlessEndpoint.BeginDelete", ActivityType.PUBLICAPI)
+ def begin_delete(self, name: str, **kwargs) -> LROPoller[None]:
+ """Delete a Serverless Endpoint.
+
+ :param name: Name of the serverless endpoint.
+ :type name: str
+ :return: A poller to track the operation status.
+ :rtype: ~azure.core.polling.LROPoller[None]
+ """
+ return self._service_client.begin_delete(
+ self._resource_group_name,
+ self._workspace_name,
+ name,
+ **kwargs,
+ )
+
+ @experimental
+ @monitor_with_activity(ops_logger, "ServerlessEndpoint.GetKeys", ActivityType.PUBLICAPI)
+ def get_keys(self, name: str, **kwargs) -> EndpointAuthKeys:
+ """Get serveless endpoint auth keys.
+
+ :param name: The serverless endpoint name
+ :type name: str
+ :return: Returns the keys of the serverless endpoint
+ :rtype: ~azure.ai.ml.entities.EndpointAuthKeys
+ """
+ return self._service_client.list_keys(
+ self._resource_group_name,
+ self._workspace_name,
+ name,
+ cls=lambda response, deserialized, headers: EndpointAuthKeys._from_rest_object(deserialized),
+ **kwargs,
+ )
+
+ @experimental
+ @monitor_with_activity(ops_logger, "ServerlessEndpoint.BeginRegenerateKeys", ActivityType.PUBLICAPI)
+ def begin_regenerate_keys(
+ self,
+ name: str,
+ *,
+ key_type: str = EndpointKeyType.PRIMARY_KEY_TYPE,
+ **kwargs,
+ ) -> LROPoller[EndpointAuthKeys]:
+ """Regenerate keys for a serverless endpoint.
+
+ :param name: The endpoint name.
+ :type name: str
+ :keyword key_type: One of "primary", "secondary". Defaults to "primary".
+ :paramtype key_type: str
+ :raises ~azure.ai.ml.exceptions.ValidationException: Raised if key_type is not "primary"
+ or "secondary"
+ :return: A poller to track the operation status.
+ :rtype: ~azure.core.polling.LROPoller[EndpointAuthKeys]
+ """
+ keys = self.get_keys(
+ name=name,
+ )
+ if key_type.lower() == EndpointKeyType.PRIMARY_KEY_TYPE:
+ key_request = RegenerateEndpointKeysRequest(key_type=KeyType.Primary, key_value=keys.primary_key)
+ elif key_type.lower() == EndpointKeyType.SECONDARY_KEY_TYPE:
+ key_request = RegenerateEndpointKeysRequest(key_type=KeyType.Secondary, key_value=keys.secondary_key)
+ else:
+ msg = "Key type must be 'primary' or 'secondary'."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.SERVERLESS_ENDPOINT,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ error_type=ValidationErrorType.INVALID_VALUE,
+ )
+
+ return self._service_client.begin_regenerate_keys(
+ resource_group_name=self._resource_group_name,
+ workspace_name=self._workspace_name,
+ endpoint_name=name,
+ body=key_request,
+ cls=lambda response, deserialized, headers: EndpointAuthKeys._from_rest_object(deserialized),
+ **kwargs,
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_virtual_cluster_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_virtual_cluster_operations.py
new file mode 100644
index 00000000..256664ad
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_virtual_cluster_operations.py
@@ -0,0 +1,174 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Dict, Iterable, Optional, cast
+
+from azure.ai.ml._scope_dependent_operations import OperationScope
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._arm_id_utils import AzureResourceId
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._utils._virtual_cluster_utils import (
+ CrossRegionIndexEntitiesRequest,
+ IndexEntitiesRequest,
+ IndexEntitiesRequestFilter,
+ IndexEntitiesRequestOrder,
+ IndexEntitiesRequestOrderDirection,
+ IndexServiceAPIs,
+ index_entity_response_to_job,
+)
+from azure.ai.ml._utils.azure_resource_utils import (
+ get_generic_resource_by_id,
+ get_virtual_cluster_by_name,
+ get_virtual_clusters_from_subscriptions,
+)
+from azure.ai.ml.constants._common import AZUREML_RESOURCE_PROVIDER, LEVEL_ONE_NAMED_RESOURCE_ID_FORMAT, Scope
+from azure.ai.ml.entities import Job
+from azure.ai.ml.exceptions import UserErrorException, ValidationException
+from azure.core.credentials import TokenCredential
+from azure.core.tracing.decorator import distributed_trace
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class VirtualClusterOperations:
+ """VirtualClusterOperations.
+
+ You should not instantiate this class directly. Instead, you should create an MLClient instance that instantiates it
+ for you and attaches it as an attribute.
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ credentials: TokenCredential,
+ *,
+ _service_client_kwargs: Dict,
+ **kwargs: Dict,
+ ):
+ ops_logger.update_filter()
+ self._resource_group_name = operation_scope.resource_group_name
+ self._subscription_id = operation_scope.subscription_id
+ self._credentials = credentials
+ self._init_kwargs = kwargs
+ self._service_client_kwargs = _service_client_kwargs
+
+ """A (mostly) autogenerated rest client for the index service.
+
+ TODO: Remove this property and the rest client when list job by virtual cluster is added to virtual cluster rp
+ """
+ self._index_service = IndexServiceAPIs(
+ credential=self._credentials, base_url="https://westus2.api.azureml.ms", **self._service_client_kwargs
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "VirtualCluster.List", ActivityType.PUBLICAPI)
+ def list(self, *, scope: Optional[str] = None) -> Iterable[Dict]:
+ """List virtual clusters a user has access to.
+
+ :keyword scope: scope of the listing, "subscription" or None, defaults to None.
+ If None, list virtual clusters across all subscriptions a customer has access to.
+ :paramtype scope: str
+ :return: An iterator like instance of dictionaries.
+ :rtype: ~azure.core.paging.ItemPaged[Dict]
+ """
+
+ if scope is None:
+ subscription_list = None
+ elif scope.lower() == Scope.SUBSCRIPTION:
+ subscription_list = [self._subscription_id]
+ else:
+ message = f"Invalid scope: {scope}. Valid values are 'subscription' or None."
+ raise UserErrorException(message=message, no_personal_data_message=message)
+
+ try:
+ return cast(
+ Iterable[Dict],
+ get_virtual_clusters_from_subscriptions(self._credentials, subscription_list=subscription_list),
+ )
+ except ImportError as e:
+ raise UserErrorException(
+ message="Met ImportError when trying to list virtual clusters. "
+ "Please install azure-mgmt-resource to enable this feature; "
+ "and please install azure-mgmt-resource to enable listing virtual clusters "
+ "across all subscriptions a customer has access to."
+ ) from e
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "VirtualCluster.ListJobs", ActivityType.PUBLICAPI)
+ def list_jobs(self, name: str) -> Iterable[Job]:
+ """List of jobs that target the virtual cluster
+
+ :param name: Name of virtual cluster
+ :type name: str
+ :return: An iterable of jobs.
+ :rtype: Iterable[Job]
+ """
+
+ def make_id(entity_type: str) -> str:
+ return str(
+ LEVEL_ONE_NAMED_RESOURCE_ID_FORMAT.format(
+ self._subscription_id, self._resource_group_name, AZUREML_RESOURCE_PROVIDER, entity_type, name
+ )
+ )
+
+ # List of virtual cluster ids to match
+ # Needs to include several capitalizations for historical reasons. Will be fixed in a service side change
+ vc_ids = [make_id("virtualClusters"), make_id("virtualclusters"), make_id("virtualClusters").lower()]
+
+ filters = [
+ IndexEntitiesRequestFilter(field="type", operator="eq", values=["runs"]),
+ IndexEntitiesRequestFilter(field="annotations/archived", operator="eq", values=["false"]),
+ IndexEntitiesRequestFilter(field="properties/userProperties/azureml.VC", operator="eq", values=vc_ids),
+ ]
+ order = [
+ IndexEntitiesRequestOrder(
+ field="properties/creationContext/createdTime", direction=IndexEntitiesRequestOrderDirection.DESC
+ )
+ ]
+ index_entities_request = IndexEntitiesRequest(filters=filters, order=order)
+
+ # cspell:ignore entites
+ return cast(
+ Iterable[Job],
+ self._index_service.index_entities.get_entites_cross_region(
+ body=CrossRegionIndexEntitiesRequest(index_entities_request=index_entities_request),
+ cls=lambda objs: [index_entity_response_to_job(obj) for obj in objs],
+ ),
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "VirtualCluster.Get", ActivityType.PUBLICAPI)
+ def get(self, name: str) -> Dict:
+ """
+ Get a virtual cluster resource. If name is provided, the virtual cluster
+ with the name in the subscription and resource group of the MLClient object
+ will be returned. If an ARM id is provided, a virtual cluster with the id will be returned.
+
+ :param name: Name or ARM ID of the virtual cluster.
+ :type name: str
+ :return: Virtual cluster object
+ :rtype: Dict
+ """
+
+ try:
+ arm_id = AzureResourceId(name)
+ sub_id = arm_id.subscription_id
+
+ return cast(
+ Dict,
+ get_generic_resource_by_id(
+ arm_id=name, credential=self._credentials, subscription_id=sub_id, api_version="2021-03-01-preview"
+ ),
+ )
+ except ValidationException:
+ return cast(
+ Dict,
+ get_virtual_cluster_by_name(
+ name=name,
+ resource_group=self._resource_group_name,
+ subscription_id=self._subscription_id,
+ credential=self._credentials,
+ ),
+ )
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_connections_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_connections_operations.py
new file mode 100644
index 00000000..6debb243
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_connections_operations.py
@@ -0,0 +1,189 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Any, Dict, Iterable, Optional, cast
+
+from azure.ai.ml._restclient.v2024_04_01_preview import AzureMachineLearningWorkspaces as ServiceClient082023Preview
+from azure.ai.ml._scope_dependent_operations import (
+ OperationConfig,
+ OperationsContainer,
+ OperationScope,
+ _ScopeDependentOperations,
+)
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._utils.utils import _snake_to_camel
+from azure.ai.ml.entities._credentials import ApiKeyConfiguration
+from azure.ai.ml.entities._workspace.connections.workspace_connection import WorkspaceConnection
+from azure.core.credentials import TokenCredential
+from azure.core.tracing.decorator import distributed_trace
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class WorkspaceConnectionsOperations(_ScopeDependentOperations):
+ """WorkspaceConnectionsOperations.
+
+ You should not instantiate this class directly. Instead, you should create
+ an MLClient instance that instantiates it for you and attaches it as an attribute.
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ operation_config: OperationConfig,
+ service_client: ServiceClient082023Preview,
+ all_operations: OperationsContainer,
+ credentials: Optional[TokenCredential] = None,
+ **kwargs: Dict,
+ ):
+ super(WorkspaceConnectionsOperations, self).__init__(operation_scope, operation_config)
+ ops_logger.update_filter()
+ self._all_operations = all_operations
+ self._operation = service_client.workspace_connections
+ self._credentials = credentials
+ self._init_kwargs = kwargs
+
+ def _try_fill_api_key(self, connection: WorkspaceConnection) -> None:
+ """Try to fill in a connection's credentials with it's actual values.
+ Connection data retrievals normally return an empty credential object that merely includes the
+ connection's credential type, but not the actual secrets of that credential.
+ However, it's extremely common for users to want to know the contents of their connection's credentials.
+ This method tries to fill in the user's credentials with the actual values by making
+ a secondary API call to the service. It requires that the user have the necessary permissions to do so,
+ and it only works on api key-based credentials.
+
+ :param connection: The connection to try to fill in the credentials for.
+ :type connection: ~azure.ai.ml.entities.WorkspaceConnection
+ """
+ if hasattr(connection, "credentials") and isinstance(connection.credentials, ApiKeyConfiguration):
+ list_secrets_response = self._operation.list_secrets(
+ connection_name=connection.name,
+ resource_group_name=self._operation_scope.resource_group_name,
+ workspace_name=self._operation_scope.workspace_name,
+ )
+ if list_secrets_response.properties.credentials is not None:
+ connection.credentials.key = list_secrets_response.properties.credentials.key
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "WorkspaceConnections.Get", ActivityType.PUBLICAPI)
+ def get(self, name: str, *, populate_secrets: bool = False, **kwargs: Dict) -> WorkspaceConnection:
+ """Get a connection by name.
+
+ :param name: Name of the connection.
+ :type name: str
+ :keyword populate_secrets: If true, make a secondary API call to try filling in the workspace
+ connections credentials. Currently only works for api key-based credentials. Defaults to False.
+ :paramtype populate_secrets: bool
+ :raises ~azure.core.exceptions.HttpResponseError: Raised if the corresponding name and version cannot be
+ retrieved from the service.
+ :return: The connection with the provided name.
+ :rtype: ~azure.ai.ml.entities.WorkspaceConnection
+ """
+
+ connection = WorkspaceConnection._from_rest_object(
+ rest_obj=self._operation.get(
+ workspace_name=self._workspace_name,
+ connection_name=name,
+ **self._scope_kwargs,
+ **kwargs,
+ )
+ )
+
+ if populate_secrets and connection is not None:
+ self._try_fill_api_key(connection)
+ return connection # type: ignore[return-value]
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "WorkspaceConnections.CreateOrUpdate", ActivityType.PUBLICAPI)
+ def create_or_update(
+ self, workspace_connection: WorkspaceConnection, *, populate_secrets: bool = False, **kwargs: Any
+ ) -> WorkspaceConnection:
+ """Create or update a connection.
+
+ :param workspace_connection: Definition of a Workspace Connection or one of its subclasses
+ or object which can be translated to a connection.
+ :type workspace_connection: ~azure.ai.ml.entities.WorkspaceConnection
+ :keyword populate_secrets: If true, make a secondary API call to try filling in the workspace
+ connections credentials. Currently only works for api key-based credentials. Defaults to False.
+ :paramtype populate_secrets: bool
+ :return: Created or update connection.
+ :rtype: ~azure.ai.ml.entities.WorkspaceConnection
+ """
+ rest_workspace_connection = workspace_connection._to_rest_object()
+ response = self._operation.create(
+ workspace_name=self._workspace_name,
+ connection_name=workspace_connection.name,
+ body=rest_workspace_connection,
+ **self._scope_kwargs,
+ **kwargs,
+ )
+ conn = WorkspaceConnection._from_rest_object(rest_obj=response)
+ if populate_secrets and conn is not None:
+ self._try_fill_api_key(conn)
+ return conn
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "WorkspaceConnections.Delete", ActivityType.PUBLICAPI)
+ def delete(self, name: str, **kwargs: Any) -> None:
+ """Delete the connection.
+
+ :param name: Name of the connection.
+ :type name: str
+ """
+
+ self._operation.delete(
+ connection_name=name,
+ workspace_name=self._workspace_name,
+ **self._scope_kwargs,
+ **kwargs,
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "WorkspaceConnections.List", ActivityType.PUBLICAPI)
+ def list(
+ self,
+ connection_type: Optional[str] = None,
+ *,
+ populate_secrets: bool = False,
+ include_data_connections: bool = False,
+ **kwargs: Any,
+ ) -> Iterable[WorkspaceConnection]:
+ """List all connections for a workspace.
+
+ :param connection_type: Type of connection to list.
+ :type connection_type: Optional[str]
+ :keyword populate_secrets: If true, make a secondary API call to try filling in the workspace
+ connections credentials. Currently only works for api key-based credentials. Defaults to False.
+ :paramtype populate_secrets: bool
+ :keyword include_data_connections: If true, also return data connections. Defaults to False.
+ :paramtype include_data_connections: bool
+ :return: An iterator like instance of connection objects
+ :rtype: Iterable[~azure.ai.ml.entities.WorkspaceConnection]
+ """
+
+ if include_data_connections:
+ if "params" in kwargs:
+ kwargs["params"]["includeAll"] = "true"
+ else:
+ kwargs["params"] = {"includeAll": "true"}
+
+ def post_process_conn(rest_obj):
+ result = WorkspaceConnection._from_rest_object(rest_obj)
+ if populate_secrets and result is not None:
+ self._try_fill_api_key(result)
+ return result
+
+ result = self._operation.list(
+ workspace_name=self._workspace_name,
+ cls=lambda objs: [post_process_conn(obj) for obj in objs],
+ category=_snake_to_camel(connection_type) if connection_type else connection_type,
+ **self._scope_kwargs,
+ **kwargs,
+ )
+
+ return cast(Iterable[WorkspaceConnection], result)
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_operations.py
new file mode 100644
index 00000000..47fd8747
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_operations.py
@@ -0,0 +1,443 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+from typing import Any, Dict, Iterable, List, Optional, Union, cast
+
+from marshmallow import ValidationError
+
+from azure.ai.ml._restclient.v2024_10_01_preview import AzureMachineLearningWorkspaces as ServiceClient102024Preview
+from azure.ai.ml._restclient.v2024_10_01_preview.models import ManagedNetworkProvisionOptions
+from azure.ai.ml._scope_dependent_operations import OperationsContainer, OperationScope
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._http_utils import HttpPipeline
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._utils.utils import (
+ _get_workspace_base_url,
+ get_resource_and_group_name_from_resource_id,
+ get_resource_group_name_from_resource_group_id,
+ modified_operation_client,
+)
+from azure.ai.ml.constants._common import AzureMLResourceType, Scope, WorkspaceKind
+from azure.ai.ml.entities import (
+ DiagnoseRequestProperties,
+ DiagnoseResponseResult,
+ DiagnoseResponseResultValue,
+ DiagnoseWorkspaceParameters,
+ ManagedNetworkProvisionStatus,
+ Workspace,
+ WorkspaceKeys,
+)
+from azure.ai.ml.entities._credentials import IdentityConfiguration
+from azure.core.credentials import TokenCredential
+from azure.core.exceptions import HttpResponseError
+from azure.core.polling import LROPoller
+from azure.core.tracing.decorator import distributed_trace
+
+from ._workspace_operations_base import WorkspaceOperationsBase
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class WorkspaceOperations(WorkspaceOperationsBase):
+ """Handles workspaces and its subclasses, hubs and projects.
+
+ You should not instantiate this class directly. Instead, you should create
+ an MLClient instance that instantiates it for you and attaches it as an attribute.
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ service_client: ServiceClient102024Preview,
+ all_operations: OperationsContainer,
+ credentials: Optional[TokenCredential] = None,
+ **kwargs: Any,
+ ):
+ self.dataplane_workspace_operations = (
+ kwargs.pop("dataplane_client").workspaces if kwargs.get("dataplane_client") else None
+ )
+ self._requests_pipeline: HttpPipeline = kwargs.pop("requests_pipeline", None)
+ ops_logger.update_filter()
+ self._provision_network_operation = service_client.managed_network_provisions
+ super().__init__(
+ operation_scope=operation_scope,
+ service_client=service_client,
+ all_operations=all_operations,
+ credentials=credentials,
+ **kwargs,
+ )
+
+ @monitor_with_activity(ops_logger, "Workspace.List", ActivityType.PUBLICAPI)
+ def list(
+ self, *, scope: str = Scope.RESOURCE_GROUP, filtered_kinds: Optional[Union[str, List[str]]] = None
+ ) -> Iterable[Workspace]:
+ """List all Workspaces that the user has access to in the current resource group or subscription.
+
+ :keyword scope: scope of the listing, "resource_group" or "subscription", defaults to "resource_group"
+ :paramtype scope: str
+ :keyword filtered_kinds: The kinds of workspaces to list. If not provided, all workspaces varieties will
+ be listed. Accepts either a single kind, or a list of them.
+ Valid kind options include: "default", "project", and "hub".
+ :return: An iterator like instance of Workspace objects
+ :rtype: ~azure.core.paging.ItemPaged[~azure.ai.ml.entities.Workspace]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START workspace_list]
+ :end-before: [END workspace_list]
+ :language: python
+ :dedent: 8
+ :caption: List the workspaces by resource group or subscription.
+ """
+
+ # Kind should be converted to a comma-separating string if multiple values are supplied.
+ formatted_kinds = filtered_kinds
+ if filtered_kinds and not isinstance(filtered_kinds, str):
+ formatted_kinds = ",".join(filtered_kinds) # type: ignore[arg-type]
+
+ if scope == Scope.SUBSCRIPTION:
+ return cast(
+ Iterable[Workspace],
+ self._operation.list_by_subscription(
+ cls=lambda objs: [Workspace._from_rest_object(obj) for obj in objs],
+ kind=formatted_kinds,
+ ),
+ )
+ return cast(
+ Iterable[Workspace],
+ self._operation.list_by_resource_group(
+ self._resource_group_name,
+ cls=lambda objs: [Workspace._from_rest_object(obj) for obj in objs],
+ kind=formatted_kinds,
+ ),
+ )
+
+ @monitor_with_activity(ops_logger, "Workspace.Get", ActivityType.PUBLICAPI)
+ @distributed_trace
+ # pylint: disable=arguments-renamed
+ def get(self, name: Optional[str] = None, **kwargs: Dict) -> Optional[Workspace]:
+ """Get a Workspace by name.
+
+ :param name: Name of the workspace.
+ :type name: str
+ :return: The workspace with the provided name.
+ :rtype: ~azure.ai.ml.entities.Workspace
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START workspace_get]
+ :end-before: [END workspace_get]
+ :language: python
+ :dedent: 8
+ :caption: Get the workspace with the given name.
+ """
+
+ return super().get(workspace_name=name, **kwargs)
+
+ @monitor_with_activity(ops_logger, "Workspace.Get_Keys", ActivityType.PUBLICAPI)
+ @distributed_trace
+ def get_keys(self, name: Optional[str] = None) -> Optional[WorkspaceKeys]:
+ """Get WorkspaceKeys by workspace name.
+
+ :param name: Name of the workspace.
+ :type name: str
+ :return: Keys of workspace dependent resources.
+ :rtype: ~azure.ai.ml.entities.WorkspaceKeys
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START workspace_get_keys]
+ :end-before: [END workspace_get_keys]
+ :language: python
+ :dedent: 8
+ :caption: Get the workspace keys for the workspace with the given name.
+ """
+ workspace_name = self._check_workspace_name(name)
+ obj = self._operation.list_keys(self._resource_group_name, workspace_name)
+ return WorkspaceKeys._from_rest_object(obj)
+
+ @monitor_with_activity(ops_logger, "Workspace.BeginSyncKeys", ActivityType.PUBLICAPI)
+ @distributed_trace
+ def begin_sync_keys(self, name: Optional[str] = None) -> LROPoller[None]:
+ """Triggers the workspace to immediately synchronize keys. If keys for any resource in the workspace are
+ changed, it can take around an hour for them to automatically be updated. This function enables keys to be
+ updated upon request. An example scenario is needing immediate access to storage after regenerating storage
+ keys.
+
+ :param name: Name of the workspace.
+ :type name: str
+ :return: An instance of LROPoller that returns either None or the sync keys result.
+ :rtype: ~azure.core.polling.LROPoller[None]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START workspace_sync_keys]
+ :end-before: [END workspace_sync_keys]
+ :language: python
+ :dedent: 8
+ :caption: Begin sync keys for the workspace with the given name.
+ """
+ workspace_name = self._check_workspace_name(name)
+ return self._operation.begin_resync_keys(self._resource_group_name, workspace_name)
+
+ @monitor_with_activity(ops_logger, "Workspace.BeginProvisionNetwork", ActivityType.PUBLICAPI)
+ @distributed_trace
+ def begin_provision_network(
+ self,
+ *,
+ workspace_name: Optional[str] = None,
+ include_spark: bool = False,
+ **kwargs: Any,
+ ) -> LROPoller[ManagedNetworkProvisionStatus]:
+ """Triggers the workspace to provision the managed network. Specifying spark enabled
+ as true prepares the workspace managed network for supporting Spark.
+
+ :keyword workspace_name: Name of the workspace.
+ :paramtype workspace_name: str
+ :keyword include_spark: Whether the workspace managed network should prepare to support Spark.
+ :paramtype include_space: bool
+ :return: An instance of LROPoller.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.ManagedNetworkProvisionStatus]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START workspace_provision_network]
+ :end-before: [END workspace_provision_network]
+ :language: python
+ :dedent: 8
+ :caption: Begin provision network for a workspace with managed network.
+ """
+ workspace_name = self._check_workspace_name(workspace_name)
+
+ poller = self._provision_network_operation.begin_provision_managed_network(
+ self._resource_group_name,
+ workspace_name,
+ ManagedNetworkProvisionOptions(include_spark=include_spark),
+ polling=True,
+ cls=lambda response, deserialized, headers: ManagedNetworkProvisionStatus._from_rest_object(deserialized),
+ **kwargs,
+ )
+ module_logger.info("Provision network request initiated for workspace: %s\n", workspace_name)
+ return poller
+
+ @monitor_with_activity(ops_logger, "Workspace.BeginCreate", ActivityType.PUBLICAPI)
+ @distributed_trace
+ # pylint: disable=arguments-differ
+ def begin_create(
+ self,
+ workspace: Workspace,
+ update_dependent_resources: bool = False,
+ **kwargs: Any,
+ ) -> LROPoller[Workspace]:
+ """Create a new Azure Machine Learning Workspace.
+
+ Returns the workspace if already exists.
+
+ :param workspace: Workspace definition.
+ :type workspace: ~azure.ai.ml.entities.Workspace
+ :param update_dependent_resources: Whether to update dependent resources, defaults to False.
+ :type update_dependent_resources: boolean
+ :return: An instance of LROPoller that returns a Workspace.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.Workspace]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START workspace_begin_create]
+ :end-before: [END workspace_begin_create]
+ :language: python
+ :dedent: 8
+ :caption: Begin create for a workspace.
+ """
+ # Add hub values to project if possible
+ if workspace._kind == WorkspaceKind.PROJECT:
+ try:
+ parent_name = workspace._hub_id.split("/")[-1] if workspace._hub_id else ""
+ parent = self.get(parent_name)
+ if parent:
+ # Project location can not differ from hub, so try to force match them if possible.
+ workspace.location = parent.location
+ # Project's technically don't save their PNA, since it implicitly matches their parent's.
+ # However, some PNA-dependent code is run server-side before that alignment is made, so make sure
+ # they're aligned before the request hits the server.
+ workspace.public_network_access = parent.public_network_access
+ except HttpResponseError:
+ module_logger.warning("Failed to get parent hub for project, some values won't be transferred:")
+ try:
+ return super().begin_create(workspace, update_dependent_resources=update_dependent_resources, **kwargs)
+ except HttpResponseError as error:
+ if error.status_code == 403 and workspace._kind == WorkspaceKind.PROJECT:
+ resource_group = kwargs.get("resource_group") or self._resource_group_name
+ hub_name, _ = get_resource_and_group_name_from_resource_id(workspace._hub_id)
+ rest_workspace_obj = self._operation.get(resource_group, hub_name)
+ hub_default_project_resource_group = get_resource_group_name_from_resource_group_id(
+ rest_workspace_obj.workspace_hub_config.default_workspace_resource_group
+ )
+ # we only want to try joining the workspaceHub when the default workspace resource group
+ # is same with the user provided resource group.
+ if hub_default_project_resource_group == resource_group:
+ log_msg = (
+ "User lacked permission to create project workspace,"
+ + "trying to join the workspaceHub default resource group."
+ )
+ module_logger.info(log_msg)
+ return self._begin_join(workspace, **kwargs)
+ raise error
+
+ @monitor_with_activity(ops_logger, "Workspace.BeginUpdate", ActivityType.PUBLICAPI)
+ @distributed_trace
+ def begin_update(
+ self,
+ workspace: Workspace,
+ *,
+ update_dependent_resources: bool = False,
+ **kwargs: Any,
+ ) -> LROPoller[Workspace]:
+ """Updates a Azure Machine Learning Workspace.
+
+ :param workspace: Workspace definition.
+ :type workspace: ~azure.ai.ml.entities.Workspace
+ :keyword update_dependent_resources: Whether to update dependent resources, defaults to False.
+ :paramtype update_dependent_resources: boolean
+ :return: An instance of LROPoller that returns a Workspace.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.Workspace]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START workspace_begin_update]
+ :end-before: [END workspace_begin_update]
+ :language: python
+ :dedent: 8
+ :caption: Begin update for a workspace.
+ """
+ return super().begin_update(workspace, update_dependent_resources=update_dependent_resources, **kwargs)
+
+ @monitor_with_activity(ops_logger, "Workspace.BeginDelete", ActivityType.PUBLICAPI)
+ @distributed_trace
+ def begin_delete(
+ self, name: str, *, delete_dependent_resources: bool, permanently_delete: bool = False, **kwargs: Dict
+ ) -> LROPoller[None]:
+ """Delete a workspace.
+
+ :param name: Name of the workspace
+ :type name: str
+ :keyword delete_dependent_resources: Whether to delete resources associated with the workspace,
+ i.e., container registry, storage account, key vault, application insights, log analytics.
+ The default is False. Set to True to delete these resources.
+ :paramtype delete_dependent_resources: bool
+ :keyword permanently_delete: Workspaces are soft-deleted by default to allow recovery of workspace data.
+ Set this flag to true to override the soft-delete behavior and permanently delete your workspace.
+ :paramtype permanently_delete: bool
+ :return: A poller to track the operation status.
+ :rtype: ~azure.core.polling.LROPoller[None]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START workspace_begin_delete]
+ :end-before: [END workspace_begin_delete]
+ :language: python
+ :dedent: 8
+ :caption: Begin permanent (force) deletion for a workspace and delete dependent resources.
+ """
+ return super().begin_delete(
+ name, delete_dependent_resources=delete_dependent_resources, permanently_delete=permanently_delete, **kwargs
+ )
+
+ @distributed_trace
+ @monitor_with_activity(ops_logger, "Workspace.BeginDiagnose", ActivityType.PUBLICAPI)
+ def begin_diagnose(self, name: str, **kwargs: Dict) -> LROPoller[DiagnoseResponseResultValue]:
+ """Diagnose workspace setup problems.
+
+ If your workspace is not working as expected, you can run this diagnosis to
+ check if the workspace has been broken.
+ For private endpoint workspace, it will also help check if the network
+ setup to this workspace and its dependent resource has problems or not.
+
+ :param name: Name of the workspace
+ :type name: str
+ :return: A poller to track the operation status.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.DiagnoseResponseResultValue]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START workspace_begin_diagnose]
+ :end-before: [END workspace_begin_diagnose]
+ :language: python
+ :dedent: 8
+ :caption: Begin diagnose operation for a workspace.
+ """
+ resource_group = kwargs.get("resource_group") or self._resource_group_name
+ parameters = DiagnoseWorkspaceParameters(value=DiagnoseRequestProperties())._to_rest_object()
+
+ # pylint: disable=unused-argument, docstring-missing-param
+ def callback(_: Any, deserialized: Any, args: Any) -> Optional[DiagnoseResponseResultValue]:
+ """Callback to be called after completion
+
+ :return: DiagnoseResponseResult deserialized.
+ :rtype: ~azure.ai.ml.entities.DiagnoseResponseResult
+ """
+ diagnose_response_result = DiagnoseResponseResult._from_rest_object(deserialized)
+ res = None
+ if diagnose_response_result:
+ res = diagnose_response_result.value
+ return res
+
+ poller = self._operation.begin_diagnose(resource_group, name, parameters, polling=True, cls=callback)
+ module_logger.info("Diagnose request initiated for workspace: %s\n", name)
+ return poller
+
+ @distributed_trace
+ def _begin_join(self, workspace: Workspace, **kwargs: Dict) -> LROPoller[Workspace]:
+ """Join a WorkspaceHub by creating a project workspace under workspaceHub's default resource group.
+
+ :param workspace: Project workspace definition to create
+ :type workspace: Workspace
+ :return: An instance of LROPoller that returns a project Workspace.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.Workspace]
+ """
+ if not workspace._hub_id:
+ raise ValidationError(
+ "{0} is not a Project workspace, join operation can only perform with workspaceHub provided".format(
+ workspace.name
+ )
+ )
+
+ resource_group = kwargs.get("resource_group") or self._resource_group_name
+ hub_name, _ = get_resource_and_group_name_from_resource_id(workspace._hub_id)
+ rest_workspace_obj = self._operation.get(resource_group, hub_name)
+
+ # override the location to the same as the workspaceHub
+ workspace.location = rest_workspace_obj.location
+ # set this to system assigned so ARM will create MSI
+ if not hasattr(workspace, "identity") or not workspace.identity:
+ workspace.identity = IdentityConfiguration(type="system_assigned")
+
+ workspace_operations = self._all_operations.all_operations[AzureMLResourceType.WORKSPACE]
+ workspace_base_uri = _get_workspace_base_url(workspace_operations, hub_name, self._requests_pipeline)
+
+ # pylint:disable=unused-argument
+ def callback(_: Any, deserialized: Any, args: Any) -> Optional[Workspace]:
+ return Workspace._from_rest_object(deserialized)
+
+ with modified_operation_client(self.dataplane_workspace_operations, workspace_base_uri):
+ result = self.dataplane_workspace_operations.begin_hub_join( # type: ignore
+ resource_group_name=resource_group,
+ workspace_name=hub_name,
+ project_workspace_name=workspace.name,
+ body=workspace._to_rest_object(),
+ cls=callback,
+ **self._init_kwargs,
+ )
+ return result
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_operations_base.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_operations_base.py
new file mode 100644
index 00000000..1a3293fe
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_operations_base.py
@@ -0,0 +1,1167 @@
+# pylint: disable=too-many-lines
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+# pylint: disable=protected-access
+
+import time
+from abc import ABC
+from typing import Any, Callable, Dict, MutableMapping, Optional, Tuple
+
+from azure.ai.ml._arm_deployments import ArmDeploymentExecutor
+from azure.ai.ml._arm_deployments.arm_helper import get_template
+from azure.ai.ml._restclient.v2024_10_01_preview import AzureMachineLearningWorkspaces as ServiceClient102024Preview
+from azure.ai.ml._restclient.v2024_10_01_preview.models import (
+ EncryptionKeyVaultUpdateProperties,
+ EncryptionUpdateProperties,
+ WorkspaceUpdateParameters,
+)
+from azure.ai.ml._scope_dependent_operations import OperationsContainer, OperationScope
+from azure.ai.ml._utils._appinsights_utils import get_log_analytics_arm_id
+
+# from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml._utils._workspace_utils import (
+ delete_resource_by_arm_id,
+ get_deployment_name,
+ get_generic_arm_resource_by_arm_id,
+ get_name_for_dependent_resource,
+ get_resource_and_group_name,
+ get_resource_group_location,
+ get_sub_id_resource_and_group_name,
+)
+from azure.ai.ml._utils.utils import camel_to_snake, from_iso_duration_format_min_sec
+from azure.ai.ml._version import VERSION
+from azure.ai.ml.constants import ManagedServiceIdentityType
+from azure.ai.ml.constants._common import (
+ WORKSPACE_PATCH_REJECTED_KEYS,
+ ArmConstants,
+ LROConfigurations,
+ WorkspaceKind,
+ WorkspaceResourceConstants,
+)
+from azure.ai.ml.constants._workspace import IsolationMode, OutboundRuleCategory
+from azure.ai.ml.entities import Hub, Project, Workspace
+from azure.ai.ml.entities._credentials import IdentityConfiguration
+from azure.ai.ml.entities._workspace._ai_workspaces._constants import ENDPOINT_AI_SERVICE_KIND
+from azure.ai.ml.entities._workspace.network_acls import NetworkAcls
+from azure.ai.ml.entities._workspace.networking import ManagedNetwork
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+from azure.core.credentials import TokenCredential
+from azure.core.polling import LROPoller, PollingMethod
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class WorkspaceOperationsBase(ABC):
+ """Base class for WorkspaceOperations, can't be instantiated directly."""
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ service_client: ServiceClient102024Preview,
+ all_operations: OperationsContainer,
+ credentials: Optional[TokenCredential] = None,
+ **kwargs: Dict,
+ ):
+ ops_logger.update_filter()
+ self._subscription_id = operation_scope.subscription_id
+ self._resource_group_name = operation_scope.resource_group_name
+ self._default_workspace_name = operation_scope.workspace_name
+ self._all_operations = all_operations
+ self._operation = service_client.workspaces
+ self._credentials = credentials
+ self._init_kwargs = kwargs
+ self.containerRegistry = "none"
+
+ def get(self, workspace_name: Optional[str] = None, **kwargs: Any) -> Optional[Workspace]:
+ """Get a Workspace by name.
+
+ :param workspace_name: Name of the workspace.
+ :type workspace_name: str
+ :return: The workspace with the provided name.
+ :rtype: ~azure.ai.ml.entities.Workspace
+ """
+ workspace_name = self._check_workspace_name(workspace_name)
+ resource_group = kwargs.get("resource_group") or self._resource_group_name
+ obj = self._operation.get(resource_group, workspace_name)
+ v2_service_context = {}
+
+ v2_service_context["subscription_id"] = self._subscription_id
+ v2_service_context["workspace_name"] = workspace_name
+ v2_service_context["resource_group_name"] = resource_group
+ v2_service_context["auth"] = self._credentials # type: ignore
+
+ from urllib.parse import urlparse
+
+ if obj is not None and obj.ml_flow_tracking_uri:
+ parsed_url = urlparse(obj.ml_flow_tracking_uri)
+ host_url = "https://{}".format(parsed_url.netloc)
+ v2_service_context["host_url"] = host_url
+ else:
+ v2_service_context["host_url"] = ""
+
+ # host_url=service_context._get_mlflow_url(),
+ # cloud=_get_cloud_or_default(
+ # service_context.get_auth()._cloud_type.name
+ if obj is not None and obj.kind is not None and obj.kind.lower() == WorkspaceKind.HUB:
+ return Hub._from_rest_object(obj, v2_service_context)
+ if obj is not None and obj.kind is not None and obj.kind.lower() == WorkspaceKind.PROJECT:
+ return Project._from_rest_object(obj, v2_service_context)
+ return Workspace._from_rest_object(obj, v2_service_context)
+
+ def begin_create(
+ self,
+ workspace: Workspace,
+ update_dependent_resources: bool = False,
+ get_callback: Optional[Callable[[], Workspace]] = None,
+ **kwargs: Any,
+ ) -> LROPoller[Workspace]:
+ """Create a new Azure Machine Learning Workspace.
+
+ Returns the workspace if already exists.
+
+ :param workspace: Workspace definition.
+ :type workspace: ~azure.ai.ml.entities.Workspace
+ :param update_dependent_resources: Whether to update dependent resources, defaults to False.
+ :type update_dependent_resources: boolean
+ :param get_callback: A callable function to call at the end of operation.
+ :type get_callback: Optional[Callable[[], ~azure.ai.ml.entities.Workspace]]
+ :return: An instance of LROPoller that returns a Workspace.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.Workspace]
+ :raises ~azure.ai.ml.ValidationException: Raised if workspace is Project workspace and user
+ specifies any of the following in workspace object: storage_account, container_registry, key_vault,
+ public_network_access, managed_network, customer_managed_key, system_datastores_auth_mode.
+ """
+ existing_workspace = None
+ resource_group = kwargs.get("resource_group") or workspace.resource_group or self._resource_group_name
+ endpoint_resource_id = kwargs.pop("endpoint_resource_id", "")
+ endpoint_kind = kwargs.pop("endpoint_kind", ENDPOINT_AI_SERVICE_KIND)
+
+ try:
+ existing_workspace = self.get(workspace.name, resource_group=resource_group)
+ except Exception: # pylint: disable=W0718
+ pass
+
+ # idempotent behavior
+ if existing_workspace:
+ if workspace.tags is not None and workspace.tags.get("createdByToolkit") is not None:
+ workspace.tags.pop("createdByToolkit")
+ if existing_workspace.tags is not None:
+ existing_workspace.tags.update(workspace.tags) # type: ignore
+ workspace.tags = existing_workspace.tags
+ # TODO do we want projects to do this?
+ if workspace._kind != WorkspaceKind.PROJECT:
+ workspace.container_registry = workspace.container_registry or existing_workspace.container_registry
+ workspace.application_insights = (
+ workspace.application_insights or existing_workspace.application_insights
+ )
+ workspace.identity = workspace.identity or existing_workspace.identity
+ workspace.primary_user_assigned_identity = (
+ workspace.primary_user_assigned_identity or existing_workspace.primary_user_assigned_identity
+ )
+ workspace._feature_store_settings = (
+ workspace._feature_store_settings or existing_workspace._feature_store_settings
+ )
+ return self.begin_update(
+ workspace,
+ update_dependent_resources=update_dependent_resources,
+ **kwargs,
+ )
+ # add tag in the workspace to indicate which sdk version the workspace is created from
+ if workspace.tags is None:
+ workspace.tags = {}
+ if workspace.tags.get("createdByToolkit") is None:
+ workspace.tags["createdByToolkit"] = "sdk-v2-{}".format(VERSION)
+
+ workspace.resource_group = resource_group
+ (
+ template,
+ param,
+ resources_being_deployed,
+ ) = self._populate_arm_parameters(
+ workspace,
+ endpoint_resource_id=endpoint_resource_id,
+ endpoint_kind=endpoint_kind,
+ **kwargs,
+ )
+ # check if create with workspace hub request is valid
+ if workspace._kind == WorkspaceKind.PROJECT:
+ if not all(
+ x is None
+ for x in [
+ workspace.storage_account,
+ workspace.container_registry,
+ workspace.key_vault,
+ ]
+ ):
+ msg = (
+ "To create a project workspace with a workspace hub, please only specify name, description, "
+ + "display_name, location, application insight and identity"
+ )
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.WORKSPACE,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ arm_submit = ArmDeploymentExecutor(
+ credentials=self._credentials,
+ resource_group_name=resource_group,
+ subscription_id=self._subscription_id,
+ deployment_name=get_deployment_name(workspace.name),
+ )
+
+ # deploy_resource() blocks for the poller to succeed if wait is True
+ poller = arm_submit.deploy_resource(
+ template=template,
+ resources_being_deployed=resources_being_deployed,
+ parameters=param,
+ wait=False,
+ )
+
+ def callback() -> Optional[Workspace]:
+ """Callback to be called after completion
+
+ :return: Result of calling appropriate callback.
+ :rtype: Any
+ """
+ return get_callback() if get_callback else self.get(workspace.name, resource_group=resource_group)
+
+ real_callback = callback
+ injected_callback = kwargs.get("cls", None)
+ if injected_callback:
+ # pylint: disable=function-redefined
+ def real_callback() -> Any:
+ """Callback to be called after completion
+
+ :return: Result of calling appropriate callback.
+ :rtype: Any
+ """
+ return injected_callback(callback())
+
+ return LROPoller(
+ self._operation._client,
+ None,
+ lambda *x, **y: None,
+ CustomArmTemplateDeploymentPollingMethod(poller, arm_submit, real_callback),
+ )
+
+ # pylint: disable=too-many-statements,too-many-locals
+ def begin_update(
+ self,
+ workspace: Workspace,
+ *,
+ update_dependent_resources: bool = False,
+ deserialize_callback: Optional[Callable] = None,
+ **kwargs: Any,
+ ) -> LROPoller[Workspace]:
+ """Updates a Azure Machine Learning Workspace.
+
+ :param workspace: Workspace resource.
+ :type workspace: ~azure.ai.ml.entities.Workspace
+ :keyword update_dependent_resources: Whether to update dependent resources, defaults to False.
+ :paramtype update_dependent_resources: boolean
+ :keyword deserialize_callback: A callable function to call at the end of operation.
+ :paramtype deserialize_callback: Optional[Callable[[], ~azure.ai.ml.entities.Workspace]]
+ :return: An instance of LROPoller that returns a Workspace.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.Workspace]
+ :raises ~azure.ai.ml.ValidationException: Raised if updating container_registry for a workspace
+ and update_dependent_resources is not True.
+ :raises ~azure.ai.ml.ValidationException: Raised if updating application_insights for a workspace
+ and update_dependent_resources is not True.
+ """
+ identity = kwargs.get("identity", workspace.identity)
+ workspace_name = kwargs.get("workspace_name", workspace.name)
+ resource_group = kwargs.get("resource_group") or workspace.resource_group or self._resource_group_name
+ existing_workspace: Any = self.get(workspace_name, **kwargs)
+ if identity:
+ identity = identity._to_workspace_rest_object()
+ rest_user_assigned_identities = identity.user_assigned_identities
+ # add the uai resource_id which needs to be deleted (which is not provided in the list)
+ if (
+ existing_workspace
+ and existing_workspace.identity
+ and existing_workspace.identity.user_assigned_identities
+ ):
+ if rest_user_assigned_identities is None:
+ rest_user_assigned_identities = {}
+ for uai in existing_workspace.identity.user_assigned_identities:
+ if uai.resource_id not in rest_user_assigned_identities:
+ rest_user_assigned_identities[uai.resource_id] = None
+ identity.user_assigned_identities = rest_user_assigned_identities
+
+ managed_network = kwargs.get("managed_network", workspace.managed_network)
+ if isinstance(managed_network, str):
+ managed_network = ManagedNetwork(isolation_mode=managed_network)._to_rest_object()
+ elif isinstance(managed_network, ManagedNetwork):
+ if managed_network.outbound_rules is not None:
+ # drop recommended and required rules from the update request since it would result in bad request
+ managed_network.outbound_rules = [
+ rule
+ for rule in managed_network.outbound_rules
+ if rule.category not in (OutboundRuleCategory.REQUIRED, OutboundRuleCategory.RECOMMENDED)
+ ]
+ managed_network = managed_network._to_rest_object()
+
+ container_registry = kwargs.get("container_registry", workspace.container_registry)
+ # Empty string is for erasing the value of container_registry, None is to be ignored value
+ if (
+ container_registry is not None
+ and container_registry != existing_workspace.container_registry
+ and not update_dependent_resources
+ ):
+ msg = (
+ "Updating the workspace-attached Azure Container Registry resource may break lineage of "
+ "previous jobs or your ability to rerun earlier jobs in this workspace. "
+ "Are you sure you want to perform this operation? "
+ "Include the update_dependent_resources argument in SDK or the "
+ "--update-dependent-resources/-u parameter in CLI with this request to confirm."
+ )
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.WORKSPACE,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ application_insights = kwargs.get("application_insights", workspace.application_insights)
+ # Empty string is for erasing the value of application_insights, None is to be ignored value
+ if (
+ application_insights is not None
+ and application_insights != existing_workspace.application_insights
+ and not update_dependent_resources
+ ):
+ msg = (
+ "Updating the workspace-attached Azure Application Insights resource may break lineage "
+ "of deployed inference endpoints this workspace. Are you sure you want to perform this "
+ "operation? "
+ "Include the update_dependent_resources argument in SDK or the "
+ "--update-dependent-resources/-u parameter in CLI with this request to confirm."
+ )
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.WORKSPACE,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+
+ feature_store_settings = kwargs.get("feature_store_settings", workspace._feature_store_settings)
+ if feature_store_settings:
+ feature_store_settings = feature_store_settings._to_rest_object()
+
+ serverless_compute_settings = kwargs.get("serverless_compute", workspace.serverless_compute)
+ if serverless_compute_settings:
+ serverless_compute_settings = serverless_compute_settings._to_rest_object()
+
+ public_network_access = kwargs.get("public_network_access", workspace.public_network_access)
+ network_acls = kwargs.get("network_acls", workspace.network_acls)
+ if network_acls:
+ network_acls = network_acls._to_rest_object() # pylint: disable=protected-access
+
+ if public_network_access == "Disabled" or (
+ existing_workspace
+ and existing_workspace.public_network_access == "Disabled"
+ and public_network_access is None
+ ):
+ network_acls = NetworkAcls()._to_rest_object() # pylint: disable=protected-access
+
+ update_param = WorkspaceUpdateParameters(
+ tags=kwargs.get("tags", workspace.tags),
+ description=kwargs.get("description", workspace.description),
+ friendly_name=kwargs.get("display_name", workspace.display_name),
+ public_network_access=kwargs.get("public_network_access", workspace.public_network_access),
+ system_datastores_auth_mode=kwargs.get(
+ "system_datastores_auth_mode", workspace.system_datastores_auth_mode
+ ),
+ allow_role_assignment_on_rg=kwargs.get(
+ "allow_roleassignment_on_rg", workspace.allow_roleassignment_on_rg
+ ), # diff due to swagger restclient casing diff
+ image_build_compute=kwargs.get("image_build_compute", workspace.image_build_compute),
+ identity=identity,
+ primary_user_assigned_identity=kwargs.get(
+ "primary_user_assigned_identity", workspace.primary_user_assigned_identity
+ ),
+ managed_network=managed_network,
+ feature_store_settings=feature_store_settings,
+ network_acls=network_acls,
+ )
+ if serverless_compute_settings:
+ update_param.serverless_compute_settings = serverless_compute_settings
+ update_param.container_registry = container_registry or None
+ update_param.application_insights = application_insights or None
+
+ # Only the key uri property of customer_managed_key can be updated.
+ # Check if user is updating CMK key uri, if so, add to update_param
+ if workspace.customer_managed_key is not None and workspace.customer_managed_key.key_uri is not None:
+ customer_managed_key_uri = workspace.customer_managed_key.key_uri
+ update_param.encryption = EncryptionUpdateProperties(
+ key_vault_properties=EncryptionKeyVaultUpdateProperties(
+ key_identifier=customer_managed_key_uri,
+ )
+ )
+
+ update_role_assignment = (
+ kwargs.get("update_workspace_role_assignment", None)
+ or kwargs.get("update_offline_store_role_assignment", None)
+ or kwargs.get("update_online_store_role_assignment", None)
+ )
+ grant_materialization_permissions = kwargs.get("grant_materialization_permissions", None)
+
+ # Remove deprecated keys from older workspaces that might still have them before we try to update.
+ if workspace.tags is not None:
+ for bad_key in WORKSPACE_PATCH_REJECTED_KEYS:
+ _ = workspace.tags.pop(bad_key, None)
+
+ # pylint: disable=unused-argument, docstring-missing-param
+ def callback(_: Any, deserialized: Any, args: Any) -> Optional[Workspace]:
+ """Callback to be called after completion
+
+ :return: Workspace deserialized.
+ :rtype: ~azure.ai.ml.entities.Workspace
+ """
+ if (
+ workspace._kind
+ and workspace._kind.lower() == "featurestore"
+ and update_role_assignment
+ and grant_materialization_permissions
+ ):
+ module_logger.info("updating feature store materialization identity role assignments..")
+ template, param, resources_being_deployed = self._populate_feature_store_role_assignment_parameters(
+ workspace, resource_group=resource_group, location=existing_workspace.location, **kwargs
+ )
+
+ arm_submit = ArmDeploymentExecutor(
+ credentials=self._credentials,
+ resource_group_name=resource_group,
+ subscription_id=self._subscription_id,
+ deployment_name=get_deployment_name(workspace.name),
+ )
+
+ # deploy_resource() blocks for the poller to succeed if wait is True
+ poller = arm_submit.deploy_resource(
+ template=template,
+ resources_being_deployed=resources_being_deployed,
+ parameters=param,
+ wait=False,
+ )
+
+ poller.result()
+ return (
+ deserialize_callback(deserialized)
+ if deserialize_callback
+ else Workspace._from_rest_object(deserialized)
+ )
+
+ real_callback = callback
+ injected_callback = kwargs.get("cls", None)
+ if injected_callback:
+ # pylint: disable=function-redefined, docstring-missing-param
+ def real_callback(_: Any, deserialized: Any, args: Any) -> Any:
+ """Callback to be called after completion
+
+ :return: Result of calling appropriate callback.
+ :rtype: Any
+ """
+ return injected_callback(callback(_, deserialized, args))
+
+ poller = self._operation.begin_update(
+ resource_group, workspace_name, update_param, polling=True, cls=real_callback
+ )
+ return poller
+
+ def begin_delete(
+ self, name: str, *, delete_dependent_resources: bool, permanently_delete: bool = False, **kwargs: Any
+ ) -> LROPoller[None]:
+ """Delete a Workspace.
+
+ :param name: Name of the Workspace
+ :type name: str
+ :keyword delete_dependent_resources: Whether to delete resources associated with the Workspace,
+ i.e., container registry, storage account, key vault, application insights, log analytics.
+ The default is False. Set to True to delete these resources.
+ :paramtype delete_dependent_resources: bool
+ :keyword permanently_delete: Workspaces are soft-deleted by default to allow recovery of workspace data.
+ Set this flag to true to override the soft-delete behavior and permanently delete your workspace.
+ :paramtype permanently_delete: bool
+ :return: A poller to track the operation status.
+ :rtype: ~azure.core.polling.LROPoller[None]
+ """
+ workspace: Any = self.get(name, **kwargs)
+ resource_group = kwargs.get("resource_group") or self._resource_group_name
+
+ # prevent dependent resource delete for lean workspace, only delete appinsight and associated log analytics
+ if workspace._kind == WorkspaceKind.PROJECT and delete_dependent_resources:
+ app_insights = get_generic_arm_resource_by_arm_id(
+ self._credentials,
+ self._subscription_id,
+ workspace.application_insights,
+ ArmConstants.AZURE_MGMT_APPINSIGHT_API_VERSION,
+ )
+ if app_insights is not None and "WorkspaceResourceId" in app_insights.properties:
+ delete_resource_by_arm_id(
+ self._credentials,
+ self._subscription_id,
+ app_insights.properties["WorkspaceResourceId"],
+ ArmConstants.AZURE_MGMT_LOGANALYTICS_API_VERSION,
+ )
+ delete_resource_by_arm_id(
+ self._credentials,
+ self._subscription_id,
+ workspace.application_insights,
+ ArmConstants.AZURE_MGMT_APPINSIGHT_API_VERSION,
+ )
+ elif delete_dependent_resources:
+ app_insights = get_generic_arm_resource_by_arm_id(
+ self._credentials,
+ self._subscription_id,
+ workspace.application_insights,
+ ArmConstants.AZURE_MGMT_APPINSIGHT_API_VERSION,
+ )
+ if app_insights is not None and "WorkspaceResourceId" in app_insights.properties:
+ delete_resource_by_arm_id(
+ self._credentials,
+ self._subscription_id,
+ app_insights.properties["WorkspaceResourceId"],
+ ArmConstants.AZURE_MGMT_LOGANALYTICS_API_VERSION,
+ )
+ delete_resource_by_arm_id(
+ self._credentials,
+ self._subscription_id,
+ workspace.application_insights,
+ ArmConstants.AZURE_MGMT_APPINSIGHT_API_VERSION,
+ )
+ delete_resource_by_arm_id(
+ self._credentials,
+ self._subscription_id,
+ workspace.storage_account,
+ ArmConstants.AZURE_MGMT_STORAGE_API_VERSION,
+ )
+ delete_resource_by_arm_id(
+ self._credentials,
+ self._subscription_id,
+ workspace.key_vault,
+ ArmConstants.AZURE_MGMT_KEYVAULT_API_VERSION,
+ )
+ delete_resource_by_arm_id(
+ self._credentials,
+ self._subscription_id,
+ workspace.container_registry,
+ ArmConstants.AZURE_MGMT_CONTAINER_REG_API_VERSION,
+ )
+
+ poller = self._operation.begin_delete(
+ resource_group_name=resource_group,
+ workspace_name=name,
+ force_to_purge=permanently_delete,
+ **self._init_kwargs,
+ )
+ module_logger.info("Delete request initiated for workspace: %s\n", name)
+ return poller
+
+ # pylint: disable=too-many-statements,too-many-branches,too-many-locals
+ def _populate_arm_parameters(self, workspace: Workspace, **kwargs: Any) -> Tuple[dict, dict, dict]:
+ """Populates ARM template parameters for use to deploy a workspace resource.
+
+ :param workspace: Workspace resource.
+ :type workspace: ~azure.ai.ml.entities.Workspace
+ :return: A tuple of three dicts: an ARM template, ARM template parameters, resources_being_deployed.
+ :rtype: Tuple[dict, dict, dict]
+ """
+ resources_being_deployed: Dict = {}
+ if not workspace.location:
+ workspace.location = get_resource_group_location(
+ self._credentials, self._subscription_id, workspace.resource_group
+ )
+ template = get_template(resource_type=ArmConstants.WORKSPACE_BASE)
+ param = get_template(resource_type=ArmConstants.WORKSPACE_PARAM)
+ if workspace._kind == WorkspaceKind.PROJECT:
+ template = get_template(resource_type=ArmConstants.WORKSPACE_PROJECT)
+ endpoint_resource_id = kwargs.get("endpoint_resource_id") or ""
+ endpoint_kind = kwargs.get("endpoint_kind") or ENDPOINT_AI_SERVICE_KIND
+ _set_val(param["workspaceName"], workspace.name)
+ if not workspace.display_name:
+ _set_val(param["friendlyName"], workspace.name)
+ else:
+ _set_val(param["friendlyName"], workspace.display_name)
+
+ if not workspace.description:
+ _set_val(param["description"], workspace.name)
+ else:
+ _set_val(param["description"], workspace.description)
+ _set_val(param["location"], workspace.location)
+
+ if not workspace._kind:
+ _set_val(param["kind"], "default")
+ else:
+ _set_val(param["kind"], workspace._kind)
+
+ _set_val(param["resourceGroupName"], workspace.resource_group)
+
+ if workspace.key_vault:
+ resource_name, group_name = get_resource_and_group_name(workspace.key_vault)
+ _set_val(param["keyVaultName"], resource_name)
+ _set_val(param["keyVaultOption"], "existing")
+ _set_val(param["keyVaultResourceGroupName"], group_name)
+ else:
+ key_vault = _generate_key_vault(workspace.name, resources_being_deployed)
+ _set_val(param["keyVaultName"], key_vault)
+ _set_val(
+ param["keyVaultResourceGroupName"],
+ workspace.resource_group,
+ )
+
+ if workspace.storage_account:
+ subscription_id, resource_name, group_name = get_sub_id_resource_and_group_name(workspace.storage_account)
+ _set_val(param["storageAccountName"], resource_name)
+ _set_val(param["storageAccountOption"], "existing")
+ _set_val(param["storageAccountResourceGroupName"], group_name)
+ _set_val(param["storageAccountSubscriptionId"], subscription_id)
+ else:
+ storage = _generate_storage(workspace.name, resources_being_deployed)
+ _set_val(param["storageAccountName"], storage)
+ _set_val(
+ param["storageAccountResourceGroupName"],
+ workspace.resource_group,
+ )
+ _set_val(
+ param["storageAccountSubscriptionId"],
+ self._subscription_id,
+ )
+
+ if workspace.application_insights:
+ resource_name, group_name = get_resource_and_group_name(workspace.application_insights)
+ _set_val(param["applicationInsightsName"], resource_name)
+ _set_val(param["applicationInsightsOption"], "existing")
+ _set_val(
+ param["applicationInsightsResourceGroupName"],
+ group_name,
+ )
+ elif workspace._kind and workspace._kind.lower() in {WorkspaceKind.HUB, WorkspaceKind.PROJECT}:
+ _set_val(param["applicationInsightsOption"], "none")
+ # Set empty values because arm templates whine over unset values.
+ _set_val(param["applicationInsightsName"], "ignoredButCantBeEmpty")
+ _set_val(
+ param["applicationInsightsResourceGroupName"],
+ "ignoredButCantBeEmpty",
+ )
+ else:
+ log_analytics = _generate_log_analytics(workspace.name, resources_being_deployed)
+ _set_val(param["logAnalyticsName"], log_analytics)
+ _set_val(
+ param["logAnalyticsArmId"],
+ get_log_analytics_arm_id(self._subscription_id, self._resource_group_name, log_analytics),
+ )
+
+ app_insights = _generate_app_insights(workspace.name, resources_being_deployed)
+ _set_val(param["applicationInsightsName"], app_insights)
+ _set_val(
+ param["applicationInsightsResourceGroupName"],
+ workspace.resource_group,
+ )
+
+ if workspace.container_registry:
+ resource_name, group_name = get_resource_and_group_name(workspace.container_registry)
+ _set_val(param["containerRegistryName"], resource_name)
+ _set_val(param["containerRegistryOption"], "existing")
+ _set_val(param["containerRegistryResourceGroupName"], group_name)
+
+ if workspace.customer_managed_key:
+ _set_val(param["cmk_keyvault"], workspace.customer_managed_key.key_vault)
+ _set_val(param["resource_cmk_uri"], workspace.customer_managed_key.key_uri)
+ _set_val(
+ param["encryption_status"],
+ WorkspaceResourceConstants.ENCRYPTION_STATUS_ENABLED,
+ )
+ _set_val(
+ param["encryption_cosmosdb_resourceid"],
+ workspace.customer_managed_key.cosmosdb_id,
+ )
+ _set_val(
+ param["encryption_storage_resourceid"],
+ workspace.customer_managed_key.storage_id,
+ )
+ _set_val(
+ param["encryption_search_resourceid"],
+ workspace.customer_managed_key.search_id,
+ )
+
+ if workspace.hbi_workspace:
+ _set_val(param["confidential_data"], "true")
+
+ if workspace.public_network_access:
+ _set_val(param["publicNetworkAccess"], workspace.public_network_access)
+ _set_val(param["associatedResourcePNA"], workspace.public_network_access)
+
+ if workspace.system_datastores_auth_mode:
+ _set_val(param["systemDatastoresAuthMode"], workspace.system_datastores_auth_mode)
+
+ if workspace.allow_roleassignment_on_rg is False:
+ _set_val(param["allowRoleAssignmentOnRG"], "false")
+
+ if workspace.image_build_compute:
+ _set_val(param["imageBuildCompute"], workspace.image_build_compute)
+
+ if workspace.tags:
+ for key, val in workspace.tags.items():
+ param["tagValues"]["value"][key] = val
+
+ identity = None
+ if workspace.identity:
+ identity = workspace.identity._to_workspace_rest_object()
+ else:
+ identity = IdentityConfiguration(
+ type=camel_to_snake(ManagedServiceIdentityType.SYSTEM_ASSIGNED)
+ )._to_workspace_rest_object()
+ _set_val(param["identity"], identity)
+
+ if workspace.primary_user_assigned_identity:
+ _set_val(param["primaryUserAssignedIdentity"], workspace.primary_user_assigned_identity)
+
+ if workspace._feature_store_settings:
+ _set_val(
+ param["spark_runtime_version"], workspace._feature_store_settings.compute_runtime.spark_runtime_version
+ )
+ if workspace._feature_store_settings.offline_store_connection_name:
+ _set_val(
+ param["offline_store_connection_name"],
+ workspace._feature_store_settings.offline_store_connection_name,
+ )
+ if workspace._feature_store_settings.online_store_connection_name:
+ _set_val(
+ param["online_store_connection_name"],
+ workspace._feature_store_settings.online_store_connection_name,
+ )
+
+ if workspace._kind and workspace._kind.lower() == "featurestore":
+ materialization_identity = kwargs.get("materialization_identity", None)
+ offline_store_target = kwargs.get("offline_store_target", None)
+ online_store_target = kwargs.get("online_store_target", None)
+
+ from azure.ai.ml._utils._arm_id_utils import AzureResourceId, AzureStorageContainerResourceId
+
+ # set workspace storage account access auth type to identity-based
+ _set_val(param["systemDatastoresAuthMode"], "identity")
+
+ if offline_store_target:
+ arm_id = AzureStorageContainerResourceId(offline_store_target)
+ _set_val(param["offlineStoreStorageAccountOption"], "existing")
+ _set_val(param["offline_store_container_name"], arm_id.container)
+ _set_val(param["offline_store_storage_account_name"], arm_id.storage_account)
+ _set_val(param["offline_store_resource_group_name"], arm_id.resource_group_name)
+ _set_val(param["offline_store_subscription_id"], arm_id.subscription_id)
+ else:
+ _set_val(param["offlineStoreStorageAccountOption"], "new")
+ _set_val(
+ param["offline_store_container_name"],
+ _generate_storage_container(workspace.name, resources_being_deployed),
+ )
+ if not workspace.storage_account:
+ _set_val(param["offline_store_storage_account_name"], param["storageAccountName"]["value"])
+ else:
+ _set_val(
+ param["offline_store_storage_account_name"],
+ _generate_storage(workspace.name, resources_being_deployed),
+ )
+ _set_val(param["offline_store_resource_group_name"], workspace.resource_group)
+ _set_val(param["offline_store_subscription_id"], self._subscription_id)
+
+ if online_store_target:
+ arm_id = AzureResourceId(online_store_target)
+ _set_val(param["online_store_resource_id"], online_store_target)
+ _set_val(param["online_store_resource_group_name"], arm_id.resource_group_name)
+ _set_val(param["online_store_subscription_id"], arm_id.subscription_id)
+
+ if materialization_identity:
+ arm_id = AzureResourceId(materialization_identity.resource_id)
+ _set_val(param["materializationIdentityOption"], "existing")
+ _set_val(param["materialization_identity_name"], arm_id.asset_name)
+ _set_val(param["materialization_identity_resource_group_name"], arm_id.resource_group_name)
+ _set_val(param["materialization_identity_subscription_id"], arm_id.subscription_id)
+ else:
+ _set_val(param["materializationIdentityOption"], "new")
+ _set_val(
+ param["materialization_identity_name"],
+ _generate_materialization_identity(workspace, self._subscription_id, resources_being_deployed),
+ )
+ _set_val(param["materialization_identity_resource_group_name"], workspace.resource_group)
+ _set_val(param["materialization_identity_subscription_id"], self._subscription_id)
+
+ if not kwargs.get("grant_materialization_permissions", None):
+ _set_val(param["grant_materialization_permissions"], "false")
+
+ if workspace.provision_network_now:
+ _set_val(param["provisionNetworkNow"], "true")
+
+ managed_network = None
+ if workspace.managed_network:
+ managed_network = workspace.managed_network._to_rest_object()
+ if workspace.managed_network.isolation_mode in [
+ IsolationMode.ALLOW_INTERNET_OUTBOUND,
+ IsolationMode.ALLOW_ONLY_APPROVED_OUTBOUND,
+ ]:
+ _set_val(param["associatedResourcePNA"], "Disabled")
+ else:
+ managed_network = ManagedNetwork(isolation_mode=IsolationMode.DISABLED)._to_rest_object()
+ _set_obj_val(param["managedNetwork"], managed_network)
+ if workspace.enable_data_isolation:
+ _set_val(param["enable_data_isolation"], "true")
+
+ if workspace._kind and workspace._kind.lower() == WorkspaceKind.HUB:
+ _set_obj_val(param["workspace_hub_config"], workspace._hub_values_to_rest_object()) # type: ignore
+ # A user-supplied resource ID (either AOAI or AI Services or null)
+ # endpoint_kind differentiates between a 'Bring a legacy AOAI resource hub' and 'any other kind of hub'
+ # The former doesn't create non-AOAI endpoints, and is set below if the user provided a byo AOAI
+ # resource ID. The latter case is the default and not shown here.
+ if endpoint_resource_id != "":
+ _set_val(param["endpoint_resource_id"], endpoint_resource_id)
+ _set_val(param["endpoint_kind"], endpoint_kind)
+
+ # Lean related param
+ if (
+ hasattr(workspace, "_kind")
+ and workspace._kind is not None
+ and workspace._kind.lower() == WorkspaceKind.PROJECT
+ ):
+ if hasattr(workspace, "_hub_id"):
+ _set_val(param["workspace_hub"], workspace._hub_id)
+
+ # Serverless compute related param
+ serverless_compute = workspace.serverless_compute if workspace.serverless_compute else None
+ if serverless_compute:
+ _set_obj_val(param["serverless_compute_settings"], serverless_compute._to_rest_object())
+
+ resources_being_deployed[workspace.name] = (ArmConstants.WORKSPACE, None)
+ return template, param, resources_being_deployed
+
+ def _populate_feature_store_role_assignment_parameters(
+ self, workspace: Workspace, **kwargs: Any
+ ) -> Tuple[dict, dict, dict]:
+ """Populates ARM template parameters for use to update feature store materialization identity role assignments.
+
+ :param workspace: Workspace resource.
+ :type workspace: ~azure.ai.ml.entities.Workspace
+ :return: A tuple of three dicts: an ARM template, ARM template parameters, resources_being_deployed.
+ :rtype: Tuple[dict, dict, dict]
+ """
+ resources_being_deployed = {}
+ template = get_template(resource_type=ArmConstants.FEATURE_STORE_ROLE_ASSIGNMENTS)
+ param = get_template(resource_type=ArmConstants.FEATURE_STORE_ROLE_ASSIGNMENTS_PARAM)
+
+ materialization_identity_id = kwargs.get("materialization_identity_id", None)
+ if materialization_identity_id:
+ _set_val(param["materialization_identity_resource_id"], materialization_identity_id)
+
+ _set_val(param["workspace_name"], workspace.name)
+ resource_group = kwargs.get("resource_group", workspace.resource_group)
+ _set_val(param["resource_group_name"], resource_group)
+ location = kwargs.get("location", workspace.location)
+ _set_val(param["location"], location)
+
+ update_workspace_role_assignment = kwargs.get("update_workspace_role_assignment", None)
+ if update_workspace_role_assignment:
+ _set_val(param["update_workspace_role_assignment"], "true")
+ update_offline_store_role_assignment = kwargs.get("update_offline_store_role_assignment", None)
+ if update_offline_store_role_assignment:
+ _set_val(param["update_offline_store_role_assignment"], "true")
+ update_online_store_role_assignment = kwargs.get("update_online_store_role_assignment", None)
+ if update_online_store_role_assignment:
+ _set_val(param["update_online_store_role_assignment"], "true")
+
+ offline_store_target = kwargs.get("offline_store_target", None)
+ online_store_target = kwargs.get("online_store_target", None)
+
+ from azure.ai.ml._utils._arm_id_utils import AzureResourceId
+
+ if offline_store_target:
+ arm_id = AzureResourceId(offline_store_target)
+ _set_val(param["offline_store_target"], offline_store_target)
+ _set_val(param["offline_store_resource_group_name"], arm_id.resource_group_name)
+ _set_val(param["offline_store_subscription_id"], arm_id.subscription_id)
+
+ if online_store_target:
+ arm_id = AzureResourceId(online_store_target)
+ _set_val(param["online_store_target"], online_store_target)
+ _set_val(param["online_store_resource_group_name"], arm_id.resource_group_name)
+ _set_val(param["online_store_subscription_id"], arm_id.subscription_id)
+
+ resources_being_deployed[materialization_identity_id] = (ArmConstants.USER_ASSIGNED_IDENTITIES, None)
+ return template, param, resources_being_deployed
+
+ def _check_workspace_name(self, name: Optional[str]) -> str:
+ """Validates that a workspace name exists.
+
+ :param name: Name for a workspace resource.
+ :type name: str
+ :return: No Return.
+ :rtype: None
+ :raises ~azure.ai.ml.ValidationException: Raised if updating nothing is specified for name and
+ MLClient does not have workspace name set.
+ """
+ workspace_name = name or self._default_workspace_name
+ if not workspace_name:
+ msg = "Please provide a workspace name or use a MLClient with a workspace name set."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.WORKSPACE,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ return workspace_name
+
+
+def _set_val(dict: dict, val: Optional[str]) -> None:
+ """Sets the value of a reference in parameters dict to a certain value.
+
+ :param dict: Dict for a certain parameter.
+ :type dict: dict
+ :param val: The value to set for "value" in the passed in dict.
+ :type val: str
+ :return: No Return.
+ :rtype: None
+ """
+ dict["value"] = val
+
+
+def _set_obj_val(dict: dict, val: Any) -> None:
+ """Serializes a JSON string into the parameters dict.
+
+ :param dict: Parameters dict.
+ :type dict: dict
+ :param val: The obj to serialize.
+ :type val: Any type. Must have `.serialize() -> MutableMapping[str, Any]` method.
+ :return: No Return.
+ :rtype: None
+ """
+ from copy import deepcopy
+
+ json: MutableMapping[str, Any] = val.serialize()
+ dict["value"] = deepcopy(json)
+
+
+def _generate_key_vault(name: Optional[str], resources_being_deployed: dict) -> str:
+ """Generates a name for a key vault resource to be created with workspace based on workspace name,
+ sets name and type in resources_being_deployed.
+
+ :param name: The name for the related workspace.
+ :type name: str
+ :param resources_being_deployed: Dict for resources being deployed.
+ :type resources_being_deployed: dict
+ :return: String for name of key vault.
+ :rtype: str
+ """
+ # Vault name must only contain alphanumeric characters and dashes and cannot start with a number.
+ # Vault name must be between 3-24 alphanumeric characters.
+ # The name must begin with a letter, end with a letter or digit, and not contain consecutive hyphens.
+ key_vault = get_name_for_dependent_resource(name, "keyvault")
+ resources_being_deployed[key_vault] = (ArmConstants.KEY_VAULT, None)
+ return str(key_vault)
+
+
+def _generate_storage(name: Optional[str], resources_being_deployed: dict) -> str:
+ """Generates a name for a storage account resource to be created with workspace based on workspace name,
+ sets name and type in resources_being_deployed.
+
+ :param name: The name for the related workspace.
+ :type name: str
+ :param resources_being_deployed: Dict for resources being deployed.
+ :type resources_being_deployed: dict
+ :return: String for name of storage account.
+ :rtype: str
+ """
+ storage = get_name_for_dependent_resource(name, "storage")
+ resources_being_deployed[storage] = (ArmConstants.STORAGE, None)
+ return str(storage)
+
+
+def _generate_storage_container(name: Optional[str], resources_being_deployed: dict) -> str:
+ """Generates a name for a storage container resource to be created with workspace based on workspace name,
+ sets name and type in resources_being_deployed.
+
+ :param name: The name for the related workspace.
+ :type name: str
+ :param resources_being_deployed: Dict for resources being deployed.
+ :type resources_being_deployed: dict
+ :return: String for name of storage container
+ :rtype: str
+ """
+ storage_container = get_name_for_dependent_resource(name, "container")
+ resources_being_deployed[storage_container] = (ArmConstants.STORAGE_CONTAINER, None)
+ return str(storage_container)
+
+
+def _generate_log_analytics(name: Optional[str], resources_being_deployed: dict) -> str:
+ """Generates a name for a log analytics resource to be created with workspace based on workspace name,
+ sets name and type in resources_being_deployed.
+
+ :param name: The name for the related workspace.
+ :type name: str
+ :param resources_being_deployed: Dict for resources being deployed.
+ :type resources_being_deployed: dict
+ :return: String for name of log analytics.
+ :rtype: str
+ """
+ log_analytics = get_name_for_dependent_resource(name, "logalytics") # cspell:disable-line
+ resources_being_deployed[log_analytics] = (
+ ArmConstants.LOG_ANALYTICS,
+ None,
+ )
+ return str(log_analytics)
+
+
+def _generate_app_insights(name: Optional[str], resources_being_deployed: dict) -> str:
+ """Generates a name for an application insights resource to be created with workspace based on workspace name,
+ sets name and type in resources_being_deployed.
+
+ :param name: The name for the related workspace.
+ :type name: str
+ :param resources_being_deployed: Dict for resources being deployed.
+ :type resources_being_deployed: dict
+ :return: String for name of app insights.
+ :rtype: str
+ """
+ # Application name only allows alphanumeric characters, periods, underscores,
+ # hyphens and parenthesis and cannot end in a period
+ app_insights = get_name_for_dependent_resource(name, "insights")
+ resources_being_deployed[app_insights] = (
+ ArmConstants.APP_INSIGHTS,
+ None,
+ )
+ return str(app_insights)
+
+
+def _generate_container_registry(name: Optional[str], resources_being_deployed: dict) -> str:
+ """Generates a name for a container registry resource to be created with workspace based on workspace name,
+ sets name and type in resources_being_deployed.
+
+ :param name: The name for the related workspace.
+ :type name: str
+ :param resources_being_deployed: Dict for resources being deployed.
+ :type resources_being_deployed: dict
+ :return: String for name of container registry.
+ :rtype: str
+ """
+ # Application name only allows alphanumeric characters, periods, underscores,
+ # hyphens and parenthesis and cannot end in a period
+ con_reg = get_name_for_dependent_resource(name, "containerRegistry")
+ resources_being_deployed[con_reg] = (
+ ArmConstants.CONTAINER_REGISTRY,
+ None,
+ )
+ return str(con_reg)
+
+
+def _generate_materialization_identity(
+ workspace: Workspace, subscription_id: str, resources_being_deployed: dict
+) -> str:
+ """Generates a name for a materialization identity resource to be created
+ with feature store based on workspace information,
+ sets name and type in resources_being_deployed.
+
+ :param workspace: The workspace object.
+ :type workspace: Workspace
+ :param subscription_id: The subscription id
+ :type subscription_id: str
+ :param resources_being_deployed: Dict for resources being deployed.
+ :type resources_being_deployed: dict
+ :return: String for name of materialization identity.
+ :rtype: str
+ """
+ import uuid
+
+ namespace = ""
+ namespace_raw = f"{subscription_id[:12]}_{str(workspace.resource_group)[:12]}_{workspace.location}"
+ for char in namespace_raw.lower():
+ if char.isalpha() or char.isdigit():
+ namespace = namespace + char
+ namespace = namespace.encode("utf-8").hex()
+ uuid_namespace = uuid.UUID(namespace[:32].ljust(32, "0"))
+ materialization_identity = f"materialization-uai-" f"{uuid.uuid3(uuid_namespace, str(workspace.name).lower()).hex}"
+ resources_being_deployed[materialization_identity] = (
+ ArmConstants.USER_ASSIGNED_IDENTITIES,
+ None,
+ )
+ return materialization_identity
+
+
+class CustomArmTemplateDeploymentPollingMethod(PollingMethod):
+ """A custom polling method for ARM template deployment used internally for workspace creation."""
+
+ def __init__(self, poller: Any, arm_submit: Any, func: Any) -> None:
+ self.poller = poller
+ self.arm_submit = arm_submit
+ self.func = func
+ super().__init__()
+
+ def resource(self) -> Any:
+ """
+ Polls for the resource creation completing every so often with ability to cancel deployment and outputs
+ either error or executes function to "deserialize" result.
+
+ :return: The response from polling result and calling func from CustomArmTemplateDeploymentPollingMethod
+ :rtype: Any
+ """
+ error: Any = None
+ try:
+ while not self.poller.done():
+ try:
+ time.sleep(LROConfigurations.SLEEP_TIME)
+ self.arm_submit._check_deployment_status()
+ except KeyboardInterrupt as e:
+ self.arm_submit._client.close()
+ error = e
+ raise
+
+ if self.poller._exception is not None:
+ error = self.poller._exception
+ except Exception as e: # pylint: disable=W0718
+ error = e
+ finally:
+ # one last check to make sure all print statements make it
+ if not isinstance(error, KeyboardInterrupt):
+ self.arm_submit._check_deployment_status()
+ total_duration = self.poller.result().properties.duration
+
+ if error is not None:
+ error_msg = f"Unable to create resource. \n {error}\n"
+ module_logger.error(error_msg)
+ raise error
+ module_logger.info(
+ "Total time : %s\n", from_iso_duration_format_min_sec(total_duration) # pylint: disable=E0606
+ )
+ return self.func()
+
+ # pylint: disable=docstring-missing-param
+ def initialize(self, *args: Any, **kwargs: Any) -> None:
+ """
+ unused stub overridden from ABC
+
+ :return: No return.
+ :rtype: ~azure.ai.ml.entities.OutboundRule
+ """
+
+ def finished(self) -> None:
+ """
+ unused stub overridden from ABC
+
+ :return: No return.
+ :rtype: ~azure.ai.ml.entities.OutboundRule
+ """
+
+ def run(self) -> None:
+ """
+ unused stub overridden from ABC
+
+ :return: No return.
+ :rtype: ~azure.ai.ml.entities.OutboundRule
+ """
+
+ def status(self) -> None:
+ """
+ unused stub overridden from ABC
+
+ :return: No return.
+ :rtype: ~azure.ai.ml.entities.OutboundRule
+ """
diff --git a/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_outbound_rule_operations.py b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_outbound_rule_operations.py
new file mode 100644
index 00000000..e7fde295
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/azure/ai/ml/operations/_workspace_outbound_rule_operations.py
@@ -0,0 +1,245 @@
+# ---------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# ---------------------------------------------------------
+
+from typing import Any, Dict, Iterable, Optional
+
+from azure.ai.ml._restclient.v2024_10_01_preview import AzureMachineLearningWorkspaces as ServiceClient102024Preview
+from azure.ai.ml._restclient.v2024_10_01_preview.models import OutboundRuleBasicResource
+from azure.ai.ml._scope_dependent_operations import OperationsContainer, OperationScope
+from azure.ai.ml._telemetry import ActivityType, monitor_with_activity
+from azure.ai.ml._utils._logger_utils import OpsLogger
+from azure.ai.ml.entities._workspace.networking import OutboundRule
+from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
+from azure.core.credentials import TokenCredential
+from azure.core.polling import LROPoller
+
+ops_logger = OpsLogger(__name__)
+module_logger = ops_logger.module_logger
+
+
+class WorkspaceOutboundRuleOperations:
+ """WorkspaceOutboundRuleOperations.
+
+ You should not instantiate this class directly. Instead, you should create
+ an MLClient instance that instantiates it for you and attaches it as an attribute.
+ """
+
+ def __init__(
+ self,
+ operation_scope: OperationScope,
+ service_client: ServiceClient102024Preview,
+ all_operations: OperationsContainer,
+ credentials: TokenCredential = None,
+ **kwargs: Dict,
+ ):
+ ops_logger.update_filter()
+ self._subscription_id = operation_scope.subscription_id
+ self._resource_group_name = operation_scope.resource_group_name
+ self._default_workspace_name = operation_scope.workspace_name
+ self._all_operations = all_operations
+ self._rule_operation = service_client.managed_network_settings_rule
+ self._credentials = credentials
+ self._init_kwargs = kwargs
+
+ @monitor_with_activity(ops_logger, "WorkspaceOutboundRule.Get", ActivityType.PUBLICAPI)
+ def get(self, workspace_name: str, outbound_rule_name: str, **kwargs: Any) -> OutboundRule:
+ """Get a workspace OutboundRule by name.
+
+ :param workspace_name: Name of the workspace.
+ :type workspace_name: str
+ :param outbound_rule_name: Name of the outbound rule.
+ :type outbound_rule_name: str
+ :return: The OutboundRule with the provided name for the workspace.
+ :rtype: ~azure.ai.ml.entities.OutboundRule
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START outbound_rule_get]
+ :end-before: [END outbound_rule_get]
+ :language: python
+ :dedent: 8
+ :caption: Get the outbound rule for a workspace with the given name.
+ """
+
+ workspace_name = self._check_workspace_name(workspace_name)
+ resource_group = kwargs.get("resource_group") or self._resource_group_name
+
+ obj = self._rule_operation.get(resource_group, workspace_name, outbound_rule_name)
+ # pylint: disable=protected-access
+ res: OutboundRule = OutboundRule._from_rest_object(obj.properties, name=obj.name) # type: ignore
+ return res
+
+ @monitor_with_activity(ops_logger, "WorkspaceOutboundRule.BeginCreate", ActivityType.PUBLICAPI)
+ def begin_create(self, workspace_name: str, rule: OutboundRule, **kwargs: Any) -> LROPoller[OutboundRule]:
+ """Create a Workspace OutboundRule.
+
+ :param workspace_name: Name of the workspace.
+ :type workspace_name: str
+ :param rule: OutboundRule definition (FqdnDestination, PrivateEndpointDestination, or ServiceTagDestination).
+ :type rule: ~azure.ai.ml.entities.OutboundRule
+ :return: An instance of LROPoller that returns an OutboundRule.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.OutboundRule]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START outbound_rule_begin_create]
+ :end-before: [END outbound_rule_begin_create]
+ :language: python
+ :dedent: 8
+ :caption: Create an FQDN outbound rule for a workspace with the given name,
+ similar can be done for PrivateEndpointDestination or ServiceTagDestination.
+ """
+
+ workspace_name = self._check_workspace_name(workspace_name)
+ resource_group = kwargs.get("resource_group") or self._resource_group_name
+
+ # pylint: disable=protected-access
+ rule_params = OutboundRuleBasicResource(properties=rule._to_rest_object()) # type: ignore
+
+ # pylint: disable=unused-argument, docstring-missing-param
+ def callback(_: Any, deserialized: Any, args: Any) -> Optional[OutboundRule]:
+ """Callback to be called after completion
+
+ :return: Outbound rule deserialized.
+ :rtype: ~azure.ai.ml.entities.OutboundRule
+ """
+ properties = deserialized.properties
+ name = deserialized.name
+ return OutboundRule._from_rest_object(properties, name=name) # pylint: disable=protected-access
+
+ poller = self._rule_operation.begin_create_or_update(
+ resource_group, workspace_name, rule.name, rule_params, polling=True, cls=callback
+ )
+ module_logger.info("Create request initiated for outbound rule with name: %s\n", rule.name)
+ return poller
+
+ @monitor_with_activity(ops_logger, "WorkspaceOutboundRule.BeginUpdate", ActivityType.PUBLICAPI)
+ def begin_update(self, workspace_name: str, rule: OutboundRule, **kwargs: Any) -> LROPoller[OutboundRule]:
+ """Update a Workspace OutboundRule.
+
+ :param workspace_name: Name of the workspace.
+ :type workspace_name: str
+ :param rule: OutboundRule definition (FqdnDestination, PrivateEndpointDestination, or ServiceTagDestination).
+ :type rule: ~azure.ai.ml.entities.OutboundRule
+ :return: An instance of LROPoller that returns an OutboundRule.
+ :rtype: ~azure.core.polling.LROPoller[~azure.ai.ml.entities.OutboundRule]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START outbound_rule_begin_update]
+ :end-before: [END outbound_rule_begin_update]
+ :language: python
+ :dedent: 8
+ :caption: Update an FQDN outbound rule for a workspace with the given name,
+ similar can be done for PrivateEndpointDestination or ServiceTagDestination.
+ """
+
+ workspace_name = self._check_workspace_name(workspace_name)
+ resource_group = kwargs.get("resource_group") or self._resource_group_name
+
+ # pylint: disable=protected-access
+ rule_params = OutboundRuleBasicResource(properties=rule._to_rest_object()) # type: ignore
+
+ # pylint: disable=unused-argument, docstring-missing-param
+ def callback(_: Any, deserialized: Any, args: Any) -> Optional[OutboundRule]:
+ """Callback to be called after completion
+
+ :return: Outbound rule deserialized.
+ :rtype: ~azure.ai.ml.entities.OutboundRule
+ """
+ properties = deserialized.properties
+ name = deserialized.name
+ return OutboundRule._from_rest_object(properties, name=name) # pylint: disable=protected-access
+
+ poller = self._rule_operation.begin_create_or_update(
+ resource_group, workspace_name, rule.name, rule_params, polling=True, cls=callback
+ )
+ module_logger.info("Update request initiated for outbound rule with name: %s\n", rule.name)
+ return poller
+
+ @monitor_with_activity(ops_logger, "WorkspaceOutboundRule.List", ActivityType.PUBLICAPI)
+ def list(self, workspace_name: str, **kwargs: Any) -> Iterable[OutboundRule]:
+ """List Workspace OutboundRules.
+
+ :param workspace_name: Name of the workspace.
+ :type workspace_name: str
+ :return: An Iterable of OutboundRule.
+ :rtype: Iterable[OutboundRule]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START outbound_rule_list]
+ :end-before: [END outbound_rule_list]
+ :language: python
+ :dedent: 8
+ :caption: List the outbound rule for a workspace with the given name.
+ """
+
+ workspace_name = self._check_workspace_name(workspace_name)
+ resource_group = kwargs.get("resource_group") or self._resource_group_name
+
+ rest_rules = self._rule_operation.list(resource_group, workspace_name)
+
+ result = [
+ OutboundRule._from_rest_object(rest_obj=obj.properties, name=obj.name) # pylint: disable=protected-access
+ for obj in rest_rules
+ ]
+ return result # type: ignore
+
+ @monitor_with_activity(ops_logger, "WorkspaceOutboundRule.Remove", ActivityType.PUBLICAPI)
+ def begin_remove(self, workspace_name: str, outbound_rule_name: str, **kwargs: Any) -> LROPoller[None]:
+ """Remove a Workspace OutboundRule.
+
+ :param workspace_name: Name of the workspace.
+ :type workspace_name: str
+ :param outbound_rule_name: Name of the outbound rule to remove.
+ :type outbound_rule_name: str
+ :return: An Iterable of OutboundRule.
+ :rtype: Iterable[OutboundRule]
+
+ .. admonition:: Example:
+
+ .. literalinclude:: ../samples/ml_samples_workspace.py
+ :start-after: [START outbound_rule_begin_remove]
+ :end-before: [END outbound_rule_begin_remove]
+ :language: python
+ :dedent: 8
+ :caption: Remove the outbound rule for a workspace with the given name.
+ """
+
+ workspace_name = self._check_workspace_name(workspace_name)
+ resource_group = kwargs.get("resource_group") or self._resource_group_name
+
+ poller = self._rule_operation.begin_delete(
+ resource_group_name=resource_group,
+ workspace_name=workspace_name,
+ rule_name=outbound_rule_name,
+ )
+ module_logger.info("Delete request initiated for outbound rule: %s\n", outbound_rule_name)
+ return poller
+
+ def _check_workspace_name(self, name: str) -> str:
+ """Validates that a workspace name exists.
+
+ :param name: Name for a workspace resource.
+ :type name: str
+ :raises ~azure.ai.ml.ValidationException: Raised if updating nothing is specified for name and
+ MLClient does not have workspace name set.
+ :return: No return
+ :rtype: None
+ """
+ workspace_name = name or self._default_workspace_name
+ if not workspace_name:
+ msg = "Please provide a workspace name or use a MLClient with a workspace name set."
+ raise ValidationException(
+ message=msg,
+ target=ErrorTarget.WORKSPACE,
+ no_personal_data_message=msg,
+ error_category=ErrorCategory.USER_ERROR,
+ )
+ return workspace_name